本發明專利技術提供了一種兩階段的長尾學習方法,包括:利用具有長尾分布的圖像分類的訓練集對圖像分類模型進行兩階段的多輪迭代訓練,得到經訓練的圖像分類模型,其中:第一階段的學習,包括:利用頭部類集合和損失函數確定第一銳度感知梯度,以及利用尾部類集合和損失函數確定第二銳度感知梯度,根據第一銳度感知梯度和第二銳度感知梯度更新圖像分類模型的參數;第二階段的學習,包括:利用頭部類集合和損失函數確定第一原始梯度,以及利用尾部類集合和損失函數確定第二銳度感知梯度,根據第一原始梯度和第二銳度感知梯度更新圖像分類模型的參數,本發明專利技術方法在整體上提升了模型的泛化能力。
【技術實現步驟摘要】
本專利技術涉及神經網絡,具體來說涉及長尾分布下的深度學習領域,更具體地說,涉及一種兩階段的長尾學習方法。
技術介紹
1、長尾分布[1][2]是少量類別(一般稱頭部類)在數據集中占據著較多的樣本量,而大量類別(一般稱尾部類)占據了較少的樣本量的數據分布,表現出極大的不平衡性。應用在長尾分布中的深度學習模型被稱為長尾學習模型。
2、從現實世界中搜集到的數據往往存在著長尾分布的現象。例如,在野生動物圖像分類識別中,含有珍稀動物的圖像數量往往遠少于含有老虎、大象等常見動物的圖像數量。由于其現實意義,長尾學習近年來受到了越來越廣泛的關注。
3、傳統的長尾學習方法往往基于重采樣或類別敏感學習的方法來實現。其中,前者對數據集進行修改,通過欠采樣、過采樣等方式生成平衡數據;后者則針對損失函數進行修改,通過修改權重、邊際等,改變不同類別在損失函數中占據的權重。然而,傳統方法往往在尾部類上表現出過擬合的現象,嚴重制約了模型的泛化能力。
4、平滑性正則化方法可以有效解決過擬合問題[3][4]。銳度感知最小化(sam)[5]是一種典型的平滑性正則化方法,它在最小化損失函數的損失值的同時最小化損失函數的銳度,使得模型收斂于平坦的局部或全局極小值點,從而有效提升了模型的泛化能力,為解決過擬合問題提供了有希望的方向。
5、然而,直接將銳度感知最小化方法應用于長尾學習中是不佳的,因頭部類的樣本數量占比更大,其擾動項中頭部類的影響會占據主導地位,影響尾部類上的泛化能力。
6、以上參考文獻的信息如下:p>7、[1]wanli?ouyang,xiaogang?wang,cong?zhang,and?xiaokang?yang.factors?infinetuning?deep?model?for?object?detection?with?long-tail?distribution.incvpr,pages?864–873,2016.
8、[2]songyang?zhang,zeming?li,shipeng?yan,xuming?he,and?jiansun.distribution?alignment:a?unified?framework?for?long-tail?visualrecognition.in?cvpr,pages?2361–2370,2021.
9、[3]bingyi?kang,saining?xie,marcus?rohrbach,zhicheng?yan,albert?gordo,jiashi?feng,and?yannis?kalantidis.decoupling?representation?and?classifierfor?long-tailed?recognition.in?iclr,2020.
10、[4]shaden?alshammari,yu-xiong?wang,deva?ramanan,and?shu?kong.long-tailed?recognition?via?weight?balancing.in?cvpr,pages?6897–6907,2022.
11、[5]pierre?foret,ariel?kleiner,hossein?mobahi,and?behnamneyshabur.sharpness-aware?minimization?for?efficiently?improvinggeneralization.in?iclr,2021.
12、需要說明的是:本
技術介紹
僅用于介紹本專利技術的相關信息,以便于幫助理解本專利技術的技術方案,但并不意味著相關信息必然是現有技術。相關信息與本專利技術方案一同提交和公開,在沒有證據表明相關信息已在本專利技術的申請日以前公開的情況下,相關信息不應被視為現有技術。
技術實現思路
1、因此,本專利技術的目的在于克服上述現有技術的缺陷,提供一種兩階段的長尾學習方法。
2、本專利技術的目的是通過以下技術方案實現的:
3、根據本專利技術的第一方面,提供一種兩階段的長尾學習方法,包括:利用具有長尾分布的圖像分類的訓練集對圖像分類模型進行兩階段的多輪迭代訓練,得到經訓練的圖像分類模型,每輪訓練包括:從訓練集獲取一個批次的樣本,根據該批次的樣本所含標簽屬于預設的頭部類還是尾部類,將該批次的樣本分至頭部類集合或尾部類集合;判斷模型訓練是否達到預設要求,若否,當前輪次進行第一階段的學習,若是,當前輪次進行第二階段的學習,其中:第一階段的學習,包括:利用頭部類集合和損失函數確定第一銳度感知梯度,以及利用尾部類集合和損失函數確定第二銳度感知梯度,根據第一銳度感知梯度和第二銳度感知梯度更新圖像分類模型的參數;執行第二階段的學習,包括:利用頭部類集合和損失函數確定第一原始梯度,以及利用尾部類集合和損失函數確定第二銳度感知梯度,根據第一原始梯度和第二銳度感知梯度更新圖像分類模型的參數。
4、可選的,樣本包括樣本圖像和標簽,標簽指示對應樣本圖像所屬的類別,其中,第一階段的學習包括:將該批次中的樣本圖像輸入圖像分類模型得到第一預測值;根據頭部類集合中樣本圖像對應的第一預測值、標簽和損失函數,確定頭部類集合對應的第一原始梯度;基于銳度感知最小化技術,利用第一原始梯度計算頭部類的擾動項;將圖像分類模型的參數與頭部類的擾動項相加,得到第一擾動參數,用第一擾動參數對頭部類集合的樣本圖像預測得到第二預測值,根據第二預測值、標簽和損失函數求梯度,得到頭部類集合對應的第一銳度感知梯度;根據尾部類集合中樣本圖像對應的第一預測值、標簽和損失函數,確定尾部類集合對應的第二原始梯度;基于銳度感知最小化技術,利用第二原始梯度計算尾部類的擾動項;將圖像分類模型的參數與尾部類的擾動項相加,得到第二擾動參數,用第二擾動參數對尾部類集合的樣本圖像預測得到第三預測值,根據第三預測值、標簽和損失函數求梯度,得到尾部類集合對應的第二銳度感知梯度;根據第一銳度感知梯度和第二銳度感知梯度計算的總梯度,更新圖像分類模型的參數。
5、可選的,第二階段的學習包括:將該批次中的樣本圖像輸入圖像分類模型得到第一預測值;根據頭部類集合中樣本圖像對應的第一預測值、標簽和損失函數,確定頭部類集合對應的第一原始梯度;根據尾部類集合中樣本圖像對應的第一預測值、標簽和損失函數,確定尾部類集合對應的第二原始梯度;基于銳度感知最小化技術,利用第二原始梯度計算尾部類的擾動項;將圖像分類模型的參數與尾部類的擾動項相加,得到第二擾動參數,用第二擾動參數對尾部類集合的樣本圖像預測得到第三預測值,根據第三預測值、標簽和損失函數求梯度,得到尾部類集合對應的第二銳度感知梯度;根據第一原始梯度和第二銳度感知梯度計算的總梯度,更新圖像分類模型的參數。
6、可選的,頭部類的擾動項按照以下方式計算:
7、
8、其中,∈head(wt)表示頭部類的擾動項,是一個向量,其中包括圖像分類模型中本文檔來自技高網
...
【技術保護點】
1.一種兩階段的長尾學習方法,其特征在于,包括:利用具有長尾分布的圖像分類的訓練集對圖像分類模型進行兩階段的多輪迭代訓練,得到經訓練的圖像分類模型,每輪訓練包括:
2.根據權利要求1所述的方法,其特征在于,樣本包括樣本圖像和標簽,標簽指示對應樣本圖像所屬的類別,其中,
3.根據權利要求2所述的方法,其特征在于,第二階段的學習包括:
4.根據權利要求2或者3所述的方法,其特征在于,頭部類的擾動項按照以下方式計算:
5.根據權利要求4所述的方法,其特征在于,尾部類的擾動項按照以下方式計算:
6.根據權利要求5所述的方法,其特征在于,按照以下方式更新圖像分類模型的參數:
7.一種圖像分類方法,包括:
8.一種計算機程序產品,包括計算機程序/指令,該計算機程序/指令被處理器執行時實現權利要求1-7之一所述方法的步驟。
9.一種計算機可讀存儲介質,其特征在于,其上存儲有計算機程序,所述計算機程序可被處理器執行以實現權利要求1-7之一所述方法的步驟。
10.一種電子設備,其特征在于,包括:
...
【技術特征摘要】
1.一種兩階段的長尾學習方法,其特征在于,包括:利用具有長尾分布的圖像分類的訓練集對圖像分類模型進行兩階段的多輪迭代訓練,得到經訓練的圖像分類模型,每輪訓練包括:
2.根據權利要求1所述的方法,其特征在于,樣本包括樣本圖像和標簽,標簽指示對應樣本圖像所屬的類別,其中,
3.根據權利要求2所述的方法,其特征在于,第二階段的學習包括:
4.根據權利要求2或者3所述的方法,其特征在于,頭部類的擾動項按照以下方式計算:
5.根據權利要求4所述的方法,其特征...
【專利技術屬性】
技術研發人員:黃慶明,呂星宇,許倩倩,楊智勇,
申請(專利權)人:中國科學院計算技術研究所,
類型:發明
國別省市:
還沒有人留言評論。發表了對其他瀏覽者有用的留言會獲得科技券。