龙良曲pytorch学习笔记_03

时间:2024-02-01 19:36:17
  1 import torch
  2 from torch import nn
  3 from torch.nn import functional as F
  4 from torch import optim
  5 
  6 import torchvision
  7 from matplotlib import pyplot as plt
  8 
  9 # 小工具
 10 
 11 def plot_curve(data):
 12     fig = plt.figure()
 13     plt.plot(range(len(data)),data,color='blue')
 14     plt.legend(['value'],loc='upper right')
 15     plt.xlabel('step')
 16     plt.tlabel('value')
 17     plt.show()
 18 
 19 def plot_image(img,label,name):
 20     fig = plt.figure()
 21     for i in range(6):
 22         plt.subplot(2,3,i+1)
 23         plt,tight_layout()
 24         plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
 25         plt.title("{}:{}".format(name,label[i].item()))
 26         plt.xticks([])
 27         plt.xticks([])
 28         
 29     plt.show()
 30         
 31 def one_hot(label,depth = 10):
 32     out = torch.zeros(label.size(0),depth)
 33     idx = torch.LongTensor(label).view(-1,1)
 34     out.scatter_(dim=1,index=idx,value=1)
 35     return out
 36     
 37 # 一次加载多少图片
 38 batch_size = 512
 39 # step1. load dataset 数据加载
 40 train_loader = torch.utils.data.DataLoader(
 41     torchvision.datasets.MINST('mnist_data',train=True,download=True,
 42                               transform=torchvision.transforms.Compose([
 43                                   torchvision.transfroms.ToTensor(),
 44                                   
 45                                   torchvision.transfroms.Normalize(
 46                                       (0.1307,),(0.3081,))
 47                               ])),
 48     batch_size=batch_size,shuffle=True)
 49 test_loader = torch.utils.data.DataLoader(
 50     torchvision.datasets.MINST('mnist_data/',train=False,download=True,
 51                               transform=torchvision.transforms.Compose([
 52                                   torchvision.transfroms.ToTensor(),
 53                                   torchvision.transfroms.Normalize(
 54                                       (0.1307,),(0.3081,))
 55                               ])),
 56     batch_size=batch_size,shuffle=False)
 57     
 58 # 网络创建
 59 class Net(nn.Module):
 60     
 61     def __init__(self):
 62         super(Net,self).__init__()
 63         
 64     #xw+b
 65     self.fc1 = nn.Linear(28*28,256)
 66     self.fc2 = nn.Linear(256,64)
 67     self.fc3 = nn.Linear(64,10)
 68     
 69     def forward(self,x):
 70         # x:[batch_size,1,28,28]
 71         # h1 = relu(xw1+b1)
 72         x = F.relu(self.fc1(x))
 73         # h1 = relu(h1w2+b2)
 74         x = F.relu(self.fc2(x))
 75         # h3 = h2w3+b3
 76         x = self.fc3(x)
 77         
 78         return x
 79         
 80 net = Net()
 81 # [w1,b1,w2,b1,w3,b3]
 82 optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
 83 
 84 train_loss = []
 85 
 86 # 训练
 87 for epoch in range(3):
 88 
 89     for batch_idx,(x,y) in enumerate(train_loader):
 90     
 91         # x: [b,1,28,28], y:[512]
 92         # [b,1,28,28]-->[b,feature]
 93         x = x.view(x.size(0),28*28)
 94         # --> [b,10]
 95         out = net(x)
 96         # --> [b,10]
 97         y_onehot = one_hot(y)
 98         # loss = mse(out,y_onehot)
 99         loss = F.mse_loss(out,y_onehot)
100         # 清零梯度
101         optimizer.zero_grad()
102         # 计算梯度
103         loss.backward()
104         #w' = w - lr*grad 更新梯度
105         optimizer.step()
106         
107         train_loss.append(loss.item())
108         
109         if batch_idx % 10 == 0:
110             print(epoch,batch_idx,loss.item())
111             
112 plot_curve(train_loss)
113             
114 # 得到一个比较好的    [w1,b1,w2,b1,w3,b3]    
115 
116 
117 # 验证准确率
118 total_correct = 0
119 for x,y in test_loader"
120     x = x.view(x.size(0),28*28)
121     out = net(x)
122     # out: [b,10] --> pred: [b]
123     pred = out.argmax(dim = 1)
124     correct = pred.eq(y).sum().float().item()
125     total_correct += correct
126 
127 total_num = len(test_loader.dataset)
128 acc = total_correct / total_num
129 print('test acc:',acc)
130 
131 # 直观显示验证
132 x,y = next(iter(test_loader))
133 out = net(x.view(x.size(0),28*28))
134 pred = out.argmax(dim = 1)
135 plot_image(x,pred,'test')
136         
137         
138         
139         
140         
141