非平衡數據集 focal loss 多類分類

本文為 AI 研習社編譯的技術博客,原標題 :

Multi-class classification with focal loss for imbalanced datasets

作者 | Chengwei Zhang

翻譯 | 汪鵬 校對 | 斯蒂芬·二狗子

審核 | Pita 整理 | 立魚王

原文鏈接:

https://medium.com/swlh/multi-class-classification-with-focal-loss-for-imbalanced-datasets-c478700e65f5

非平衡数据集 focal loss 多类分类

焦點損失函數 Focal Loss(2017年何凱明大佬的論文)被提出用於密集物體檢測任務。它可以訓練高精度的密集物體探測器,哪怕前景和背景之間比例為1:1000(譯者注:facal loss 就是為了解決目標檢測中類別樣本比例嚴重失衡的問題)。本教程將向您展示如何在給定的高度不平衡的數據集的情況下,應用焦點損失函數來訓練一個多分類模型。

背景

讓我們首先了解類別不平衡數據集的一般的處理方法,然後再學習 focal loss 的解決方式。

在多分類問題中,類別平衡的數據集的目標標籤是均勻分佈的。若某類目標的樣本相比其他類在數量上佔據極大優勢,則可以將該數據集視為不平衡的數據集。這種不平衡將導致兩個問題:

  • 訓練效率低下,因為大多數樣本都是簡單的目標,這些樣本在訓練中提供給模型不太有用的信息;

  • 簡單的樣本數量上的極大優勢會搞垮訓練,使模型性能退化。

一種常見的解決方案是執行某種形式的困難樣本挖掘,實現方式就是在訓練時選取困難樣本 或 使用更復雜的採樣,以及重新對樣本加權等方案。

對具體圖像分類問題,對數據增強技術方案變更,以便為樣本不足的類創建增強的數據。

焦點損失函數旨在通過降低內部加權(簡單樣本)來解決類別不平衡問題,這樣即使簡單樣本的數量很大,但它們對總損失的貢獻卻很小。也就是說,該函數側重於用困難樣本稀疏的數據集來訓練。

將 Focal Loss 應用於欺詐檢測任務

為了演示,我們將會使用 Kaggle上的欺詐檢測數據集 構建一個分類器,這個數據及具有極端的類不平衡問題,它包含總共6354407個正常樣本和8213個欺詐案例,兩者比例約為733:1。對這種高度不平衡的數據集的分類問題,若某模型簡單猜測所有輸入樣本為“正常”就可以達到733 /(733 + 1)= 99.86%的準確度,這顯然是不合理。因此,我們需要的是這個模型能夠正確檢測出欺詐案例。

為了證明focal loss 比傳統技術更有效,讓我們建立一個簡單地使用類別權重 class_weight訓練的基準模型,告訴模型“更多地關注”來自代表性不足的欺詐樣本。

非平衡数据集 focal loss 多类分类

基準模型

基準模型的準確率達到了99.87%,略好於通過採取“簡單路線”去猜測所有情況都為“正常”。

我們還繪製了混淆矩陣來展示模型在測試集上的分類性能。你可以看到總共有1140 + 480 = 1620 個樣本被錯誤分類。

非平衡数据集 focal loss 多类分类

混淆矩陣-基準模型

現在讓我們將focal loss應用於這個模型的訓練。你可以在下面看到如何在Keras框架下自定義焦點損失函數focal loss 。

非平衡数据集 focal loss 多类分类

焦點損失函數-模型

焦點損失函數focal loss 有兩個可調的參數。

  • 焦點參數γ(gamma)平滑地調整簡單樣本被加權的速率。當γ= 0時, focal loss 效果與交叉熵函數相同,並且隨著 γ 增加,調製因子的影響同樣增加(γ = 2在實驗中表現的效果最好)。

  • α(alpha):平衡focal loss ,相對於非 α 平衡形式可以略微提高它的準確度。

現在讓我們把訓練好的模型與之前的模型進行比較性能。雷鋒網雷鋒網雷鋒網

Focal Loss 模型:

  • 精確度:99.94%

  • 總錯誤分類測試集樣本:766 + 23 = 789,將錯誤數減少了一半。

非平衡数据集 focal loss 多类分类

混淆矩陣-focal loss模型

結論及導讀

在這個快速教程中,我們為你的知識庫引入了一個新的工具來處理高度不平衡的數據集 — Focal Loss。並通過一個具體的例子展示瞭如何在Keras 的 API 中定義 focal loss進而改善你的分類模型。

你可以在我的GitHub上找到這篇文章的完整源代碼。

有關focal loss的詳細情況,可去查閱論文https://arxiv.org/abs/1708.02002。

最初發表於www.dlology.com.

想要繼續查看該篇文章相關鏈接和參考文獻?

點擊非平衡數據集 focal loss 多類分類】即可訪問:

https://ai.yanxishe.com/page/TextTranslation/1646

AI研習社今日推薦:2019 最新斯坦福 CS224nNLP 課程

自然語言處理(NLP)是信息時代最重要的技術之一,也是人工智能的關鍵部分。NLP的應用無處不在,因為人們幾乎用語言進行交流:網絡搜索,廣告,電子郵件,客戶服務,語言翻譯,醫學報告等。近年來,深度學習方法在許多不同的NLP任務中獲得了非常高的性能,使用單個端到端神經模型,不需要傳統的,任務特定的特徵工程。在本課程中,學生將深入瞭解NLP深度學習的前沿研究。

課程鏈接:https://ai.yanxishe.com/page/groupDetail/59


分享到:


相關文章: