What is a loss function?
在 PyTorch 中,損失函數(Loss Function)用於衡量模型在訓練數據集上的預測誤差。它可用於指導模型參數的更新,使模型在訓練數據上的效果越來越好。損失函數通常與網絡的最後一層連用,通過自動微分,計算每個參數對損失的影響,指導參數優化方向。
Common loss functions in PyTorch:
-
MSELoss - 平均平方誤差損失,用於回歸問題
-
CrossEntropyLoss - 交叉熵損失,用於分類問題,將 softmax 激活和負對數似然損失合併在一起計算。
-
NLLLoss - 負對數似然損失,用於多分類問題。
-
BCELoss - 二進制交叉熵損失,用於二分類問題。
-
L1Loss - L1 範數損失,使模型對異常點更魯棒。用於線性回歸,邏輯回歸。
-
SmoothL1Loss - 平滑 L1 損失,綜合平方誤差損失和 L1 損失的優點。用於目標檢測。
What are regression problems, classification problems, multi-classification problems, and binary classification problems?
-
回歸問題:回歸問題的目標是預測連續型的數值目標變量,預測結果為一個連續的值。例如房價預測、銷量預測等。比如之前寫的股價預測,其中的損失函數就是 MSELoss。
-
分類問題:分類問題的目標是預測離散的類別標籤,一般輸出是一個類別。例如圖像分類(貓或狗)、垃圾郵件分類。
-
多分類問題:多分類問題的目標是對樣本進行多類別分類,預測結果是多個類別中的一個。分類類別數目大於 2。例如手寫數字識別(0~9 類別)。
-
二分類問題:二分類問題的目標是對樣本進行二分類,只有兩個類別,例如是與否分類、垃圾郵件檢測(垃圾或非垃圾)。
Using loss functions
在 PyTorch 中使用損失函數非常簡單,可以如下實例化一個損失函數對象:
loss_fn = nn.MSELoss()
loss = loss_fn(prediction, target)
選擇合適的損失函數對模型性能和訓練速度有很大影響。實踐中可以測試不同的損失函數,選擇驗證指標效果最好的。