PyTorch中的使用
>>> import torch
>>>
>>> # 定义一个 33x55 张量
>>> a = torch.randn(33, 55)
>>> a.size()
torch.Size([33, 55])
>>>
>>> # 下面开始尝试 repeat 函数在不同参数情况下的效果
>>> a.repeat(1,1).size() # 原始值:([33, 55])
torch.Size([33, 55])
>>>
>>> a.repeat(2,1).size() # 原始值:([33, 55])
torch.Size([66, 55])
>>>
>>> a.repeat(1,2).size() # 原始值:([33, 55])
torch.Size([33, 110])
>>>
>>> a.repeat(1,1,1).size() # 原始值:([33, 55])
torch.Size([1, 33, 55])
>>>
>>> a.repeat(2,1,1).size() # 原始值:([33, 55])
torch.Size([2, 33, 55])
>>>
>>> a.repeat(1,2,1).size() # 原始值:([33, 55])
torch.Size([1, 66, 55])
>>>
>>> a.repeat(1,1,2).size() # 原始值:([33, 55])
torch.Size([1, 33, 110])
>>>
>>> a.repeat(1,1,1,1).size() # 原始值:([33, 55])
torch.Size([1, 1, 33, 55])
>>>
>>> # ------------------ 割割 ------------------
>>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数,
>>> # 下面是一些错误示例
>>> a.repeat(2).size() # 1D < 2D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> # 定义一个3维的张量,然后展示前面提到的那个错误
>>> b = torch.randn(5,6,7)
>>> b.size() # 3D
torch.Size([5, 6, 7])
>>>
>>> b.repeat(2).size() # 1D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1).size() # 2D < 3D, error
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
>>>
>>> b.repeat(2,1,1).size() # 3D = 3D, okay
torch.Size([10, 6, 7])
>>>