tensor与ndarrary的转换
Pytorch中的tensor与ndarray在底层数据类型设计有相似之处,在Pytorch框架中tensor与ndarray可以较为方便地转换
tensor转ndarray
tensor转ndarray分为浅拷贝与深拷贝
浅拷贝
浅拷贝一般使用numpy()方法
-
import torch
-
import numpy as np
-
-
data1 = ([1, 2, 3])
-
print(data1)
-
data2 = ()
-
print(data2)
-
data1[0] = 9
-
print(data1)
-
print(data2)
-
# tensor([1, 2, 3])
-
# [1 2 3]
-
# tensor([9, 2, 3])
-
# [9 2 3]
可以看到,在对转换成ndarray类型的data2进行修改后,tensor的值也随之改变,这是因为二者底层共用一块,为浅拷贝
深拷贝
深拷贝我们可以对tensor进行clone()后再进行转换,clone()会拷贝一份完全独立的张量,并会拷贝计算图
-
import torch
-
import numpy as np
-
-
data1 = ([1, 2, 3])
-
print(data1)
-
data2 = ().numpy()
-
print(data2)
-
data1[0] = 9
-
print(data1)
-
print(data2)
-
# tensor([1, 2, 3])
-
# [1 2 3]
-
# tensor([9, 2, 3])
-
# [1 2 3]
可以看到这里在对张量进行修改后,并不会影响ndarray,因为这里为深拷贝
ndarray转tensor
ndarray转tensor同样分为深拷贝和浅拷贝
浅拷贝
浅拷贝一般是通过torch.from_numpy()实现的
-
import torch
-
import numpy as np
-
-
data1 = ([1, 2, 3])
-
data2 = torch.from_numpy(data1)
-
print(data1)
-
print(data2)
-
data1[0] = 9
-
print(data1)
-
print(data2)
-
# [1 2 3]
-
# tensor([1, 2, 3], dtype=torch.int32)
-
# [9 2 3]
-
# tensor([9, 2, 3], dtype=torch.int32)
可以看到浅拷贝后,对共享内存的任意一个对象修改都会影响到另一个的值
深拷贝
深拷贝这里我们可以通过对ndarray进行copy()进行深拷贝创立副本
-
import torch
-
import numpy as np
-
-
data1 = ([1, 2, 3])
-
data2 = torch.from_numpy(())
-
print(data1)
-
print(data2)
-
data1[0] = 9
-
print(data1)
-
print(data2)
-
# [1 2 3]
-
# tensor([1, 2, 3], dtype=torch.int32)
-
# [9 2 3]
-
# tensor([1, 2, 3], dtype=torch.int32)
张量提取标量
tensor可以分为矢量张量和标量张量,对于从张量中提取标量值一般可以使用item()方法,要求tensor为单个元素才可以使用
-
import torch
-
import numpy as np
-
-
data1 = (1)
-
data2 = ([1])
-
print(data1)
-
print(data2)
-
print(())
-
print(())
-
# tensor(1)
-
# tensor([1])
-
# 1
-
# 1
-
import torch
-
import numpy as np
-
-
data1 = ([1, 2, 3])
-
print(data1)
-
-
print(())
-
tensor([1, 2, 3])
-
# Traceback (most recent call last):
-
# File "D:\Pythonproject\teach_day_01\", line 7, in <module>
-
# print(())
-
# ^^^^^^^^^^^^
-
# RuntimeError: a Tensor with 3 elements cannot be converted to Scalar
可以看到非标量张量无法进行item()标量值提取