深度學(xué)習(xí)使用RNNLSTM分類MNIST數(shù)據(jù)集_第1頁
深度學(xué)習(xí)使用RNNLSTM分類MNIST數(shù)據(jù)集_第2頁
深度學(xué)習(xí)使用RNNLSTM分類MNIST數(shù)據(jù)集_第3頁
全文預(yù)覽已結(jié)束

下載本文檔

版權(quán)說明:本文檔由用戶提供并上傳,收益歸屬內(nèi)容提供方,若內(nèi)容存在侵權(quán),請進行舉報或認(rèn)領(lǐng)

文檔簡介

1、深度學(xué)習(xí)-使用RNN-LSTM分類MNIST數(shù)據(jù)集傳統(tǒng)CNN有一個主要特點,就是沒有記憶,他們單獨處理每個輸入,在輸入和輸入之間沒有保存任何狀態(tài),對于這樣的網(wǎng)絡(luò),有些事是無法較好實現(xiàn)的。比如像閱讀一篇文章或者一個句子,需要不斷地對輸入信息進行整合理解,才能更好的解決問題,這時候就需要用到RNN(循環(huán)神經(jīng)網(wǎng)絡(luò))RNN簡單理解就是一個序列當(dāng)前的輸出與前面的輸出也有關(guān)。具體的表現(xiàn)形式為網(wǎng)絡(luò)會對前面的信息進行記憶并應(yīng)用于當(dāng)前輸出的計算中,即隱藏層之間的結(jié)點不再無連接而是有連接的,并且隱藏層的輸入不僅包括輸入層的輸出還包括上一時刻隱藏層的輸出。htht輸入層延遲器輸出層理藏層XiRNN理論上來說,應(yīng)該能

2、記住許多步之前見過的信息,但實際上,由于距離過遠(yuǎn),當(dāng)梯度小于零時,較為久遠(yuǎn)的位置容易出現(xiàn)梯度消失,當(dāng)梯度大于零時,較為久遠(yuǎn)的位置容易出現(xiàn)梯度爆炸,所以就有了LSTM。可以簡單理解為,主線部分,用于存儲長期記憶的,分線部分,用于存儲短期記憶的,方便序列中的信息可以在任意位置跳上傳送帶,被傳送到更晚的實踐步。下邊一個例子基于前一篇的CNN,實現(xiàn)同樣的功能,只不過我在打印準(zhǔn)確率的時候一直出錯,所以把那一部分代碼給去掉了。稍微整理了下但還是不是理解的很透徹。importtorchimporttorchvisionfromtorchimportnnfromtorch.utils.dataimportDa

3、taLoader”超參數(shù)”EPOCH=1#一共訓(xùn)練多少次BATCH_SIZE=64#每批的訓(xùn)練個數(shù)LR=0.01#學(xué)習(xí)率DOWNLOAD_MNIST=False#是否下載,第一次執(zhí)行為True,后邊就False”下載圖片數(shù)據(jù)”train_data=torchvision.datasets.MNIST(root=./mnist/,#保存位置train=True,#是否為訓(xùn)練數(shù)據(jù)transform=torchvision.transforms.ToTensor(),#轉(zhuǎn)換為tensor形式,由(0,255)轉(zhuǎn)換為(0,1)download=DOWNLOAD_MNIST,#是否下載,第一次執(zhí)行為Tr

4、ue,后邊就False)test_data=torchvision.datasets.MNIST(root=./mnist/,train=False)”處理數(shù)據(jù)”train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)#數(shù)據(jù)加載器,打亂數(shù)據(jù)分批訓(xùn)練test_x=torch.unsqueeze(test_data.data,dim=1).type(torch.FloatTensor):2000/255.#格式從(2000,28,28)轉(zhuǎn)換為(2000,1,28,28),值為(0,1)test_y=

5、test_data.targets:2000#只測試前兩千個”創(chuàng)建神經(jīng)網(wǎng)絡(luò)模型”classRNN(nn.Module):def_init_(self):super(RNN,self)._init_()self.rnn=nn.LSTM(#使用LSTM形式input_size=28,#圖片每行的像素點個數(shù)hidden_size=64,#隱藏層中單元個數(shù)num_layers=1,#RNN層數(shù),層數(shù)越多效果越好但是計算量大batch_first=True,#(batch,time_step,input_size),將batch放在第一維)self.out=nn.Linear(64,10)#輸出層def

6、forward(self,x):r_out,(h_n,h_c)=self.rnn(x,None)out=self.out(r_out:,-1,:)#選取最后一個時間點的outputreturnoutrnn=RNN()optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)#使用Adam訓(xùn)練loss_func=nn.CrossEntropyLoss()#CrossEntropyLoss(用于處理多分類問題”訓(xùn)練神經(jīng)網(wǎng)絡(luò)模型”forepochinrange(EPOCH):forstep,(x,b_y)inenumerate(train_loader):#givesbatchdatab_x=x.view(-1,28,28)#reshapexto(batch,time_step,input_size)output=rnn(b_x)#生成結(jié)果loss=loss_func(output,b_y)#統(tǒng)計損失optimizer.zero_grad()#清空當(dāng)前梯度loss.backward()#反向傳播計算梯度optimizer.step()#以學(xué)習(xí)效率0.001來優(yōu)化梯度”打印數(shù)據(jù)”test_output=rnn(test_x:10.view(-1,28,28)#取前十個進行輸出pred_y=torch.max(test_output,

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯(lián)系上傳者。文件的所有權(quán)益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網(wǎng)頁內(nèi)容里面會有圖紙預(yù)覽,若沒有圖紙預(yù)覽就沒有圖紙。
  • 4. 未經(jīng)權(quán)益所有人同意不得將文件中的內(nèi)容挪作商業(yè)或盈利用途。
  • 5. 人人文庫網(wǎng)僅提供信息存儲空間,僅對用戶上傳內(nèi)容的表現(xiàn)方式做保護處理,對用戶上傳分享的文檔內(nèi)容本身不做任何修改或編輯,并不能對任何下載內(nèi)容負(fù)責(zé)。
  • 6. 下載文件中如有侵權(quán)或不適當(dāng)內(nèi)容,請與我們聯(lián)系,我們立即糾正。
  • 7. 本站不保證下載資源的準(zhǔn)確性、安全性和完整性, 同時也不承擔(dān)用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論