Categorical Cross Entropy Loss
2026/6/6大约 1 分钟
Categorical Cross Entropy Loss
题目描述
编写一个 GPU 程序,计算一批预测结果的分类交叉熵损失。给定预测 logits 矩阵 ()和真实类别标签向量 true_labels(长度为 ),计算批量平均交叉熵损失。
单个样本 的损失(logits 为 ,真实标签为 )使用数值稳定公式计算:
最终输出存储在 loss 变量中,为 个样本的平均损失:
实现要求
- 不允许使用外部库。
solve函数签名必须保持不变。- 最终结果(平均损失)必须存储在
loss中。
示例
示例 1
Input: N = 2, C = 3
logits = [[1.0, 2.0, 0.5], [0.1, 3.0, 1.5]]
true_labels = [1, 1]
Output: loss = 0.3548926示例 2
Input: N = 3, C = 4
logits = [[-0.5, 1.5, 0.0, 1.0], [2.0, -1.0, 0.5, 0.5], [0.0, 0.0, 0.0, 0.0]]
true_labels = [3, 0, 1]
Output: loss = 0.98820376约束条件
- ,。
- 。
- 。
- 性能测试在 的规模下进行。
解题思路
交叉熵损失 = log-softmax + 负对数似然。与 Softmax 类似,需要先找每行的最大值("max trick"),再计算 log-sum-exp 减去真实标签对应的 logit。核心是每行独立计算,可以用一个线程块处理一行或一个 warp 处理一行。当 较大时,warp-level reduction 非常高效。欢迎在 GitHub Discussions 分享你的解法。