本發(fā)明涉及深度學習,尤其涉及一種基于知識蒸餾的深度融合多跨域少樣本分類方法。
背景技術:
1、目前,深度學習在計算機視覺領域已經(jīng)取得了較大成功,例如物體分類、圖像檢索和動作識別等任務。深度學習的成功在很大程度上依賴于海量的數(shù)據(jù)和強大的計算資源。然而許多識別分辨和機器學習任務中,人類往往可以從很少的例子中識別新物體或視覺概念,這種快速學習的能力是現(xiàn)在的深度學習所不具備的。因此,如何通過有限的標記數(shù)據(jù)來學習識別新類別引起了人們的廣泛關注,這也是少樣本學習(few-shot?learning)所要解決的核心問題。近幾年來,大量少樣本學習的工作都采用了元學習(meta?learning)的思想,其基本思路是在訓練階段通過訓練一個能夠快速適應新任務的模型,從而在測試階段可以高效地使用少量的樣本來學習新任務。然而,大多數(shù)的少樣本學習方法依賴于與目標任務域分布相同的輔助數(shù)據(jù)。為了解決這個問題,研究者提出了多跨域少樣本學習(multiple?cross-domain?few-shot?learning,mcd-fsl),旨在利用來自多個不同源域的數(shù)據(jù)來提升模型性能。tseng等人通過在多個源域上進行實驗驗證了該方法可以帶來有效的改進。然而,triantafillou等人也指出,簡單的跨域學習并不總是能達到預期的效果。這是因為多個源域之間的數(shù)據(jù)干擾,導致了負遷移現(xiàn)象。為了解決這個問題,部分研究采用知識蒸餾來捕捉正確的源域信息。
2、知識蒸餾(knowledge?distillation,kd)是一種深度學習技術,旨在將一個或多個復雜的教師網(wǎng)絡中的知識轉移到一個較簡單的學生網(wǎng)絡中。學生網(wǎng)絡通常較小、速度更快且資源效率更高。盡管學生網(wǎng)絡的規(guī)模較小,但它可以保持與教師網(wǎng)絡相當?shù)男阅?,甚至提高其泛化能力。傳統(tǒng)的知識蒸餾方法主要依賴于一個強大的教師網(wǎng)絡。hinton等人首先提出了知識蒸餾的概念,他們證明了一個強大的教師網(wǎng)絡能夠有效地指導學生網(wǎng)絡的訓練。然而,近年來的研究發(fā)現(xiàn),使用多個教師網(wǎng)絡進行知識蒸餾,可以進一步提升學生網(wǎng)絡的性能。例如,通過平均多個教師網(wǎng)絡的輸出,可以得到一個更穩(wěn)健的指導信號。然而,簡單地對多個教師網(wǎng)絡的輸出進行平均,可能會忽略每個教師網(wǎng)絡在不同任務上的不同貢獻,進而誤導學生網(wǎng)絡的學習。為了更好地結合跨域少樣本學習和知識蒸餾的方法,研究者們提出了一些新的策略。例如,研究者們引入了自適應的蒸餾溫度和損失權重,使學生網(wǎng)絡能夠更加有效地從多個教師網(wǎng)絡中學習。在訓練過程中,研究者們先在多個源域上訓練多個教師網(wǎng)絡,然后將這些教師網(wǎng)絡中的知識通過蒸餾的方式轉移到學生網(wǎng)絡中。為了確保學生網(wǎng)絡能夠從不同源域中學習到有用的信息,研究者們設計了一種跨域蒸餾損失函數(shù)。這種損失函數(shù)能夠幫助學生網(wǎng)絡在多個源域之間進行有效的知識遷移,從而提升其在目標域上的表現(xiàn)。
3、盡管這些方法在實驗中顯示出了一定的有效性,但當前的方法仍然存在一些問題和挑戰(zhàn)。首先,多源域數(shù)據(jù)之間的分布差異可能會導致網(wǎng)絡模型在融合不同源域信息時出現(xiàn)困難,影響模型的泛化能力。即使通過知識蒸餾的方法對不同源域的數(shù)據(jù)進行整合,仍然無法完全消除源域之間的分布差異對模型性能的負面影響。另外,知識蒸餾過程中涉及的多教師網(wǎng)絡之間可能存在沖突,導致學生網(wǎng)絡無法有效學習所有教師網(wǎng)絡的知識。這需要在蒸餾過程中引入更復雜的權重調(diào)整機制,以平衡不同教師網(wǎng)絡對學生網(wǎng)絡的影響。在跨域學習中,還經(jīng)常涉及到多任務學習的問題,即網(wǎng)絡模型需要同時學習和處理多個不同的任務。如何有效地協(xié)調(diào)和優(yōu)化這些任務之間的關系,以最大化整體性能,也是一個值得思考的問題。
技術實現(xiàn)思路
1、本發(fā)明要解決的技術問題是針對上述現(xiàn)有技術的不足,提供一種基于知識蒸餾的深度融合多跨域少樣本分類方法,實現(xiàn)多跨域少樣本分類。
2、為解決上述技術問題,本發(fā)明所采取的技術方案是:一種基于知識蒸餾的深度融合多跨域少樣本分類方法,包括以下步驟:
3、步驟1:預訓練教師和學生網(wǎng)絡;
4、分別利用n個不同源域的訓練集{z1,z2,…,zn}來訓練n個不同的教師網(wǎng)絡,每個教師網(wǎng)絡都采用相同的網(wǎng)絡結構,包含一個教師特征編碼器et和一個線性分類器ct,初始化n個教師網(wǎng)絡參數(shù),以傳統(tǒng)監(jiān)督訓練的方式利用交叉熵損失對每一個教師網(wǎng)絡進行預訓練,最終得到n個訓練好的教師網(wǎng)絡;利用一個多樣性數(shù)據(jù)集來預訓練學生網(wǎng)絡sp,學生網(wǎng)絡包含一個特征編碼器es_p和一個基于距離度量的分類器cs,以傳統(tǒng)監(jiān)督訓練的方式利用交叉熵損失對學生網(wǎng)絡進行預訓練,得到預訓練的學生網(wǎng)絡sp;
5、利用交叉熵損失函數(shù)對每一個教師網(wǎng)絡和學生網(wǎng)絡進行預訓練,包括:
6、(1)從第n個源域訓練集zn中隨機選取一定量的數(shù)據(jù)作為第n個教師特征編碼器或學生網(wǎng)絡的輸入,經(jīng)過編碼得到第i個樣本圖像的視覺特征
7、
8、其中,為第n個源域訓練集中的第i個樣本圖像,en為第n個特征編碼器;
9、(2)將第i個樣本圖像的視覺特征輸入第n個教師分類器cn,得到第n個源域訓練集中的第i個樣本圖像的類別預測概率:
10、
11、其中,為第i個樣本圖像屬于第w個類別的預測概率;
12、(3)設定教師網(wǎng)絡和學生網(wǎng)絡的目標函數(shù)ln公式如下:
13、
14、其中,為第i個樣本圖像的真實標簽,w為第n個源域訓練集中的樣本類別數(shù);
15、(4)根據(jù)公式(2)訓練第n個特征編碼器en和第n個分類器cn,保留使公式(3)的誤差值最小的第n個特征編碼器en和第n個分類器cn;
16、(5)重復第(1)步~第(4)步,得到預訓練好的n個教師網(wǎng)絡和一個學生網(wǎng)絡sp;
17、步驟2:利用預訓練的學生網(wǎng)絡sp構建學生網(wǎng)絡s;
18、所述學生網(wǎng)絡s是一種基于度量的少樣本模型——原型網(wǎng)絡,包含一個學生特征編碼器es和一個度量函數(shù)d,通過預訓練的學生網(wǎng)絡特征編碼器es_p初始化學生特征編碼器es;
19、步驟3:元訓練階段,從n個不同源域的訓練集中隨機選取一個訓練集作為當前的元訓練集dtrain,根據(jù)元學習的思想,從當前的元訓練集中隨機采樣m個少樣本任務,每個任務都包含一個支持集s和一個查詢集q,支持集s中含有w個類別的樣本圖像數(shù)據(jù),每個類別有k個樣本;查詢集q中含有w個類別的樣本圖像;
20、步驟4:將不同的少樣本任務同時送到n個教師網(wǎng)絡和學生網(wǎng)絡s中;
21、步驟5:依次將支持集s中第k個樣本圖像xk輸入到n個教師特征編碼器和學生特征編碼器es中,分別得到相對應的視覺特征和
22、
23、
24、其中,xk為支持集s中第k個樣本圖像,為第n個教師特征編碼器,n=1,2,…,n,es為學生特征編碼器,為第n個教師特征編碼器en對xk編碼后輸出的視覺特征,為學生特征編碼器es對xk編碼后輸出的視覺特征;
25、步驟6:在每一個源域中,分別對支持集中屬于同一類別的樣本圖像對應的視覺特征取平均,得到每個類別的原型表示為:
26、
27、
28、其中,k為第w個類別的樣本總數(shù),w=1,2,…,w,為經(jīng)過第n個教師特征編碼器編碼后的第w個類別的原型表示,為經(jīng)過學生編碼器編碼后的第w個類別的原型表示;
29、步驟7:依次將查詢集q中的樣本圖像xq輸入到n個教師特征編碼器和學生特征編碼器es中,分別得到相對應的視覺特征和
30、
31、
32、其中,xq為查詢集q中的樣本圖像,為第n個教師特征編碼器對xq編碼后輸出的視覺特征,為學生特征編碼器es對xq編碼后輸出的視覺特征;
33、步驟8:對教師網(wǎng)絡和學生網(wǎng)絡s得到的支持集中每個類別的原型表示和查詢集中樣本圖像對應的視覺特征分別進行歸一化處理,得到對應的歸一化結果:
34、
35、
36、
37、
38、其中,為歸一化后的經(jīng)過第n個教師特征編碼器編碼后的第w個類別的原型表示,為歸一化后的第n個教師特征編碼器對xq編碼后輸出的視覺特征,為歸一化后的經(jīng)過學生編碼器es編碼后的第w個類別的原型表示向量,為歸一化后的學生特征編碼器es對xq編碼后輸出的視覺特征;
39、步驟9:通過歸一化后的教師網(wǎng)絡和學生網(wǎng)絡中查詢樣本對應的視覺特征和每一類的原型表示,基于相似度分別計算在教師網(wǎng)絡和學生網(wǎng)絡中每個查詢樣本屬于每個類別的預測概率矩陣和slogits:
40、
41、
42、其中,表示第n個教師網(wǎng)絡中查詢樣本對應的視覺特征和原型表示之間基于相似度的預測概率矩陣,slogits表示學生網(wǎng)絡中查詢樣本對應的視覺特征和原型表示之間基于相似度的預測概率矩陣,為同一個教師特征編碼器各個原型表示,為學生特征編碼器各個原型表示;
43、步驟10:設定查詢集q中樣本的真實標簽為yquery,根據(jù)查詢集的真實標簽yquery和學生網(wǎng)絡中查詢樣本對應的視覺特征和原型表示之間基于相似度的預測概率矩陣slogits,計算學生網(wǎng)絡的分類損失,設定學生網(wǎng)絡的分類目標函數(shù)lcls如下:
44、
45、其中,q是查詢集中的樣本總數(shù),w是類別總數(shù),是第i個查詢樣本的預測概率向量,是第i個查詢樣本的真實標簽,d(,)為基于歐式距離的度量函數(shù);
46、步驟11:根據(jù)經(jīng)過n個教師特征編碼器和學生特征編碼器編碼后的原型表示和查詢集樣本對應的視覺特征計算類別預測概率矩陣,從而在n個教師網(wǎng)絡和學生網(wǎng)絡之間進行基于軟標簽的知識蒸餾,得到學生網(wǎng)絡的基于軟標簽的目標函數(shù)lkl;
47、1)將n個教師網(wǎng)絡輸出的查詢集樣本圖像xq與每一類的原型表示之間基于相似度的預測概率矩陣進行深度融合,作為訓練學生網(wǎng)絡的目標之一:
48、
49、其中ffc表示全連接網(wǎng)絡,fmean表示平均函數(shù),表示n個教師網(wǎng)絡中查詢樣本和每一類的原型表示之間基于相似度的預測概率矩陣,⊕,⊙,||分別表示點加,點乘,級聯(lián),tfusion表示n個教師網(wǎng)絡融合后的查詢集樣本圖像xq特征與每一類的原型表示之間基于相似度的預測概率矩陣;
50、2)為了使學生網(wǎng)絡與教師網(wǎng)絡的輸出一致,設定學生網(wǎng)絡的基于預測概率矩陣的目標函數(shù)lkl如下:
51、
52、
53、其中,τ是知識蒸餾的溫度系數(shù),q是查詢樣本集中的樣本總數(shù),是學生網(wǎng)絡提取的第i個查詢樣本對應的視覺特征與每一類的原型表示之間基于相似度的預測概率向量,表示n個教師網(wǎng)絡融合后的第i個查詢樣本對應的視覺特征與每一類的原型表示之間基于相似度的預測概率向量;
54、步驟12:根據(jù)經(jīng)過n個教師特征編碼器和學生特征編碼器編碼后的原型表示和查詢集樣本對應的視覺特征,使學生網(wǎng)絡生成的特征圖的空間分布與教師網(wǎng)絡生成的特征圖對齊,將教師網(wǎng)絡的特征圖作為信息傳遞的參考,從而在n個教師網(wǎng)絡和學生網(wǎng)絡之間進行基于注意力的知識蒸餾,得到學生網(wǎng)絡基于特征圖的目標函數(shù)lattention;
55、s1:根據(jù)經(jīng)過n個教師特征編碼器和學生特征編碼器編碼后的支持集樣本和查詢集樣本對應的視覺特征,得到經(jīng)過同一個教師或學生編碼器編碼后的視覺特征:
56、
57、
58、其中,為第n個教師特征編碼器en對xq編碼后輸出的視覺特征,為學生特征編碼器es對xq編碼后輸出的視覺特征,為經(jīng)過第n個教師特征編碼器編碼的支持集和查詢集樣本對應的視覺特征合并后的結果,vs為經(jīng)過學生特征編碼器編碼的支持集和查詢集對應的視覺特征合并后的結果;
59、s2:將n個教師特征編碼器對支持集和查詢集樣本編碼得到的視覺特征進行深度融合,作為訓練學生網(wǎng)絡的目標之一:
60、
61、其中,v1,…,vn表示n個教師特征編碼器合并后的支持集和查詢集樣本的視覺特征,表示對n個教師特征編碼器合并后的支持集和查詢集樣本的視覺特征進行深度融合后的結果;
62、s3:為了使學生網(wǎng)絡與教師網(wǎng)絡的輸出一致,通過支持集和查詢集樣本經(jīng)過學生網(wǎng)絡提取的視覺特征和融合后的多個教師網(wǎng)絡提取的視覺特征,設定學生網(wǎng)絡的基于注意力的目標函數(shù)lattention如下:
63、
64、
65、
66、其中,g是支持集和查詢集樣本總數(shù),是支持集和查詢集樣本集合中第i個樣本經(jīng)過學生特征編碼器得到的特征向量,是支持集和查詢集樣本集合中第i個樣本融合后的視覺特征向量,p是用于注意力轉移函數(shù)的超參數(shù);
67、步驟13:根據(jù)如下學生網(wǎng)絡的總目標函數(shù)公式,使用sgd算法訓練學生特征編碼器:
68、l=λ1×lcls+λ2×lkl+λ3×lattention???(26)
69、其中,l為學生網(wǎng)絡的總目標函數(shù),lcls為學生網(wǎng)絡的分類目標函數(shù),lkl為學生網(wǎng)絡的基于軟標簽的目標函數(shù),lattention為學生網(wǎng)絡的基于特征圖的目標函數(shù),λ1,λ2,λ3為權重系數(shù);
70、步驟14:重復步驟3-13,直至總目標函數(shù)值逐漸收斂且趨于不變時,得到訓練好的學生網(wǎng)絡;
71、步驟15:測試階段,給定一個不同于n個預訓練源域的數(shù)據(jù)集作為目標域,依次將來自目標域測試集的支持集和查詢集的樣本圖像輸入到訓練好的學生特征編碼器es中,得到相應的視覺特征,按照公式(6)(7)計算支持集中各個類別的原型表示,再按照公式(14)(15)計算查詢集樣本圖像屬于各個類別的概率,將計算得到的概率中最大的概率所對應的類別,作為查詢集樣本圖像分類結果。
72、采用上述技術方案所產(chǎn)生的有益效果在于:本發(fā)明提供的一種基于知識蒸餾的深度融合跨域少樣本分類方法,利用知識蒸餾的師生網(wǎng)絡框架實現(xiàn)高效知識遷移,從而提升模型的泛化能力。通過將元學習的訓練策略融入知識蒸餾,結合任務導向的知識蒸餾和多個教師網(wǎng)絡的協(xié)作,不僅為學生網(wǎng)絡提供了豐富有效的知識,還增強了學生網(wǎng)絡對少樣本任務的快速適應能力。從教師網(wǎng)絡的輸出預測和樣本關系兩方面提取監(jiān)督信息,使用深度融合將多個教師網(wǎng)絡的輸出融合,用于指導學生網(wǎng)絡的訓練,提升知識蒸餾的效率。因此,本發(fā)明能夠更好地將多個源域的有效知識遷移到目標域,顯著提高學生網(wǎng)絡在目標少樣本任務上的分類準確率。