pytorch的cross_entropy
  sZgmf4bMUDTI 2023年11月24日 31 0

PyTorch的Cross Entropy

介绍

在机器学习和深度学习中,交叉熵(Cross Entropy)是一个重要的损失函数,特别适用于分类问题。PyTorch是一个流行的深度学习框架,提供了很多高级功能,其中包括了交叉熵损失函数。在本文中,我们将介绍PyTorch中的交叉熵损失函数及其使用方法。

交叉熵损失函数

交叉熵损失函数可以用来衡量模型的输出与真实标签之间的差异。对于一个分类任务,我们通常使用一个softmax层将模型输出转化为概率分布。交叉熵损失函数的定义如下:

$$ H(p,q) = -\sum_{i} p_i \log(q_i) $$

其中,p是真实标签的概率分布,q是模型输出的概率分布。交叉熵损失函数的值越小,说明模型的输出与真实标签越接近。

在PyTorch中,我们可以使用torch.nn.CrossEntropyLoss类来计算交叉熵损失。这个类将softmax层和交叉熵损失函数结合在一起,方便我们使用。下面是一个示例:

import torch
import torch.nn as nn

# 模型输出
outputs = torch.tensor([[0.1, 0.2, 0.7], [0.3, 0.5, 0.2]])
# 真实标签
labels = torch.tensor([2, 0])

# 创建交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()

# 计算损失
loss = loss_fn(outputs, labels)
print(loss)

这段代码中,我们首先定义了一个模型的输出outputs和对应的真实标签labels。然后,我们创建了一个nn.CrossEntropyLoss的实例loss_fn。最后,我们调用loss_fn__call__方法并传入模型输出和真实标签来计算损失。

代码示例

下面我们将通过一个简单的分类问题,来演示交叉熵损失函数的使用。我们首先生成一些随机数据用于训练:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np

# 随机种子
torch.manual_seed(42)
np.random.seed(42)

# 生成数据
num_samples = 1000
num_classes = 5

# 输入特征
X = torch.randn(num_samples, 10)
# 真实标签
y = torch.randint(low=0, high=num_classes, size=(num_samples,))

# 数据可视化
label_counts = np.bincount(y.numpy())
plt.pie(label_counts, labels=[f"Class {i}" for i in range(num_classes)], autopct='%1.1f%%')
plt.title("Label Distribution")
plt.show()

上面的代码中,我们使用torch.randn生成了一个大小为(num_samples, 10)的输入特征矩阵X,并使用torch.randint生成了一个大小为(num_samples,)的真实标签向量y。然后,我们统计了每个类别的样本数量,并使用饼状图进行可视化。

接下来,我们定义一个简单的全连接神经网络,并使用交叉熵损失函数进行训练:

# 定义模型
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, num_classes)
)

# 定义优化器
optimizer = optim.SGD(model.parameters(), lr=0.1)

# 定义损失函数
loss_fn = nn.CrossEntropyLoss()

# 训练模型
num_epochs = 100
losses = []

for epoch in range(num_epochs):
    # 前向传播
    outputs = model(X)
    # 计算损失
    loss = loss_fn(outputs, y)
    losses.append(loss.item())
    
    # 反向传播和优化
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# 绘制损失曲线
plt.plot(losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.show()
【版权声明】本文内容来自摩杜云社区用户原创、第三方投稿、转载,内容版权归原作者所有。本网站的目的在于传递更多信息,不拥有版权,亦不承担相应法律责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@moduyun.com

  1. 分享:
最后一次编辑于 2023年11月24日 0

暂无评论

推荐阅读
sZgmf4bMUDTI