深度學習使用RNNLSTM分類MNIST數據集_第1頁
深度學習使用RNNLSTM分類MNIST數據集_第2頁
深度學習使用RNNLSTM分類MNIST數據集_第3頁
全文預覽已結束

下載本文檔

版權說明:本文檔由用戶提供并上傳,收益歸屬內容提供方,若內容存在侵權,請進行舉報或認領

文檔簡介

1、深度學習-使用RNN-LSTM分類MNIST數據集傳統CNN有一個主要特點,就是沒有記憶,他們單獨處理每個輸入,在輸入和輸入之間沒有保存任何狀態,對于這樣的網絡,有些事是無法較好實現的。比如像閱讀一篇文章或者一個句子,需要不斷地對輸入信息進行整合理解,才能更好的解決問題,這時候就需要用到RNN(循環神經網絡)RNN簡單理解就是一個序列當前的輸出與前面的輸出也有關。具體的表現形式為網絡會對前面的信息進行記憶并應用于當前輸出的計算中,即隱藏層之間的結點不再無連接而是有連接的,并且隱藏層的輸入不僅包括輸入層的輸出還包括上一時刻隱藏層的輸出。htht輸入層延遲器輸出層理藏層XiRNN理論上來說,應該能

2、記住許多步之前見過的信息,但實際上,由于距離過遠,當梯度小于零時,較為久遠的位置容易出現梯度消失,當梯度大于零時,較為久遠的位置容易出現梯度爆炸,所以就有了LSTM。可以簡單理解為,主線部分,用于存儲長期記憶的,分線部分,用于存儲短期記憶的,方便序列中的信息可以在任意位置跳上傳送帶,被傳送到更晚的實踐步。下邊一個例子基于前一篇的CNN,實現同樣的功能,只不過我在打印準確率的時候一直出錯,所以把那一部分代碼給去掉了。稍微整理了下但還是不是理解的很透徹。importtorchimporttorchvisionfromtorchimportnnfromtorch.utils.dataimportDa

3、taLoader”超參數”EPOCH=1#一共訓練多少次BATCH_SIZE=64#每批的訓練個數LR=0.01#學習率DOWNLOAD_MNIST=False#是否下載,第一次執行為True,后邊就False”下載圖片數據”train_data=torchvision.datasets.MNIST(root=./mnist/,#保存位置train=True,#是否為訓練數據transform=torchvision.transforms.ToTensor(),#轉換為tensor形式,由(0,255)轉換為(0,1)download=DOWNLOAD_MNIST,#是否下載,第一次執行為Tr

4、ue,后邊就False)test_data=torchvision.datasets.MNIST(root=./mnist/,train=False)”處理數據”train_loader=DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)#數據加載器,打亂數據分批訓練test_x=torch.unsqueeze(test_data.data,dim=1).type(torch.FloatTensor):2000/255.#格式從(2000,28,28)轉換為(2000,1,28,28),值為(0,1)test_y=

5、test_data.targets:2000#只測試前兩千個”創建神經網絡模型”classRNN(nn.Module):def_init_(self):super(RNN,self)._init_()self.rnn=nn.LSTM(#使用LSTM形式input_size=28,#圖片每行的像素點個數hidden_size=64,#隱藏層中單元個數num_layers=1,#RNN層數,層數越多效果越好但是計算量大batch_first=True,#(batch,time_step,input_size),將batch放在第一維)self.out=nn.Linear(64,10)#輸出層def

6、forward(self,x):r_out,(h_n,h_c)=self.rnn(x,None)out=self.out(r_out:,-1,:)#選取最后一個時間點的outputreturnoutrnn=RNN()optimizer=torch.optim.Adam(rnn.parameters(),lr=LR)#使用Adam訓練loss_func=nn.CrossEntropyLoss()#CrossEntropyLoss(用于處理多分類問題”訓練神經網絡模型”forepochinrange(EPOCH):forstep,(x,b_y)inenumerate(train_loader):#givesbatchdatab_x=x.view(-1,28,28)#reshapexto(batch,time_step,input_size)output=rnn(b_x)#生成結果loss=loss_func(output,b_y)#統計損失optimizer.zero_grad()#清空當前梯度loss.backward()#反向傳播計算梯度optimizer.step()#以學習效率0.001來優化梯度”打印數據”test_output=rnn(test_x:10.view(-1,28,28)#取前十個進行輸出pred_y=torch.max(test_output,

溫馨提示

  • 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
  • 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯系上傳者。文件的所有權益歸上傳用戶所有。
  • 3. 本站RAR壓縮包中若帶圖紙,網頁內容里面會有圖紙預覽,若沒有圖紙預覽就沒有圖紙。
  • 4. 未經權益所有人同意不得將文件中的內容挪作商業或盈利用途。
  • 5. 人人文庫網僅提供信息存儲空間,僅對用戶上傳內容的表現方式做保護處理,對用戶上傳分享的文檔內容本身不做任何修改或編輯,并不能對任何下載內容負責。
  • 6. 下載文件中如有侵權或不適當內容,請與我們聯系,我們立即糾正。
  • 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。

評論

0/150

提交評論