PyTorch_张量转换为numpy数组

时间:2025-05-08 14:10:45

使用 tensor.numpy 函数可以将张量转换为 ndarray 数组,但是共享内存,可以使用 copy 函数避免共享。共享内存会导致张量或者numpy中的其中一个更改后,另外一个会受到影响。


代码

import torch 

# 张量转换为 numpy 数组
def test01():
    data_tensor = torch.tensor([2,3,4])

    # 将张量转换为 numpy 数组
    data_numpy = data_tensor.numpy()

    print(type(data_tensor))
    print(type(data_numpy))

    print(data_tensor)
    print(data_numpy)

# 张量和 numpy 数组共享内存
def test02():
    data_tensor = torch.tensor([2,3,4])
    data_numpy = data_tensor.numpy()

    data_tensor[0] = 100 
    print(data_tensor)
    print(data_numpy)

    # 修改 numpy 数组元素的值,看看张量是否会发生变化? 会发生变化
    data_numpy[0] = 200 
    print(data_tensor)
    print(data_numpy)

# 使用 copy 函数实现不共享内存
def test03():
    data_tensor = torch.tensor([2,3,4])
    # 此处,发生了类型转换,可以使用拷贝函数产生新的数据,避免共享内存
    data_numpy = data_tensor.numpy().copy()

    # 修改 numpy 数组元素的值,看看张量是否会发生变化? 不会发生变化
    data_numpy[0] = 100 
    print(data_tensor)
    print(data_numpy)

    data_tensor[0] = 200 
    print(data_tensor)
    print(data_numpy)

if __name__ == "__main__":
    test03()