torch.nn.Hardshrink
CLASS torch.nn.Hardshrink(lambd=0.5)
参数
lambd ([float]) – the λ \lambdaλ 默认为 0.5
Hardshrink是一种非线性函数,它用于对输入进行硬阈值处理。它将输入值与阈值进行比较,并将小于阈值的值设置为0,并将大于阈值的值保持不变。
Hardshrink函数的定义如下:
hardshrink(x, threshold) =
x if x <= -threshold or x >= threshold
0 otherwise
在Hardshrink函数中,输入值x与给定的阈值threshold进行比较。如果x小于等于 -threshold 或大于等于 threshold,则保持x不变。否则,返回0作为输出。
Hardshrink函数的作用类似于一个门控,用于去除输入中较小的幅度变化,从而加强较大幅度变化的信号。它可以用于降噪、稀疏化和特征选择等任务。通过调整阈值参数,可以控制过滤的灵敏度和压缩程度。
需要注意的是,Hardshrink函数是一个逐元素的操作,可以应用于向量、矩阵或任意大小的张量。
代码
import torch
import torch.nn as nn
m = nn.Hardshrink()
input = torch.randn(2)
output = m(input)
print("input: ", input) # input: tensor([ 0.2078, -1.4333])
print("output: ", output) # output: tensor([ 0.0000, -1.4333])
Lnton 羚通视频算法算力云平台专注于音视频算法、算力、云平台的高科技人工智能, 公司基于视频分析技术、视频智能传输技术、远程监测技术以及智能语音融合技术等, 拥有多款可支持 ONVIF、RTSP、GB/T28181 等多协议、多路数的音视频智能分析服务器。