本發(fā)明涉及深度學(xué)習(xí),尤其涉及一種基于知識(shí)蒸餾的深度融合多跨域少樣本分類方法。
背景技術(shù):
1、目前,深度學(xué)習(xí)在計(jì)算機(jī)視覺領(lǐng)域已經(jīng)取得了較大成功,例如物體分類、圖像檢索和動(dòng)作識(shí)別等任務(wù)。深度學(xué)習(xí)的成功在很大程度上依賴于海量的數(shù)據(jù)和強(qiáng)大的計(jì)算資源。然而許多識(shí)別分辨和機(jī)器學(xué)習(xí)任務(wù)中,人類往往可以從很少的例子中識(shí)別新物體或視覺概念,這種快速學(xué)習(xí)的能力是現(xiàn)在的深度學(xué)習(xí)所不具備的。因此,如何通過有限的標(biāo)記數(shù)據(jù)來學(xué)習(xí)識(shí)別新類別引起了人們的廣泛關(guān)注,這也是少樣本學(xué)習(xí)(few-shot?learning)所要解決的核心問題。近幾年來,大量少樣本學(xué)習(xí)的工作都采用了元學(xué)習(xí)(meta?learning)的思想,其基本思路是在訓(xùn)練階段通過訓(xùn)練一個(gè)能夠快速適應(yīng)新任務(wù)的模型,從而在測(cè)試階段可以高效地使用少量的樣本來學(xué)習(xí)新任務(wù)。然而,大多數(shù)的少樣本學(xué)習(xí)方法依賴于與目標(biāo)任務(wù)域分布相同的輔助數(shù)據(jù)。為了解決這個(gè)問題,研究者提出了多跨域少樣本學(xué)習(xí)(multiple?cross-domain?few-shot?learning,mcd-fsl),旨在利用來自多個(gè)不同源域的數(shù)據(jù)來提升模型性能。tseng等人通過在多個(gè)源域上進(jìn)行實(shí)驗(yàn)驗(yàn)證了該方法可以帶來有效的改進(jìn)。然而,triantafillou等人也指出,簡(jiǎn)單的跨域?qū)W習(xí)并不總是能達(dá)到預(yù)期的效果。這是因?yàn)槎鄠€(gè)源域之間的數(shù)據(jù)干擾,導(dǎo)致了負(fù)遷移現(xiàn)象。為了解決這個(gè)問題,部分研究采用知識(shí)蒸餾來捕捉正確的源域信息。
2、知識(shí)蒸餾(knowledge?distillation,kd)是一種深度學(xué)習(xí)技術(shù),旨在將一個(gè)或多個(gè)復(fù)雜的教師網(wǎng)絡(luò)中的知識(shí)轉(zhuǎn)移到一個(gè)較簡(jiǎn)單的學(xué)生網(wǎng)絡(luò)中。學(xué)生網(wǎng)絡(luò)通常較小、速度更快且資源效率更高。盡管學(xué)生網(wǎng)絡(luò)的規(guī)模較小,但它可以保持與教師網(wǎng)絡(luò)相當(dāng)?shù)男阅?,甚至提高其泛化能力。傳統(tǒng)的知識(shí)蒸餾方法主要依賴于一個(gè)強(qiáng)大的教師網(wǎng)絡(luò)。hinton等人首先提出了知識(shí)蒸餾的概念,他們證明了一個(gè)強(qiáng)大的教師網(wǎng)絡(luò)能夠有效地指導(dǎo)學(xué)生網(wǎng)絡(luò)的訓(xùn)練。然而,近年來的研究發(fā)現(xiàn),使用多個(gè)教師網(wǎng)絡(luò)進(jìn)行知識(shí)蒸餾,可以進(jìn)一步提升學(xué)生網(wǎng)絡(luò)的性能。例如,通過平均多個(gè)教師網(wǎng)絡(luò)的輸出,可以得到一個(gè)更穩(wěn)健的指導(dǎo)信號(hào)。然而,簡(jiǎn)單地對(duì)多個(gè)教師網(wǎng)絡(luò)的輸出進(jìn)行平均,可能會(huì)忽略每個(gè)教師網(wǎng)絡(luò)在不同任務(wù)上的不同貢獻(xiàn),進(jìn)而誤導(dǎo)學(xué)生網(wǎng)絡(luò)的學(xué)習(xí)。為了更好地結(jié)合跨域少樣本學(xué)習(xí)和知識(shí)蒸餾的方法,研究者們提出了一些新的策略。例如,研究者們引入了自適應(yīng)的蒸餾溫度和損失權(quán)重,使學(xué)生網(wǎng)絡(luò)能夠更加有效地從多個(gè)教師網(wǎng)絡(luò)中學(xué)習(xí)。在訓(xùn)練過程中,研究者們先在多個(gè)源域上訓(xùn)練多個(gè)教師網(wǎng)絡(luò),然后將這些教師網(wǎng)絡(luò)中的知識(shí)通過蒸餾的方式轉(zhuǎn)移到學(xué)生網(wǎng)絡(luò)中。為了確保學(xué)生網(wǎng)絡(luò)能夠從不同源域中學(xué)習(xí)到有用的信息,研究者們?cè)O(shè)計(jì)了一種跨域蒸餾損失函數(shù)。這種損失函數(shù)能夠幫助學(xué)生網(wǎng)絡(luò)在多個(gè)源域之間進(jìn)行有效的知識(shí)遷移,從而提升其在目標(biāo)域上的表現(xiàn)。
3、盡管這些方法在實(shí)驗(yàn)中顯示出了一定的有效性,但當(dāng)前的方法仍然存在一些問題和挑戰(zhàn)。首先,多源域數(shù)據(jù)之間的分布差異可能會(huì)導(dǎo)致網(wǎng)絡(luò)模型在融合不同源域信息時(shí)出現(xiàn)困難,影響模型的泛化能力。即使通過知識(shí)蒸餾的方法對(duì)不同源域的數(shù)據(jù)進(jìn)行整合,仍然無法完全消除源域之間的分布差異對(duì)模型性能的負(fù)面影響。另外,知識(shí)蒸餾過程中涉及的多教師網(wǎng)絡(luò)之間可能存在沖突,導(dǎo)致學(xué)生網(wǎng)絡(luò)無法有效學(xué)習(xí)所有教師網(wǎng)絡(luò)的知識(shí)。這需要在蒸餾過程中引入更復(fù)雜的權(quán)重調(diào)整機(jī)制,以平衡不同教師網(wǎng)絡(luò)對(duì)學(xué)生網(wǎng)絡(luò)的影響。在跨域?qū)W習(xí)中,還經(jīng)常涉及到多任務(wù)學(xué)習(xí)的問題,即網(wǎng)絡(luò)模型需要同時(shí)學(xué)習(xí)和處理多個(gè)不同的任務(wù)。如何有效地協(xié)調(diào)和優(yōu)化這些任務(wù)之間的關(guān)系,以最大化整體性能,也是一個(gè)值得思考的問題。
技術(shù)實(shí)現(xiàn)思路
1、本發(fā)明要解決的技術(shù)問題是針對(duì)上述現(xiàn)有技術(shù)的不足,提供一種基于知識(shí)蒸餾的深度融合多跨域少樣本分類方法,實(shí)現(xiàn)多跨域少樣本分類。
2、為解決上述技術(shù)問題,本發(fā)明所采取的技術(shù)方案是:一種基于知識(shí)蒸餾的深度融合多跨域少樣本分類方法,包括以下步驟:
3、步驟1:預(yù)訓(xùn)練教師和學(xué)生網(wǎng)絡(luò);
4、分別利用n個(gè)不同源域的訓(xùn)練集{z1,z2,…,zn}來訓(xùn)練n個(gè)不同的教師網(wǎng)絡(luò),每個(gè)教師網(wǎng)絡(luò)都采用相同的網(wǎng)絡(luò)結(jié)構(gòu),包含一個(gè)教師特征編碼器et和一個(gè)線性分類器ct,初始化n個(gè)教師網(wǎng)絡(luò)參數(shù),以傳統(tǒng)監(jiān)督訓(xùn)練的方式利用交叉熵?fù)p失對(duì)每一個(gè)教師網(wǎng)絡(luò)進(jìn)行預(yù)訓(xùn)練,最終得到n個(gè)訓(xùn)練好的教師網(wǎng)絡(luò);利用一個(gè)多樣性數(shù)據(jù)集來預(yù)訓(xùn)練學(xué)生網(wǎng)絡(luò)sp,學(xué)生網(wǎng)絡(luò)包含一個(gè)特征編碼器es_p和一個(gè)基于距離度量的分類器cs,以傳統(tǒng)監(jiān)督訓(xùn)練的方式利用交叉熵?fù)p失對(duì)學(xué)生網(wǎng)絡(luò)進(jìn)行預(yù)訓(xùn)練,得到預(yù)訓(xùn)練的學(xué)生網(wǎng)絡(luò)sp;
5、利用交叉熵?fù)p失函數(shù)對(duì)每一個(gè)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)進(jìn)行預(yù)訓(xùn)練,包括:
6、(1)從第n個(gè)源域訓(xùn)練集zn中隨機(jī)選取一定量的數(shù)據(jù)作為第n個(gè)教師特征編碼器或?qū)W生網(wǎng)絡(luò)的輸入,經(jīng)過編碼得到第i個(gè)樣本圖像的視覺特征
7、
8、其中,為第n個(gè)源域訓(xùn)練集中的第i個(gè)樣本圖像,en為第n個(gè)特征編碼器;
9、(2)將第i個(gè)樣本圖像的視覺特征輸入第n個(gè)教師分類器cn,得到第n個(gè)源域訓(xùn)練集中的第i個(gè)樣本圖像的類別預(yù)測(cè)概率:
10、
11、其中,為第i個(gè)樣本圖像屬于第w個(gè)類別的預(yù)測(cè)概率;
12、(3)設(shè)定教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)的目標(biāo)函數(shù)ln公式如下:
13、
14、其中,為第i個(gè)樣本圖像的真實(shí)標(biāo)簽,w為第n個(gè)源域訓(xùn)練集中的樣本類別數(shù);
15、(4)根據(jù)公式(2)訓(xùn)練第n個(gè)特征編碼器en和第n個(gè)分類器cn,保留使公式(3)的誤差值最小的第n個(gè)特征編碼器en和第n個(gè)分類器cn;
16、(5)重復(fù)第(1)步~第(4)步,得到預(yù)訓(xùn)練好的n個(gè)教師網(wǎng)絡(luò)和一個(gè)學(xué)生網(wǎng)絡(luò)sp;
17、步驟2:利用預(yù)訓(xùn)練的學(xué)生網(wǎng)絡(luò)sp構(gòu)建學(xué)生網(wǎng)絡(luò)s;
18、所述學(xué)生網(wǎng)絡(luò)s是一種基于度量的少樣本模型——原型網(wǎng)絡(luò),包含一個(gè)學(xué)生特征編碼器es和一個(gè)度量函數(shù)d,通過預(yù)訓(xùn)練的學(xué)生網(wǎng)絡(luò)特征編碼器es_p初始化學(xué)生特征編碼器es;
19、步驟3:元訓(xùn)練階段,從n個(gè)不同源域的訓(xùn)練集中隨機(jī)選取一個(gè)訓(xùn)練集作為當(dāng)前的元訓(xùn)練集dtrain,根據(jù)元學(xué)習(xí)的思想,從當(dāng)前的元訓(xùn)練集中隨機(jī)采樣m個(gè)少樣本任務(wù),每個(gè)任務(wù)都包含一個(gè)支持集s和一個(gè)查詢集q,支持集s中含有w個(gè)類別的樣本圖像數(shù)據(jù),每個(gè)類別有k個(gè)樣本;查詢集q中含有w個(gè)類別的樣本圖像;
20、步驟4:將不同的少樣本任務(wù)同時(shí)送到n個(gè)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)s中;
21、步驟5:依次將支持集s中第k個(gè)樣本圖像xk輸入到n個(gè)教師特征編碼器和學(xué)生特征編碼器es中,分別得到相對(duì)應(yīng)的視覺特征和
22、
23、
24、其中,xk為支持集s中第k個(gè)樣本圖像,為第n個(gè)教師特征編碼器,n=1,2,…,n,es為學(xué)生特征編碼器,為第n個(gè)教師特征編碼器en對(duì)xk編碼后輸出的視覺特征,為學(xué)生特征編碼器es對(duì)xk編碼后輸出的視覺特征;
25、步驟6:在每一個(gè)源域中,分別對(duì)支持集中屬于同一類別的樣本圖像對(duì)應(yīng)的視覺特征取平均,得到每個(gè)類別的原型表示為:
26、
27、
28、其中,k為第w個(gè)類別的樣本總數(shù),w=1,2,…,w,為經(jīng)過第n個(gè)教師特征編碼器編碼后的第w個(gè)類別的原型表示,為經(jīng)過學(xué)生編碼器編碼后的第w個(gè)類別的原型表示;
29、步驟7:依次將查詢集q中的樣本圖像xq輸入到n個(gè)教師特征編碼器和學(xué)生特征編碼器es中,分別得到相對(duì)應(yīng)的視覺特征和
30、
31、
32、其中,xq為查詢集q中的樣本圖像,為第n個(gè)教師特征編碼器對(duì)xq編碼后輸出的視覺特征,為學(xué)生特征編碼器es對(duì)xq編碼后輸出的視覺特征;
33、步驟8:對(duì)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)s得到的支持集中每個(gè)類別的原型表示和查詢集中樣本圖像對(duì)應(yīng)的視覺特征分別進(jìn)行歸一化處理,得到對(duì)應(yīng)的歸一化結(jié)果:
34、
35、
36、
37、
38、其中,為歸一化后的經(jīng)過第n個(gè)教師特征編碼器編碼后的第w個(gè)類別的原型表示,為歸一化后的第n個(gè)教師特征編碼器對(duì)xq編碼后輸出的視覺特征,為歸一化后的經(jīng)過學(xué)生編碼器es編碼后的第w個(gè)類別的原型表示向量,為歸一化后的學(xué)生特征編碼器es對(duì)xq編碼后輸出的視覺特征;
39、步驟9:通過歸一化后的教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)中查詢樣本對(duì)應(yīng)的視覺特征和每一類的原型表示,基于相似度分別計(jì)算在教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)中每個(gè)查詢樣本屬于每個(gè)類別的預(yù)測(cè)概率矩陣和slogits:
40、
41、
42、其中,表示第n個(gè)教師網(wǎng)絡(luò)中查詢樣本對(duì)應(yīng)的視覺特征和原型表示之間基于相似度的預(yù)測(cè)概率矩陣,slogits表示學(xué)生網(wǎng)絡(luò)中查詢樣本對(duì)應(yīng)的視覺特征和原型表示之間基于相似度的預(yù)測(cè)概率矩陣,為同一個(gè)教師特征編碼器各個(gè)原型表示,為學(xué)生特征編碼器各個(gè)原型表示;
43、步驟10:設(shè)定查詢集q中樣本的真實(shí)標(biāo)簽為yquery,根據(jù)查詢集的真實(shí)標(biāo)簽yquery和學(xué)生網(wǎng)絡(luò)中查詢樣本對(duì)應(yīng)的視覺特征和原型表示之間基于相似度的預(yù)測(cè)概率矩陣slogits,計(jì)算學(xué)生網(wǎng)絡(luò)的分類損失,設(shè)定學(xué)生網(wǎng)絡(luò)的分類目標(biāo)函數(shù)lcls如下:
44、
45、其中,q是查詢集中的樣本總數(shù),w是類別總數(shù),是第i個(gè)查詢樣本的預(yù)測(cè)概率向量,是第i個(gè)查詢樣本的真實(shí)標(biāo)簽,d(,)為基于歐式距離的度量函數(shù);
46、步驟11:根據(jù)經(jīng)過n個(gè)教師特征編碼器和學(xué)生特征編碼器編碼后的原型表示和查詢集樣本對(duì)應(yīng)的視覺特征計(jì)算類別預(yù)測(cè)概率矩陣,從而在n個(gè)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)之間進(jìn)行基于軟標(biāo)簽的知識(shí)蒸餾,得到學(xué)生網(wǎng)絡(luò)的基于軟標(biāo)簽的目標(biāo)函數(shù)lkl;
47、1)將n個(gè)教師網(wǎng)絡(luò)輸出的查詢集樣本圖像xq與每一類的原型表示之間基于相似度的預(yù)測(cè)概率矩陣進(jìn)行深度融合,作為訓(xùn)練學(xué)生網(wǎng)絡(luò)的目標(biāo)之一:
48、
49、其中ffc表示全連接網(wǎng)絡(luò),fmean表示平均函數(shù),表示n個(gè)教師網(wǎng)絡(luò)中查詢樣本和每一類的原型表示之間基于相似度的預(yù)測(cè)概率矩陣,⊕,⊙,||分別表示點(diǎn)加,點(diǎn)乘,級(jí)聯(lián),tfusion表示n個(gè)教師網(wǎng)絡(luò)融合后的查詢集樣本圖像xq特征與每一類的原型表示之間基于相似度的預(yù)測(cè)概率矩陣;
50、2)為了使學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的輸出一致,設(shè)定學(xué)生網(wǎng)絡(luò)的基于預(yù)測(cè)概率矩陣的目標(biāo)函數(shù)lkl如下:
51、
52、
53、其中,τ是知識(shí)蒸餾的溫度系數(shù),q是查詢樣本集中的樣本總數(shù),是學(xué)生網(wǎng)絡(luò)提取的第i個(gè)查詢樣本對(duì)應(yīng)的視覺特征與每一類的原型表示之間基于相似度的預(yù)測(cè)概率向量,表示n個(gè)教師網(wǎng)絡(luò)融合后的第i個(gè)查詢樣本對(duì)應(yīng)的視覺特征與每一類的原型表示之間基于相似度的預(yù)測(cè)概率向量;
54、步驟12:根據(jù)經(jīng)過n個(gè)教師特征編碼器和學(xué)生特征編碼器編碼后的原型表示和查詢集樣本對(duì)應(yīng)的視覺特征,使學(xué)生網(wǎng)絡(luò)生成的特征圖的空間分布與教師網(wǎng)絡(luò)生成的特征圖對(duì)齊,將教師網(wǎng)絡(luò)的特征圖作為信息傳遞的參考,從而在n個(gè)教師網(wǎng)絡(luò)和學(xué)生網(wǎng)絡(luò)之間進(jìn)行基于注意力的知識(shí)蒸餾,得到學(xué)生網(wǎng)絡(luò)基于特征圖的目標(biāo)函數(shù)lattention;
55、s1:根據(jù)經(jīng)過n個(gè)教師特征編碼器和學(xué)生特征編碼器編碼后的支持集樣本和查詢集樣本對(duì)應(yīng)的視覺特征,得到經(jīng)過同一個(gè)教師或?qū)W生編碼器編碼后的視覺特征:
56、
57、
58、其中,為第n個(gè)教師特征編碼器en對(duì)xq編碼后輸出的視覺特征,為學(xué)生特征編碼器es對(duì)xq編碼后輸出的視覺特征,為經(jīng)過第n個(gè)教師特征編碼器編碼的支持集和查詢集樣本對(duì)應(yīng)的視覺特征合并后的結(jié)果,vs為經(jīng)過學(xué)生特征編碼器編碼的支持集和查詢集對(duì)應(yīng)的視覺特征合并后的結(jié)果;
59、s2:將n個(gè)教師特征編碼器對(duì)支持集和查詢集樣本編碼得到的視覺特征進(jìn)行深度融合,作為訓(xùn)練學(xué)生網(wǎng)絡(luò)的目標(biāo)之一:
60、
61、其中,v1,…,vn表示n個(gè)教師特征編碼器合并后的支持集和查詢集樣本的視覺特征,表示對(duì)n個(gè)教師特征編碼器合并后的支持集和查詢集樣本的視覺特征進(jìn)行深度融合后的結(jié)果;
62、s3:為了使學(xué)生網(wǎng)絡(luò)與教師網(wǎng)絡(luò)的輸出一致,通過支持集和查詢集樣本經(jīng)過學(xué)生網(wǎng)絡(luò)提取的視覺特征和融合后的多個(gè)教師網(wǎng)絡(luò)提取的視覺特征,設(shè)定學(xué)生網(wǎng)絡(luò)的基于注意力的目標(biāo)函數(shù)lattention如下:
63、
64、
65、
66、其中,g是支持集和查詢集樣本總數(shù),是支持集和查詢集樣本集合中第i個(gè)樣本經(jīng)過學(xué)生特征編碼器得到的特征向量,是支持集和查詢集樣本集合中第i個(gè)樣本融合后的視覺特征向量,p是用于注意力轉(zhuǎn)移函數(shù)的超參數(shù);
67、步驟13:根據(jù)如下學(xué)生網(wǎng)絡(luò)的總目標(biāo)函數(shù)公式,使用sgd算法訓(xùn)練學(xué)生特征編碼器:
68、l=λ1×lcls+λ2×lkl+λ3×lattention???(26)
69、其中,l為學(xué)生網(wǎng)絡(luò)的總目標(biāo)函數(shù),lcls為學(xué)生網(wǎng)絡(luò)的分類目標(biāo)函數(shù),lkl為學(xué)生網(wǎng)絡(luò)的基于軟標(biāo)簽的目標(biāo)函數(shù),lattention為學(xué)生網(wǎng)絡(luò)的基于特征圖的目標(biāo)函數(shù),λ1,λ2,λ3為權(quán)重系數(shù);
70、步驟14:重復(fù)步驟3-13,直至總目標(biāo)函數(shù)值逐漸收斂且趨于不變時(shí),得到訓(xùn)練好的學(xué)生網(wǎng)絡(luò);
71、步驟15:測(cè)試階段,給定一個(gè)不同于n個(gè)預(yù)訓(xùn)練源域的數(shù)據(jù)集作為目標(biāo)域,依次將來自目標(biāo)域測(cè)試集的支持集和查詢集的樣本圖像輸入到訓(xùn)練好的學(xué)生特征編碼器es中,得到相應(yīng)的視覺特征,按照公式(6)(7)計(jì)算支持集中各個(gè)類別的原型表示,再按照公式(14)(15)計(jì)算查詢集樣本圖像屬于各個(gè)類別的概率,將計(jì)算得到的概率中最大的概率所對(duì)應(yīng)的類別,作為查詢集樣本圖像分類結(jié)果。
72、采用上述技術(shù)方案所產(chǎn)生的有益效果在于:本發(fā)明提供的一種基于知識(shí)蒸餾的深度融合跨域少樣本分類方法,利用知識(shí)蒸餾的師生網(wǎng)絡(luò)框架實(shí)現(xiàn)高效知識(shí)遷移,從而提升模型的泛化能力。通過將元學(xué)習(xí)的訓(xùn)練策略融入知識(shí)蒸餾,結(jié)合任務(wù)導(dǎo)向的知識(shí)蒸餾和多個(gè)教師網(wǎng)絡(luò)的協(xié)作,不僅為學(xué)生網(wǎng)絡(luò)提供了豐富有效的知識(shí),還增強(qiáng)了學(xué)生網(wǎng)絡(luò)對(duì)少樣本任務(wù)的快速適應(yīng)能力。從教師網(wǎng)絡(luò)的輸出預(yù)測(cè)和樣本關(guān)系兩方面提取監(jiān)督信息,使用深度融合將多個(gè)教師網(wǎng)絡(luò)的輸出融合,用于指導(dǎo)學(xué)生網(wǎng)絡(luò)的訓(xùn)練,提升知識(shí)蒸餾的效率。因此,本發(fā)明能夠更好地將多個(gè)源域的有效知識(shí)遷移到目標(biāo)域,顯著提高學(xué)生網(wǎng)絡(luò)在目標(biāo)少樣本任務(wù)上的分類準(zhǔn)確率。