python

超轻量级php框架startmvc

Pytorch 实现计算分类器准确率(总分类及子分类)

更新时间:2020-08-21 15:18:02 作者:startmvc
分类器平均准确率计算:correct=torch.zeros(1).squeeze().cuda()total=torch.zeros(1).squeeze().cuda()fori,(image

分类器平均准确率计算:


correct = torch.zeros(1).squeeze().cuda()
total = torch.zeros(1).squeeze().cuda()
for i, (images, labels) in enumerate(train_loader):
 images = Variable(images.cuda())
 labels = Variable(labels.cuda())

 output = model(images)

 prediction = torch.argmax(output, 1)
 correct += (prediction == labels).sum().float()
 total += len(labels)
acc_str = 'Accuracy: %f'%((correct/total).cpu().detach().data.numpy())

分类器各个子类准确率计算:


correct = list(0. for i in range(args.class_num))
total = list(0. for i in range(args.class_num))
for i, (images, labels) in enumerate(train_loader):
 images = Variable(images.cuda())
 labels = Variable(labels.cuda())

 output = model(images)

 prediction = torch.argmax(output, 1)
 res = prediction == labels
 for label_idx in range(len(labels)):
 label_single = label[label_idx]
 correct[label_single] += res[label_idx].item()
 total[label_single] += 1
 acc_str = 'Accuracy: %f'%(sum(correct)/sum(total))
 for acc_idx in range(len(train_class_correct)):
 try:
 acc = correct[acc_idx]/total[acc_idx]
 except:
 acc = 0
 finally:
 acc_str += '\tclassID:%d\tacc:%f\t'%(acc_idx+1, acc)

以上这篇Pytorch 实现计算分类器准确率(总分类及子分类)就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

Pytorch 分类器 准确率