Pytorch中tensor与ndarray类型转换及标量转换

时间:2025-03-29 09:32:48

tensor与ndarrary的转换

Pytorch中的tensor与ndarray在底层数据类型设计有相似之处,在Pytorch框架中tensor与ndarray可以较为方便地转换

tensor转ndarray

tensor转ndarray分为浅拷贝与深拷贝

浅拷贝

浅拷贝一般使用numpy()方法

  1. import torch
  2. import numpy as np
  3. data1 = ([1, 2, 3])
  4. print(data1)
  5. data2 = ()
  6. print(data2)
  7. data1[0] = 9
  8. print(data1)
  9. print(data2)
  10. # tensor([1, 2, 3])
  11. # [1 2 3]
  12. # tensor([9, 2, 3])
  13. # [9 2 3]

可以看到,在对转换成ndarray类型的data2进行修改后,tensor的值也随之改变,这是因为二者底层共用一块,为浅拷贝

深拷贝

深拷贝我们可以对tensor进行clone()后再进行转换,clone()会拷贝一份完全独立的张量,并会拷贝计算图

  1. import torch
  2. import numpy as np
  3. data1 = ([1, 2, 3])
  4. print(data1)
  5. data2 = ().numpy()
  6. print(data2)
  7. data1[0] = 9
  8. print(data1)
  9. print(data2)
  10. # tensor([1, 2, 3])
  11. # [1 2 3]
  12. # tensor([9, 2, 3])
  13. # [1 2 3]

可以看到这里在对张量进行修改后,并不会影响ndarray,因为这里为深拷贝

ndarray转tensor

ndarray转tensor同样分为深拷贝和浅拷贝

浅拷贝

浅拷贝一般是通过torch.from_numpy()实现的

  1. import torch
  2. import numpy as np
  3. data1 = ([1, 2, 3])
  4. data2 = torch.from_numpy(data1)
  5. print(data1)
  6. print(data2)
  7. data1[0] = 9
  8. print(data1)
  9. print(data2)
  10. # [1 2 3]
  11. # tensor([1, 2, 3], dtype=torch.int32)
  12. # [9 2 3]
  13. # tensor([9, 2, 3], dtype=torch.int32)

可以看到浅拷贝后,对共享内存的任意一个对象修改都会影响到另一个的值

深拷贝

深拷贝这里我们可以通过对ndarray进行copy()进行深拷贝创立副本

  1. import torch
  2. import numpy as np
  3. data1 = ([1, 2, 3])
  4. data2 = torch.from_numpy(())
  5. print(data1)
  6. print(data2)
  7. data1[0] = 9
  8. print(data1)
  9. print(data2)
  10. # [1 2 3]
  11. # tensor([1, 2, 3], dtype=torch.int32)
  12. # [9 2 3]
  13. # tensor([1, 2, 3], dtype=torch.int32)

张量提取标量

tensor可以分为矢量张量和标量张量,对于从张量中提取标量值一般可以使用item()方法,要求tensor为单个元素才可以使用

  1. import torch
  2. import numpy as np
  3. data1 = (1)
  4. data2 = ([1])
  5. print(data1)
  6. print(data2)
  7. print(())
  8. print(())
  9. # tensor(1)
  10. # tensor([1])
  11. # 1
  12. # 1
  1. import torch
  2. import numpy as np
  3. data1 = ([1, 2, 3])
  4. print(data1)
  5. print(())
  6. tensor([1, 2, 3])
  7. # Traceback (most recent call last):
  8. # File "D:\Pythonproject\teach_day_01\", line 7, in <module>
  9. # print(())
  10. # ^^^^^^^^^^^^
  11. # RuntimeError: a Tensor with 3 elements cannot be converted to Scalar

可以看到非标量张量无法进行item()标量值提取