torch.autograd.Function 学习理解

简介: 文章目录前言一、概述二、例程三、官方的demo(指数函数)

前言


在量化感知训练,为了能够进行反向传播,会引入直通估计器,用于保证参数可以求导。我们需要自己定义这些操作,且定义反向求导函数,由于基础知识薄弱,便仔细学习了相关知识。

一、概述

torch.autograd.Function

只需要实现两个 静态方法:


forward可以有任意多个输入、任意多个输出,但是输入和输出必须是Variable。

backward的输入和输出的个数就是forward()函数的输出和输入的个数。其中,backward()输入表示关于forward()输出的梯度,backward()的输出表示关于forward()的输入的梯度。

另外还要加上ctx,它可以理解为一个上下文管理器。


定义新的操作,意味着定义Function的子类,并且这些子类必须重写以下函数:forward()backward()。初始化函数:__init__()根据实际需求判断是否需要重写。

二、例程

from torch.autograd import Function
class MultiplyAdd(Function):
    @staticmethod
    def forward(ctx, w, x, b):
        print('type in forward', type(x))
        ctx.save_for_backward(w, x)#存储用来反向传播的参数
        output = w*x +b
        return output
    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors #deprecated,现在使用saved_tensors
        print('type in backward',type(x))
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b
w = torch.rand(2, 2, requires_grad=True)
x = torch.rand(2, 2, requires_grad=True)
b = torch.rand(2, 2, requires_grad=True)
out = MultiplyAdd.apply(w, x, b)
out.backward(torch.ones(2,2))
w.grad,x.grad,b.grad
(tensor([[0.5159, 0.4950],
         [0.1050, 0.7115]]),
 tensor([[0.6249, 0.4731],
         [0.7905, 0.1637]]),
 tensor([[1., 1.],
         [1., 1.]]))

三、官方的demo(指数函数)

import torch
import torch.nn.functional as F
from torch.autograd import Function
class MyExp(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result) #将result转移到Variable保存在ctx中
        return result
    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result
def exp(x):
    return MyExp.apply(x)
x = torch.Tensor([0,2]).requires_grad_(True)
out = MyExp.apply(x)
out.backward(torch.Tensor([0,2]))
x.grad
tensor([ 0.0000, 14.7781])
相关文章
|
2月前
|
PyTorch 算法框架/工具
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
PyTorch中的`nn.AdaptiveAvgPool2d()`函数用于实现自适应平均池化,能够将输入特征图调整到指定的输出尺寸,而不需要手动计算池化核大小和步长。
142 1
Pytorch学习笔记(五):nn.AdaptiveAvgPool2d()函数详解
|
4月前
|
TensorFlow API 算法框架/工具
【Tensorflow+keras】解决使用model.load_weights时报错 ‘str‘ object has no attribute ‘decode‘
python 3.6,Tensorflow 2.0,在使用Tensorflow 的keras API,加载权重模型时,报错’str’ object has no attribute ‘decode’
56 0
torch.argmax(dim=1)用法
)torch.argmax(input, dim=None, keepdim=False)返回指定维度最大值的序号;
648 0
|
TensorFlow API 算法框架/工具
解决AttributeError: module ‘keras.utils‘ has no attribute ‘plot_model‘
解决AttributeError: module ‘keras.utils‘ has no attribute ‘plot_model‘
335 0
解决AttributeError: module ‘keras.utils‘ has no attribute ‘plot_model‘
|
Python
Python编程:from __future__ import print_function
Python编程:from __future__ import print_function
102 0
|
开发者 Python
成功解决 from ._conv import register_converters as _register_converters
成功解决 from ._conv import register_converters as _register_converters
成功解决 from ._conv import register_converters as _register_converters
|
Unix Apache 算法框架/工具
成功解决AttributeError: type object 'scipy.interpolate.interpnd.array' has no attribute '__reduce_cython
成功解决AttributeError: type object 'scipy.interpolate.interpnd.array' has no attribute '__reduce_cython
成功解决AttributeError: type object 'scipy.interpolate.interpnd.array' has no attribute '__reduce_cython
成功解决numpy.core._internal.AxisError: axis -1 is out of bounds for array of dimension 0
成功解决numpy.core._internal.AxisError: axis -1 is out of bounds for array of dimension 0