发现torch.sum
的dim
参数可以使用tuple
,总结记录一下
# 1.不指定dim
torch.sum(input, *, dtype=None) → Tensor
- 输入
input
tensor - 输出
input
中各元素的和
例如
t = torch.randint(0, 3, (2,2,2))
print(t)
print(torch.sum(t))
输出为:
tensor([[[1, 1],
[0, 2]],
[[2, 2],
[0, 2]]])
tensor(10)
# 2.指定dim
torch.sum(input, dim, keepdim=False, *, dtype=None) → Tensor
dim
参数可以为常数
和tuple
keepdim
参数指是否对求和的结果squeeze
,如果True
其他维度保持不变,求和的dim
维变为1
# 2.1dim为常数
在维度dim
上求和
import torch
t = torch.randint(0, 3, (2,2,2))
print(t)
print(torch.sum(t, dim=0))
输出为:
tensor([[[0, 1],
[2, 1]],
[[2, 2],
[1, 2]]])
# 相当于在dim=0上求和,0+2, 1+2, 2+1, 2+2
tensor([[2, 3],
[3, 3]])
# 2.2 dim为tuple
import torch
t = torch.randint(0, 3, (2,2,2))
print(t)
print(torch.sum(t, dim=(0,1)))
输出为:
tensor([[[1, 2],
[0, 2]],
[[2, 2],
[2, 0]]])
# 相当于
# 1.在dim=0上求和 1+2,2+2, 0+2, 2+0
# 得[[3,4],[2,2]]
# 2.在dim=1上求和 3+2, 4+2
# 得[5,6]
tensor([5, 6])