PyTorch中使用多维索引张量对高维张量批量索引的正确方法
本文深入讲解如何在 PyTorch 中利用形状为 [b, k] 的索引张量 B,对形状为 [b, m, n] 的高维张量 A 执行高效批量索引,最终得到 [b, k, n] 的输出。核心思路在于合理扩展索引维度并配合 torch.gather 实现精准的逐行抽取。
很多人处理高维张量的批量索引时都会遇到瓶颈——尤其是当索引本身是 batch 维独立的二维张量,而目标张量还需要保留最后一维的所有列。直接使用 torch.index_select 或 torch.take 会发现根本行不通:它们只接受一维索引。而 torch.gather 虽然能处理多维输入,但要求输入张量与索引在除指定维度外的所有维度上严格对齐。那么,如何用形状为 [b, k] 的索引 B,从形状为 [b, m, n] 的张量 A 中提取每 batch 独立的 k 行,同时完整保留最后一维 n 的全部数值?
关键在于对索引张量进行升维并使其与目标张量形状对齐。具体步骤拆解如下:
-
明确索引语义——我们期望的输出结果满足
out[b, k, n] == A[b, B[b, k], n],即对每个 batch b,从A[b](形状[m, n])中挑选出B[b, k]所指定的第 k 行,总共取出 k 行,且每行保留完整的 n 列信息。 -
扩展索引维度:先将
B(形状[b, k])转换为[b, k, 1],再通过广播机制扩展为[b, k, n],从而与A的最后一维匹配:B_expanded = B.unsqueeze(-1).expand(-1, -1, A.size(-1)) # [b, k, n]
-
传入 torch.gather:沿着
dim=1(即 m 维度)执行 gather 操作。此时A的形状为[b, m, n],B_expanded的形状为[b, k, n],除 gather 维度外其他维度均已对齐:out = torch.gather(A, dim=1, index=B_expanded) # 输出 shape: [b, k, n]
下面提供一个完整可运行的示例,方便直接验证实现效果:
import torch
b, m, n, k = 2, 5, 4, 3
A = torch.randn(b, m, n) # [2, 5, 4]
B = torch.randint(0, m, (b, k)) # [2, 3],值 ∈ [0, 4]
# 扩展索引:[b,k] → [b,k,1] → [b,k,n]
B_idx = B.unsqueeze(-1).expand(-1, -1, n)
# 沿 dim=1 gather
out = torch.gather(A, dim=1, index=B_idx)
print(f"A.shape: {A.shape}") # torch.Size([2, 5, 4])
print(f"B.shape: {B.shape}") # torch.Size([2, 3])
print(f"out.shape: {out.shape}") # torch.Size([2, 3, 4])
# 验证:out[0,0] 应等于 A[0, B[0,0]]
assert torch.equal(out[0, 0], A[0, B[0, 0]])
有几个细节值得特别留意:
- 索引张量
B中的每个数值必须严格落在[0, m)区间内,否则会触发IndexError; torch.gather不支持负索引(与 NumPy 不同),使用前需确保索引为非负;- 该操作完全可微,若需要回传梯度可直接使用。如果仅用于推理阶段并希望加速,也可考虑
torch.nn.functional.embedding——只需将A视为 embedding 权重,将B视为 token IDs 即可; - 当然也能用循环配合
torch.index_select实现,但那样会失去向量化优势,性能相差悬殊,通常不推荐。
掌握这一模式后,像序列抽取、top-k 特征筛选、动态掩码选择等常见场景,都能轻松应对。
游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。
同类文章
PyTorch中使用多维索引张量对高维张量批量索引的正确方法
本文深入讲解如何在 PyTorch 中利用形状为 [b, k] 的索引张量 B,对形状为 [b, m, n] 的高维张量 A 执行高效批量索引,最终得到 [b, k, n] 的输出。核心思路在于合理扩展索引维度并配合 torch gather 实现精准的逐行抽取。 很多人处理高维张量的批量索引时都会
Go中...操作符解包切片传递可变参数函数
在 Go 语言中,` ` 运算符放在切片变量后面(如 `slice `)的作用是将该切片“展开”为多个独立参数,专门用于调用那些接受可变参数(` T`)的函数,例如 `append` 或 `fmt Println`。这是一种类型安全的语法糖,并非省略号或通配符,能够帮助开发者更简洁地处理
macOS与WSL2下PHP多版本切换失效问题排查与修复指南
本文深入分析在 macOS 或 WSL2(Ubuntu)开发环境中,通过 Homebrew 管理 PHP 多版本时,php -v 始终显示旧版本(如 php@5 6)的深层原因,并给出系统性解决方案,覆盖 PATH 冲突、符号链接逻辑、Shell 初始化配置、系统残留配置等关键环节。 遇到这种情况的
PHP JSON解析深层嵌套对象属性访问失败的解决方法
使用 json_decode() 解析 API 返回的 JSON 数据时,经常遇到某个子属性无法正常获取,始终返回 NULL —— 这是许多 PHP 开发者都曾碰到过的棘手问题。通常并非数据丢失,而是对象嵌套层级比预期更深,导致访问路径不正确。 举例来说,你看到返回的 JSON 里有一个 appea
nnU-Net v2预处理卡死问题的成因分析与实用解决指南
> 使用 nnUNetv2_plan_and_preprocess 处理大规模数据集(例如 704 例样本)时,程序常因多进程加载导致死锁而停滞。核心原因在于默认并发数过高引发资源竞争或 I O 阻塞,适当降低并发数即可稳定完成全量预处理。 你在使用 `nnunetv2_plan_and_prepr
- 日榜
- 周榜
- 月榜
相关攻略
2026-07-03 06:53
2026-07-03 06:53
2026-07-03 06:53
2026-07-03 06:53
2026-07-03 06:52
2026-07-03 06:52
2026-07-03 06:52
2026-07-03 06:52
热门教程
- 游戏攻略
- 安卓教程
- 苹果教程
- 电脑教程
热门话题

