损失函数是什么
在 PyTorch 中,损失函数 (Loss Function) 是用于衡量模型在训练数据集上的预测误差的函数。它可用于指导模型参数的更新,使模型在训练数据上的效果越来越好。损失函数通常与网络的最后一层连用,通过自动微分,计算每个参数对损失的影响,指导参数优化方向。
PyTorch 中常用的损失函数有:
-
MSELoss - 平均平方误差损失,用于回归问题
-
CrossEntropyLoss - 交叉熵损失,用于分类问题,将 softmax 激活和负对数似然损失合并在一起计算。
-
NLLLoss - 负对数似然损失,用于多分类问题。
-
BCELoss - 二进制交叉熵损失,用于二分类问题。
-
L1Loss - L1 范数损失,使模型对异常点更鲁棒。用于线性回归,逻辑回归。
-
SmoothL1Loss - 平滑 L1 损失,综合平方误差损失和 L1 损失的优点。用于目标检测。
什么是回归问题,分类问题,多分类问题,二分类问题
-
回归问题:回归问题的目标是预测连续型的数值目标变量,预测结果为一个连续的值。如房价预测、销量预测等。比如之前写的股价预测,其中的损失函数就是 MSELoss
-
分类问题:分类问题的目标是预测离散的类别标签,一般输出是一个类别。如图像分类 (猫 or 狗)、垃圾邮件分类。
-
多分类问题:多分类问题的目标是对样本进行多类别分类,预测结果是多个类别中的一个。分类类别数目大于 2。如手写数字识别 (0~9 类别)。
-
二分类问题:二分类问题的目标是对样本进行二分类,只有两个类别,如是与否分类、垃圾邮件检测 (垃圾 or 非垃圾)。
使用损失函数
在 PyTorch 中使用损失函数非常简单,可以如下实例化一个损失函数对象:
loss_fn = nn.MSELoss()
loss = loss_fn(prediction, target)
选择合适的损失函数对模型性能和训练速度有很大影响。实践中可以测试不同的损失函数,选择验证指标效果最好的。