Softmax
2026/6/6大约 1 分钟
Softmax
原始题目:LeetGPU - Softmax
题目描述
编写一个 GPU 程序,计算 32 位浮点数数组的 softmax 函数。对于长度为 的输入数组 ,softmax 是一个长度相同的数组,其第 个元素定义为:
你的解法应使用 "max trick" 来处理潜在的溢出问题:在指数运算之前,将输入数组的每个元素减去最大值。
实现要求
- 只允许使用原生功能(不允许使用外部库)。
solve函数签名必须保持不变。- 最终结果必须存储在数组
output中。
示例
示例 1
Input: [1.0, 2.0, 3.0], N = 3
Output: [0.090, 0.244, 0.665](近似值)示例 2
Input: [-10.0, -5.0, 0.0, 5.0, 10.0], N = 5
Output: [2.047e-09, 3.038e-07, 4.509e-05, 6.693e-03, 9.933e-01](近似值)约束条件
- 。
- 性能测试在 的规模下进行。
解题思路
Softmax 需要三趟遍历:找最大值(reduce)、指数求和(reduce)、逐元素除法(map)。在线程协作层面,可以先并行找局部最大值,通过 warp shuffle 或 shared memory 规约得到全局最大值。第二轮同样的规约模式求指数和。两趟 reduce 都可以使用高效的分块规约策略。注意处理 的数值稳定性。欢迎在 GitHub Discussions 分享你的解法。