发现torch.sumdim参数可以使用tuple,总结记录一下

# 1.不指定dim

torch.sum(input, *, dtype=None) → Tensor
  • 输入inputtensor
  • 输出input中各元素的和

例如

t = torch.randint(0, 3, (2,2,2))
print(t)
print(torch.sum(t))

输出为:

tensor([[[1, 1],
         [0, 2]],

        [[2, 2],
         [0, 2]]])
tensor(10)

# 2.指定dim

torch.sum(input, dim, keepdim=False, *, dtype=None) → Tensor
  • dim参数可以为常数tuple
  • keepdim参数指是否对求和的结果squeeze,如果True其他维度保持不变,求和的dim维变为1

# 2.1dim为常数

在维度dim上求和

import torch

t = torch.randint(0, 3, (2,2,2))

print(t)
print(torch.sum(t, dim=0))

输出为:

tensor([[[0, 1],
         [2, 1]],

        [[2, 2],
         [1, 2]]])
# 相当于在dim=0上求和,0+2, 1+2, 2+1, 2+2
tensor([[2, 3],
        [3, 3]])

# 2.2 dim为tuple

import torch

t = torch.randint(0, 3, (2,2,2))

print(t)
print(torch.sum(t, dim=(0,1)))

输出为:

tensor([[[1, 2],
         [0, 2]],

        [[2, 2],
         [2, 0]]])
# 相当于
# 1.在dim=0上求和 1+2,2+2, 0+2, 2+0
# 得[[3,4],[2,2]]
# 2.在dim=1上求和 3+2, 4+2
# 得[5,6]
tensor([5, 6])
(adsbygoogle = window.adsbygoogle || []).push({});

# 参考资料