全球DeepFake攻防挑战赛&DataWhale AI 夏令营——图像赛道
import torch
import dataset
torch.manual_seed(0)
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.benchmark = True
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data.dataset import Dataset
from dataset import FFDIDataset
import timm
import time
from Model import model
import pandas as pd
import numpy as np
import cv2
from PIL import Image
from tqdm import tqdm
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, *meters):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = ""
def pr2int(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def validate(val_loader, model, criterion):#验证集进行验证
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(len(val_loader), batch_time, losses, top1)
# switch to evaluate mode
model.eval()
with torch.no_grad():
end = time.time()
for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
input = input.cuda()
target = target.cuda()
# compute output
output = model(input)#模型进行处理
loss = criterion(output, target)#损失函数
# measure accuracy and record loss
acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100#计算acc
losses.update(loss.item(), input.size(0))
top1.update(acc, input.size(0))
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
# TODO: this should also be done with the ProgressMeter
print(' * Acc@1 {:.3f}'
.format(top1=top1))
return top1
def predict(test_loader, model, tta=10):#模型进行预测
# switch to evaluate mode
model.eval()
test_pred_tta = None
for _ in range(tta):
test_pred = []
with torch.no_grad():
end = time.time()
for i, (input, target) in tqdm(enumerate(test_loader), total=len(test_loader)):
input = input.cuda()
target = target.cuda()
# compute output
output = model(input)
output = F.softmax(output, dim=1)#softmax进行处理
output = output.data.cpu().numpy()
test_pred.append(output)
test_pred = np.vstack(test_pred)
if test_pred_tta is None:
test_pred_tta = test_pred
else:
test_pred_tta += test_pred
return test_pred_tta
def train(train_loader, model, criterion, optimizer, epoch):
batch_time = AverageMeter('Time', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
progress = ProgressMeter(len(train_loader), batch_time, losses, top1)
# switch to train mode
model.train()
end = time.time()
for i, (input, target) in enumerate(train_loader):
input = input.cuda(non_blocking=True)
target = target.cuda(non_blocking=True)
# compute output
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
losses.update(loss.item(), input.size(0))
acc = (output.argmax(1).view(-1) == target.float().view(-1)).float().mean() * 100
top1.update(acc, input.size(0))
# compute gradient and do SGD step
optimizer.zero_grad()
loss.backward()
optimizer.step()
# measure elapsed time
batch_time.update(time.time() - end)
end = time.time()
if i % 100 == 0:
progress.pr2int(i)
if __name__ == '__main__':
train_label, val_label = dataset.read_labels()
train_loader = torch.utils.data.DataLoader(#加载数据,同时进行数据增强
FFDIDataset(train_label['path'].head(10), train_label['target'].head(10),
transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=40, shuffle=True, num_workers=4, pin_memory=True
)
val_loader = torch.utils.data.DataLoader(
FFDIDataset(val_label['path'].head(10), val_label['target'].head(10),
transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)
model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()#交叉熵
optimizer = torch.optim.Adam(model.parameters(), 0.005)#Adam优化器
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85)
best_acc = 0.0
for epoch in range(1):
print('Epoch: ', epoch)
train(train_loader, model, criterion, optimizer, epoch)
val_acc = validate(val_loader, model, criterion)
optimizer.step()
scheduler.step()
if val_acc.avg.item() > best_acc:
best_acc = round(val_acc.avg.item(), 2)
torch.save(model.state_dict(), f'./model_{best_acc}.pt')
test_loader = torch.utils.data.DataLoader(
FFDIDataset(val_label['path'].head(10), val_label['target'].head(10),
transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
), batch_size=40, shuffle=False, num_workers=4, pin_memory=True
)
val = val_label.head(10).copy()
val['y_pred'] = predict(test_loader,model,1)[:,1]
val[['img_name','y_pred']].to_csv('',index=None)