PyTorch中的使用

时间:2025-05-06 09:21:39
>>> import torch >>> >>> # 定义一个 33x55 张量 >>> a = torch.randn(33, 55) >>> a.size() torch.Size([33, 55]) >>> >>> # 下面开始尝试 repeat 函数在不同参数情况下的效果 >>> a.repeat(1,1).size() # 原始值:([33, 55]) torch.Size([33, 55]) >>> >>> a.repeat(2,1).size() # 原始值:([33, 55]) torch.Size([66, 55]) >>> >>> a.repeat(1,2).size() # 原始值:([33, 55]) torch.Size([33, 110]) >>> >>> a.repeat(1,1,1).size() # 原始值:([33, 55]) torch.Size([1, 33, 55]) >>> >>> a.repeat(2,1,1).size() # 原始值:([33, 55]) torch.Size([2, 33, 55]) >>> >>> a.repeat(1,2,1).size() # 原始值:([33, 55]) torch.Size([1, 66, 55]) >>> >>> a.repeat(1,1,2).size() # 原始值:([33, 55]) torch.Size([1, 33, 110]) >>> >>> a.repeat(1,1,1,1).size() # 原始值:([33, 55]) torch.Size([1, 1, 33, 55]) >>> >>> # ------------------ 割割 ------------------ >>> # repeat()的参数的个数,不能少于被操作的张量的维度的个数, >>> # 下面是一些错误示例 >>> a.repeat(2).size() # 1D < 2D, error Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor >>> >>> # 定义一个3维的张量,然后展示前面提到的那个错误 >>> b = torch.randn(5,6,7) >>> b.size() # 3D torch.Size([5, 6, 7]) >>> >>> b.repeat(2).size() # 1D < 3D, error Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor >>> >>> b.repeat(2,1).size() # 2D < 3D, error Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor >>> >>> b.repeat(2,1,1).size() # 3D = 3D, okay torch.Size([10, 6, 7]) >>>