前言
在量化感知训练,为了能够进行反向传播,会引入直通估计器,用于保证参数可以求导。我们需要自己定义这些操作,且定义反向求导函数,由于基础知识薄弱,便仔细学习了相关知识。
一、概述
torch.autograd.Function
只需要实现两个 静态方法:
forward可以有任意多个输入、任意多个输出,但是输入和输出必须是Variable。
backward的输入和输出的个数就是forward()函数的输出和输入的个数。其中,backward()输入表示关于forward()输出的梯度,backward()的输出表示关于forward()的输入的梯度。
另外还要加上ctx,它可以理解为一个上下文管理器。
定义新的操作,意味着定义Function
的子类,并且这些子类必须重写以下函数:forward()
和backward()
。初始化函数:__init__()
根据实际需求判断是否需要重写。
二、例程
from torch.autograd import Function class MultiplyAdd(Function): @staticmethod def forward(ctx, w, x, b): print('type in forward', type(x)) ctx.save_for_backward(w, x)#存储用来反向传播的参数 output = w*x +b return output @staticmethod def backward(ctx, grad_output): w, x = ctx.saved_tensors #deprecated,现在使用saved_tensors print('type in backward',type(x)) grad_w = grad_output * x grad_x = grad_output * w grad_b = grad_output * 1 return grad_w, grad_x, grad_b w = torch.rand(2, 2, requires_grad=True) x = torch.rand(2, 2, requires_grad=True) b = torch.rand(2, 2, requires_grad=True) out = MultiplyAdd.apply(w, x, b) out.backward(torch.ones(2,2)) w.grad,x.grad,b.grad
(tensor([[0.5159, 0.4950], [0.1050, 0.7115]]), tensor([[0.6249, 0.4731], [0.7905, 0.1637]]), tensor([[1., 1.], [1., 1.]]))
三、官方的demo(指数函数)
import torch import torch.nn.functional as F from torch.autograd import Function class MyExp(Function): @staticmethod def forward(ctx, i): result = i.exp() ctx.save_for_backward(result) #将result转移到Variable保存在ctx中 return result @staticmethod def backward(ctx, grad_output): result, = ctx.saved_tensors return grad_output * result def exp(x): return MyExp.apply(x) x = torch.Tensor([0,2]).requires_grad_(True) out = MyExp.apply(x) out.backward(torch.Tensor([0,2])) x.grad
tensor([ 0.0000, 14.7781])