数据类型的坑

时间:2025-04-27 07:36:20

all_img = torch.tensor([])
# 采用下面的语句读取图片
img = torch.from_numpy(cv2.imread('{}.JPEG'.format(5))).unsqueeze(0)
# 此时 img 的维度 (1,224,224,3)

# 将img合并入all_img 中
all_img = all_img.cat((all_img,img)) # 报错 RuntimeError: Expected object of scalar type Byte but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'

发现img 是 int 8 类型,所以转化一下img 的数据类型


all_img = torch.tensor([])
# 采用下面的语句读取图片
img = torch.from_numpy(cv2.imread('{}.JPEG'.format(5))).unsqueeze(0).type(torch.float32) #### 注意此处
# 此时 img 的维度 (1,224,224,3)

# 将img合并入all_img 中
all_img = all_img.cat((all_img,img)) # 报错 RuntimeError: Expected object of scalar type Byte but got scalar type Float for sequence element 1 in sequence argument at position #1 'tensors'

这样才可以