python

超轻量级php框架startmvc

pytorch自定义二值化网络层方式

更新时间:2020-08-19 08:36:02 作者:startmvc
任务要求:自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数

任务要求:

自定义一个层主要是定义该层的实现函数,只需要重载Function的forward和backward函数即可,如下:


import torch
from torch.autograd import Function
from torch.autograd import Variable

定义二值化函数


class BinarizedF(Function):
 def forward(self, input):
 self.save_for_backward(input)
 a = torch.ones_like(input)
 b = -torch.ones_like(input)
 output = torch.where(input>=0,a,b)
 return output
 def backward(self, output_grad):
 input, = self.saved_tensors
 input_abs = torch.abs(input)
 ones = torch.ones_like(input)
 zeros = torch.zeros_like(input)
 input_grad = torch.where(input_abs<=1,ones, zeros)
 return input_grad

定义一个module


class BinarizedModule(nn.Module):
 def __init__(self):
 super(BinarizedModule, self).__init__()
 self.BF = BinarizedF()
 def forward(self,input):
 print(input.shape)
 output =self.BF(input)
 return output

进行测试


a = Variable(torch.randn(4,480,640), requires_grad=True)
output = BinarizedModule()(a)
output.backward(torch.ones(a.size()))
print(a)
print(a.grad)

其中, 二值化函数部分也可以按照方式写,但是速度慢了0.05s


class BinarizedF(Function):
 def forward(self, input):
 self.save_for_backward(input)
 output = torch.ones_like(input)
 output[input<0] = -1
 return output
 def backward(self, output_grad):
 input, = self.saved_tensors
 input_grad = output_grad.clone()
 input_abs = torch.abs(input)
 input_grad[input_abs>1] = 0
 return input_grad

以上这篇pytorch自定义二值化网络层方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。

pytorch 二值化 网络层