




版權說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權,請進行舉報或認領
文檔簡介
第Pytorch從0實現(xiàn)Transformer的實踐目錄摘要一、構造數(shù)據(jù)1.1句子長度1.2生成句子1.3生成字典1.4得到向量化的句子二、位置編碼2.1計算括號內(nèi)的值2.2得到位置編碼三、多頭注意力3.1selfmask
摘要
Withthecontinuousdevelopmentoftimeseriesprediction,Transformer-likemodelshavegraduallyreplacedtraditionalmodelsinthefieldsofCVandNLPbyvirtueoftheirpowerfuladvantages.Amongthem,theInformerisfarsuperiortothetraditionalRNNmodelinlong-termprediction,andtheSwinTransformerissignificantlystrongerthanthetraditionalCNNmodelinimagerecognition.AdeepgraspofTransformerhasbecomeaninevitablerequirementinthefieldofartificialintelligence.ThisarticlewillusethePytorchframeworktoimplementthepositionencoding,multi-headattentionmechanism,self-mask,causalmaskandotherfunctionsinTransformer,andbuildaTransformernetworkfrom0.
隨著時序預測的不斷發(fā)展,Transformer類模型憑借強大的優(yōu)勢,在CV、NLP領域逐漸取代傳統(tǒng)模型。其中Informer在長時序預測上遠超傳統(tǒng)的RNN模型,SwinTransformer在圖像識別上明顯強于傳統(tǒng)的CNN模型。深層次掌握Transformer已經(jīng)成為從事人工智能領域的必然要求。本文將用Pytorch框架,實現(xiàn)Transformer中的位置編碼、多頭注意力機制、自掩碼、因果掩碼等功能,從0搭建一個Transformer網(wǎng)絡。
一、構造數(shù)據(jù)
1.1句子長度
#關于wordembedding,以序列建模為例
#輸入句子有兩個,第一個長度為2,第二個長度為4
src_len=torch.tensor([2,4]).to(32)
#目標句子有兩個。第一個長度為4,第二個長度為3
tgt_len=torch.tensor([4,3]).to(32)
print(src_len)
print(tgt_len)
輸入句子(src_len)有兩個,第一個長度為2,第二個長度為4
目標句子(tgt_len)有兩個。第一個長度為4,第二個長度為3
1.2生成句子
用隨機數(shù)生成句子,用0填充空白位置,保持所有句子長度一致
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_src_words,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,max_num_tgt_words,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
print(src_seq)
print(tgt_seq)
src_seq為輸入的兩個句子,tgt_seq為輸出的兩個句子。
為什么句子是數(shù)字?在做中英文翻譯時,每個中文或英文對應的也是一個數(shù)字,只有這樣才便于處理。
1.3生成字典
在該字典中,總共有8個字(行),每個字對應8維向量(做了簡化了的)。注意在實際應用中,應當有幾十萬個字,每個字可能有512個維度。
#構造wordembedding
src_embedding_table=nn.Embedding(9,model_dim)
tgt_embedding_table=nn.Embedding(9,model_dim)
#輸入單詞的字典
print(src_embedding_table)
#目標單詞的字典
print(tgt_embedding_table)
字典中,需要留一個維度給classtoken,故是9行。
1.4得到向量化的句子
通過字典取出1.2中得到的句子
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
該階段總程序
importtorch
#句子長度
src_len=torch.tensor([2,4]).to(32)
tgt_len=torch.tensor([4,3]).to(32)
#構造句子,用0填充空白處
src_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(src_len)-L)),0)forLinsrc_len])
tgt_seq=torch.cat([torch.unsqueeze(F.pad(torch.randint(1,8,(L,)),(0,max(tgt_len)-L)),0)forLintgt_len])
#構造字典
src_embedding_table=nn.Embedding(9,8)
tgt_embedding_table=nn.Embedding(9,8)
#得到向量化的句子
src_embedding=src_embedding_table(src_seq)
tgt_embedding=tgt_embedding_table(tgt_seq)
print(src_embedding)
print(tgt_embedding)
二、位置編碼
位置編碼是transformer的一個重點,通過加入transformer位置編碼,代替了傳統(tǒng)RNN的時序信息,增強了模型的并發(fā)度。位置編碼的公式如下:(其中pos代表行,i代表列)
2.1計算括號內(nèi)的值
#得到分子pos的值
pos_mat=torch.arange(4).reshape((-1,1))
#得到分母值
i_mat=torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/8)
print(pos_mat)
print(i_mat)
2.2得到位置編碼
#初始化位置編碼矩陣
pe_embedding_table=torch.zeros(4,8)
#得到偶數(shù)行位置編碼
pe_embedding_table[:,0::2]=torch.sin(pos_mat/i_mat)
#得到奇數(shù)行位置編碼
pe_embedding_table[:,1::2]=torch.cos(pos_mat/i_mat)
pe_embedding=nn.Embedding(4,8)
#設置位置編碼不可更新參數(shù)
pe_embedding.weight=nn.Parameter(pe_embedding_table,requires_grad=False)
print(pe_embedding.weight)
三、多頭注意力
3.1selfmask
有些位置是空白用0填充的,訓練時不希望被這些位置所影響,那么就需要用到selfmask。selfmask的原理是令這些位置的值為無窮小,經(jīng)過softmax后,這些值會變?yōu)?,不會再影響結果。
3.1.1得到有效位置矩陣
#得到有效位置矩陣
vaild_encoder_pos=torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L),(0,max(src_len)-L)),0)forLinsrc_len]),2)
valid_encoder_pos_matrix=torch.bmm(vaild_encoder_pos,vaild_encoder_pos.transpose(1,2))
print(valid_encoder_pos_matrix)
3.1.2得到無效位置矩陣
invalid_encoder_pos_matrix=1-valid_encoder_pos_matrix
mask_encoder_self_attention=invalid_encoder_pos_matrix.to(torch.bool)
print(mask_encoder_self_attention)
True代表需要對該位置mask
3.1.3得到mask矩陣
用極小數(shù)填充需要被mask的位置
#初始化mask矩陣
score=torch.randn(2,max(
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預覽,若沒有圖紙預覽就沒有圖紙。
- 4. 未經(jīng)權益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
- 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負責。
- 6. 下載文件中如有侵權或不適當內(nèi)容,請與我們聯(lián)系,我們立即糾正。
- 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 2025年小學教育學考試卷及答案
- 2025年房地產(chǎn)經(jīng)紀人考試題及答案
- 2025年軟件工程理論與實踐復習試卷及答案
- 2025年心理學基礎知識考試題及答案
- 2025年金融專業(yè)考試試卷及答案
- 跨國法律文件保密碎紙機租賃與售后服務協(xié)議
- 地下綜合管廊建設及運維一體化承包合同
- 區(qū)域獨家品牌授權補充協(xié)議
- 家電品牌維修技師勞務派遣服務合同
- 影視作品網(wǎng)絡播放權獨家代理及收益分成合同
- 養(yǎng)老護理員初級試題庫含參考答案
- 基于云計算的數(shù)據(jù)中心設計與運維
- 2025年社區(qū)居委會試題及答案
- 中西醫(yī)結合內(nèi)科學之循環(huán)系統(tǒng)疾病知到課后答案智慧樹章節(jié)測試答案2025年春湖南中醫(yī)藥大學
- TCHSA 088-2024 口腔頜面修復中三維面部掃描臨床應用指南
- SMT設備安全培訓材料
- 深度解析雙十一消費行為
- 北師大版八年級數(shù)學上冊一次函數(shù)《一次函數(shù)中的三角形面積 》教學課件
- 科技企業(yè)研發(fā)流程的精益化管理
- 《中央空調(diào)原理與維護》課件
- 石油化工壓力管道安裝工藝及質(zhì)量控制重點
評論
0/150
提交評論