Top-p Sampling
2026/6/6大约 1 分钟
Top-p Sampling
题目描述
编写一个 GPU 程序,实现 LLM 推理中的 top-p(核采样,Nucleus Sampling)。Top-p 采样是一种文本生成技术,从累积概率超过阈值 的最小 token 集合中进行采样,比纯 top-k 或贪心采样更好地平衡随机性和质量。
给定语言模型的 logits(未归一化分数),执行以下步骤:
- 使用 softmax 将 logits 转换为概率
- 按概率降序排序 tokens
- 找到累积概率 的最小集合("核")
- 将核内概率重新归一化使总和为 1
- 使用提供的随机种子从核中采样一个 token
实现要求
- 不允许使用外部库。
solve函数签名必须保持不变。- 确保计算 softmax 时的数值稳定性。
示例
示例 1
Input: logits = [1.0, 2.0, 3.0, 0.5], p = 0.9, seed = 42
Output: sampled_token = 2 或 1(最高概率的两个 token 之一,随机采样)示例 2
Input: logits = [10.0, 1.0, 1.0], p = 0.5, seed = 123
Output: sampled_token = 0(单个 token 占据绝大部分概率质量)约束条件
- 。
- 。
- 。
- 。
- 性能测试在 的规模下进行。
解题思路
Top-p 采样的性能瓶颈在排序——需要对 50k 个概率值降序排列。GPU 上可以使用基数排序或 bitonic sort 对小词表高效排序。排序后做前缀和扫描找到累积概率超过 的截断点,再在核内按概率做加权随机采样。实际应用中通常用 curand 生成随机数来进行采样。欢迎在 GitHub Discussions 分享你的解法。