pytorch的四个hook函数
  csROwDoT4AiY 2023年11月18日 24 0

  训练神经网络模型有时需要观察模型内部模块的输入输出,或是期望在不修改原始模块结构的情况下调整中间模块的输出,pytorch可以用hook回调函数来实现这一功能。主要使用四个hook注册函数:register_forward_hook、register_forward_pre_hook、register_full_backward_hook、register_full_backward_pre_hook。这四个函数可以被继承nn.Module的任意模块调用,传入hook函数并进行注册,从而在执行该模块的相应阶段调用hook函数实现所需功能。

register_forward_hook(self, hook, *, prepend, with_kwargs)

  为模块注册一个在该模块前向传播之后执行的回调函数。

  hook(module, args, output):需执行的回调函数对象,module为当前模块引用,args为当前模块前向传播输入,output为当前模块前向传播输出。可以返回修改后的output来修改该模块前向传播输出。

  prepend:将该hook函数放在回调函数列表最前面,从而最先执行,否则放在队列最后。

  with_kwargs:hook函数是否传入关键字参数,如果为True,则hook可以额外增加关键则参数。

  register_forward_hook注册函数本身返回一个handle句柄,可执行handle.remove()将注册的该hook函数移除。

register_forward_pre_hook(self, hook, *, prepend, with_kwargs)

  为模块注册一个在该模块前向传播之前执行的回调函数。

  hook(module, args):args为该模块前向传播输入。可以返回修改后的args来修改该模块前向传播输入。

  其它参数、特性与前面一致。

register_full_backward_hook(self, hook, prepend)

  为模块注册一个在该模块反向传播之后执行的回调函数。

  hook(module, grad_input, grad_output):grad_input与grad_output分别为该模块前向传播输入和输出的梯度。可以返回修改后的grad_input来修改该模块前向传播输入的梯度。

register_full_backward_pre_hook(self, hook, prepend)

  为模块注册一个在该模块反向传播之前执行的回调函数。

  hook(module, grad_output):grad_output为该模块前向传播输出的梯度。可以返回修改后的grad_output来修改这一梯度。

 



【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月18日 0

暂无评论

推荐阅读
csROwDoT4AiY