Batch Normalization
2026/6/6大约 1 分钟
Batch Normalization
题目描述
编写一个 GPU 程序,实现二维输入张量的批归一化(Batch Normalization)前向传播。给定形状为 的输入张量( 为批量大小, 为特征数),使用可学习的缩放()和平移()参数计算归一化输出。
对于每个特征通道 ,批归一化计算:
实现要求
- 不允许使用外部库。
solve函数签名必须保持不变。- 最终结果必须存储在
output张量中。
示例
示例 1
Input: input = [[1,2],[3,4],[5,6]] (N=3, C=2), gamma=[1,1], beta=[0,0], eps=1e-5
Output: [[-1.224, -1.224], [0, 0], [1.224, 1.224]]示例 2
Input: input = [[0,1],[2,3]] (N=2, C=2), gamma=[2,0.5], beta=[1,-1], eps=1e-5
Output: [[-1, -1.5], [3, -0.5]]约束条件
- ,。
- 。
- ,,。
- 性能测试在 的规模下进行。
解题思路
BatchNorm 需要两趟:第一趟对每个通道做规约求 和 ,第二趟做逐元素归一化和缩放。跨 维度的规约可以利用 warp shuffle 和共享内存高效完成。当 较大时,可以每个 block 处理一个通道,多个 block 并行处理所有通道。也可以使用 Welford 在线算法在单趟中同时计算均值和方差以减少内存访问。欢迎在 GitHub Discussions 分享你的解法。