




版權說明:本文檔由用戶提供并上傳,收益歸屬內容提供方,若內容存在侵權,請進行舉報或認領
文檔簡介
第12章
深度學習1生成式對抗網絡、殘差神經網絡、孿生神經網絡學習目標生成式對抗網絡原理。利用PyTorch框架實現生成式對抗網絡。殘差、孿生神經網絡的基本原理。利用PyTorch框架實現殘差、孿生神經網絡。12234生成式對抗網絡(GenerativeAdversarialNetwork)是將判別模型與生成模型有機地進行融合的深度神經網絡框架,其包含的生成器(估計數據分布)與判別器(判斷數據真偽)通過相互競爭或對抗的方式提取數據蘊含的內在規律。312.2.5生成式對抗網絡生成式對抗網絡的基本原理為:首先根據隨機噪聲利用生成器生成新圖像,然后利用判別器判別一幅圖像的真實性(如圖像的判別概率為1則為真、為0則為假);在訓練過程中,生成器盡量生成真實圖像以欺騙判別器,而判別器則盡量把生成器生成的圖像與真實圖像分開,兩者進而構成了一個動態的“博弈過程”。最終,生成器可生成“以假亂真”的圖像(判別器難以進行判別或判別概率為1)。412.2.5生成式對抗網絡
512.2.5生成式對抗網絡
612.2.5生成式對抗網絡利用MNIST數據集構建生成對抗網絡產生“以假亂真”的手寫字體圖像。1.導入PyTorch框架庫及torchvision庫importtorchimporttorch.nnasnnfromtorchvisionimportdatasetsfromtorchvisionimporttransformsfromtorchvision.utilsimportsave_imagefromtorch.autogradimportVariablefromtorch.utils.dataimportDataLoader712.2.5生成式對抗網絡2.構造數據img_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,),std=(0.5,))])#預處理操作data_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#加載數據train_dataset=datasets.MNIST(root='./data',train=True,transform=data_transform,download=True)test_dataset=datasets.MNIST(root='./data',train=False,transform=data_transform)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)812.2.5生成式對抗網絡3.定義GAN神經網絡classdiscriminator(nn.Module):def__init__(self):super(discriminator,self).__init__()self.discriminator=nn.Sequential(nn.Linear(784,256),nn.ReLU(True),nn.Linear(256,128),nn.ReLU(True),nn.Linear(128,1),nn.Sigmoid())defforward(self,x):x=self.discriminator(x)returnx912.2.5生成式對抗網絡3.定義GAN神經網絡classgenerator(nn.Module):def__init__(self,input_size):super(generator,self).__init__()self.generator=nn.Sequential(nn.Linear(input_size,128),nn.ReLU(True),nn.Linear(128,256),nn.ReLU(True),nn.Linear(256,784),nn.Tanh())defforward(self,x):x=self.generator(x)returnx1012.2.5生成式對抗網絡
4.訓練GAN神經網絡D=discriminator()#實例化判別器G=generator(Z)#實例化生成器loss=nn.BCELoss()#定義二元交叉熵損失函數d_optimizer=torch.optim.Adam(D.parameters(),lr=0.0001)#定義判別器之優化器g_optimizer=torch.optim.Adam(G.parameters(),lr=0.0001)#定義生成器之優化器T=100#訓練迭代次數1112.2.5生成式對抗網絡forepochinrange(T):fori,(im,_)inenumerate(train_loader):num_im=im.size(0)#訓練判別器im=im.view(num_im,-1)real_im=Variable(im)real_label=Variable(torch.ones(num_im))fake_label=Variable(torch.zeros(num_im))real_pred=D(real_im).squeeze(-1)#預測真圖像的類別標記(理想情況為1)d_loss_real=loss(real_pred,real_label)#真圖像對應損失(預測類別標記,真實類別標記)#real_scores=real_predz_vector=Variable(torch.randn(num_im,Z))#生成噪聲向量fake_img=G(z_vector)#生成假圖像1212.2.5生成式對抗網絡fake_pred=D(fake_img).squeeze(-1)#判別器對假圖像的預測類別標記(理想情況為0)d_loss_fake=loss(fake_pred,fake_label)#假圖像對應損失(預測類別標記,真實類別標記)#fake_scores=fake_predd_loss=d_loss_real+d_loss_fake#真假圖像損失之和d_optimizer.zero_grad()#梯度清零d_loss.backward()#誤差反傳d_optimizer.step()#更新參數#訓練生成器z_vector=Variable(torch.randn(num_im,Z))#生成噪聲向量fake_im=G(z_vector)#生成假圖像fake_pred=D(fake_im).squeeze(-1)#判別器對假圖像的預測類別標記g_loss=loss(fake_pred,real_label)#計算損失g_optimizer.zero_grad()#梯度清零g_loss.backward()#誤差反傳g_optimizer.step()#更新參數1312.2.5生成式對抗網絡
if(i+1)%100==0:print('Epoch[{}/{}],d_loss:{:.6f},g_loss:{:.6f}'.format(epoch,T,d_loss.data.numpy(),g_loss.data.numpy(),))
ifepoch==0:real_images=to_im(real_im.data)save_image(real_images,'./results/real_images.png')fake_images=to_im(fake_img.data)save_image(fake_images,'./results/fake_images-{}.png'.format(epoch+1))1412.2.5生成式對抗網絡1512.2.5生成式對抗網絡生成對抗網絡主要用于在已知數據的基礎上生成可靠的或“以假亂真”的新數據,其結構由生成器與判別器兩部分構成,而由于其主要解決樣本真偽問題,因而,相應的損失函數通常采用二元交叉損失函數。左圖所示生成器旨在生成可以讓判別器無法判別的樣本,而判別器旨在可靠地判別樣本的真偽,兩者相互對抗之后,判別器將可以根據任何噪聲數據生成“以假亂真”的數據,在實際應用中,生成對抗網絡可以根據在已知數據集的基礎上對其進行擴展,可在一定程度上解決以數據驅動為特色的深度神經網絡所需數據匱乏的問題。1612.2.6殘差神經網絡對于深度神經網絡,其層次越深,特征表達或非線性建模能力越強,但也由于易出現梯度彌散與梯度爆炸等問題導致其在實際中難以被訓練或泛化能力較差。事實上,以卷積神經網絡為例,對從輸入層至輸出層的數據不斷進行的濾波處理(如卷積與池化)雖然在一定程度上可以避免過擬合與降低運算量,但同時也可能損失一些潛在的關鍵信息(類似有損壓縮),特別在層次增多時,此問題將更為嚴重(如清晰的圖像經過多次卷積后將無法被辨識)。
12.2.6殘差神經網絡17為了解決上述問題,如左圖所示,殘差神經網絡通過在傳統層或模塊的基礎上引入恒等映射(在輸入與輸出之間建立直接的關聯通道)的方式將原數據與“期望輸出與原數據之間的殘差”進行融合,使用層或模塊以原數據作為參考實現特征的學習而不至于損失較多的信息。利用MNIST數據集構建殘差神經網絡并訓練與測試。1.導入庫importtorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorchvisionimportdatasetsfromtorchvisionimporttransformsfromtorch.utils.dataimportDataLoader18
12.2.6殘差神經網絡19
12.2.6殘差神經網絡2.生成數據#每次處理的圖像數量batch_size=128#預處理操作data_transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#加載數據train_dataset=datasets.MNIST(root='./data',train=True,transform=data_transform,download=False)test_dataset=datasets.MNIST(root='./data',train=False,transform=data_transform)train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False)20
12.2.6殘差神經網絡3.定義殘差神經網絡模塊classResidualBlock(nn.Module):def__init__(self,channels):super(ResidualBlock,self).__init__()self.channels=channelsself.conv1=nn.Conv2d(channels,channels,kernel_size=3,padding=1)self.conv2=nn.Conv2d(channels,channels,kernel_size=3,padding=1)defforward(self,x):x=self.conv1(x)y=self.conv2(F.relu(x))y+=xy=F.relu(y)returny4.定義神經網絡classNET(nn.Module):def__init__(self):super(NET,self).__init__()self.conv1=nn.Conv2d(1,16,kernel_size=5)self.conv2=nn.Conv2d(16,32,kernel_size=5)self.mp=nn.MaxPool2d(2)self.rblock1=ResidualBlock(16)#引入殘差神經網絡模塊(通道數:16)self.rblock2=ResidualBlock(32)#引入殘差神經網絡模塊(通道數:32)self.fc=nn.Linear(512,10)
21
12.2.6殘差神經網絡defforward(self,x):x=self.conv1(x)x=self.mp(F.relu(x))x=self.rblock1(x)x=self.conv2(x)x=self.mp(F.relu(x))x=self.rblock2(x)x=x.view(x.size(0),-1)x=self.fc(x)returnx
#實例化神經網絡對象model=NET()22
12.2.6殘差神經網絡23
12.2.6殘差神經網絡4.訓練卷積神經網絡#定義損失函數(多元交叉熵損失函數)loss=torch.nn.CrossEntropyLoss()#定義優化器optimizer=torch.optim.SGD(model.parameters(),lr=0.01,momentum=.5)#神經網絡訓練T=10#訓練迭代次數24
12.2.6殘差神經網絡forepochinrange(T):#running_loss=0loss_=0.0#累積訓練誤差acc_=0.0#累積訓練精度fori,datainenumerate(train_loader):im,label=data
label_pred=model(im)#前向傳播L=loss(label_pred,label)#計算誤差loss_+=L.data.numpy()#誤差累積_,label_opt=torch.max(label_pred,1)#求取預測概率對應的類別acc_+=(label_opt==label).float().mean()#累積精度optimizer.zero_grad()#梯度清零L.backward()#誤差反傳optimizer.step()#更新參數#顯示誤差與精度變化if(epoch==0)|((epoch+1)%2==0):print('Epoch:[{}/{}],Loss:{:.4f},Accuracy:{:.4f}'.format(epoch+1,T,loss_/i,acc_/i))25
12.2.6殘差神經網絡5.神經網絡測試model.eval()acc_=0.0#累積精度fori,datainenumerate(test_loader,1):im,label=datalabel_pred=model(im)#前向傳播_,label_opt=torch.max(label_pred,1)acc_+=(label_opt==label).float().mean()print('Accuracy:{:.4f}'.format(acc_/i))Accuracy:0.988226
12.2.6殘差神經網絡殘差神經網絡通過在神經網絡層或模型的輸入與輸出之間構造恒等映射的方式解決在神經網絡層或模型較多時出現的梯度彌散或梯度爆炸等問題,在具體應用中,殘差神經網絡模塊可嵌入至其他深度神經網絡框架之中,具有較高的靈活性。本例中的神經網絡結構較為簡單,但通過嵌入殘差神經網絡模塊,其依然獲得了較高的精度,在一定程度上表明殘差神經網絡模塊的有效性。2712.2.7孿生神經網絡孿生神經網絡(Siameseneuralnetwork),又名雙生神經網絡,是基于兩個神經網絡建立的耦合構架。孿生神經網絡以兩個樣本為輸入,輸出其嵌入高維度空間的表征,以比較兩個樣本的相似程度。
12.2.7孿生神經網絡28孿生神經網絡(SiameseNeuralNetwork)旨在利用兩個神經網絡將兩個輸入數據映射至高維特征空間以比較其相似程度。如下圖所示,狹義的孿生神經網絡由兩個權重共享的神經網絡拼接而成,而廣義的孿生神經網絡或“偽孿生神經網絡”則可由任意兩個神經網絡拼接。
12.2.7孿生神經網絡29
12.2.7孿生神經網絡30
利用CIFAR數據集(數據組織方式:CIFAR文件夾中包含train與test兩個文件夾,其中,10個類別的圖像分別保存至以數字0~9命名的子文件夾內)構建孿生神經網絡,要求如下。
通過自定義數據集類的方式構建孿生神經網絡訓練數據與測試數據。
構建孿生神經網絡并訓練與測試(自定義對比損失函數)。31
12.2.7孿生神經網絡利用CIFAR數據集(數據組織方式:CIFAR文件夾中包含train與test兩個文件夾,其中,10個類別的圖像分別保存至以數字0~9命名的子文件夾內)構建孿生神經網絡,要求如下。
通過自定義數據集類的方式構建孿生神經網絡訓練數據與測試數據。
構建孿生神經網絡并訓練與測試(自定義對比損失函數)。32
12.2.7孿生神經網絡1.導入庫importtorchfromtorchimportnn,optimfromtorch.utils.dataimportDataLoader,Datasetfromtorchvisionimportdatasets,transformsimporttorch.nn.functionalasFimportmatplotlib.pyplotasplt#導入繪圖庫fromPILimportImage#導入圖像處理庫importPILimportnumpyasnp
#導入科學計算庫importos#導入文件處理庫
importglob33
12.2.7孿生神經網絡34
12.2.7孿生神經網絡defget_label_from_image_path(image_path):returnint(os.path.split(image_path)[0].split('\\')[-1])classMyDataSet(Dataset):def__init__(self,image_folder,transform=None,should_invert=True):self.image_folder=image_folder#存放圖片的文件夾
self.transform=transform#預處理
self.should_invert=should_invert#通道反轉
self.image_list=glob.glob(self.image_folder+'/*/*.jpg')#圖片列表示例:D:\\data\\1\\0001.jpgD:\\data\\2\\0003.jpg35
12.2.7孿生神經網絡def__getitem__(self,*args):label_yn=np.random.randint(2)#兩圖像若相似為1或若不相似為0im_A_path=np.random.choice(self.image_list)#隨機選擇1幅圖片im_A_label=int(os.path.split(im_A_path)[0].split('\\')[-1])#圖片真實類別iflabel_yn:#抽取與當前圖像屬于同一個類別的圖像whileTrue:im_B_path=np.random.choice(self.image_list)#隨機選擇圖像im_B_label=int(os.path.split(im_B_path)[0].split('\\')[-1])#圖片真實類別ifim_A_label==im_B_label:#若類別相同則終止breakelse:whileTrue:im_B_path=np.random.choice(self.image_list)#隨機選擇圖像im_B_label=int(os.path.split(im_B_path)[0].split('\\')[-1])#圖片真實類別ifim_A_label!=im_B_label:#若類別不相同則終止break#讀取圖像im_A=Image.open(im_A_path)im_B=Image.open(im_B_path)#判斷是否進行通道反轉ifself.yn_invert:im_A=PIL.ImageOps.invert(im_A)im_B=PIL.ImageOps.invert(im_B)#判斷是否進行預處理ifself.transformisnotNone:im_A=self.transform(im_A)im_B=self.transform(im_B)returnim_A,im_B,label_yn#返回兩幅圖像與相似標記def__len__(self):returnlen(self.image_list)
36
12.2.7孿生神經網絡#定義預處理操作transform=transforms.Compose([transforms.Grayscale(num_output_channels=1),#轉成單通道transforms.ToTensor(),#轉換為張量transforms.Normalize((0.5,),(0.5,)),#歸一化
])#加載數據train_dir="./DATA/CIFAR/train/"#訓練集文件夾
train_dataset=MyDataSet(im_dir=train_dir,transform=transform,yn_invert=False)train_dataloader=DataLoader(train_dataset,shuffle=True,batch_size=32)test_dir="./DATA/CIFAR/test/"#測試集文件夾
test_dataset=MyDataSet(im_dir=test_dir,transform=transform,yn_invert=False)test_dataloader=DataLoader(test_dataset,shuffle=True,batch_size=32)37
12.2.7孿生神經網絡38
12.2.7孿生神經網絡#讀取數據集test_set=enumerate(test_dataloader)ix,test_im_label=next(test_set)ims_1=test_im_label[0]print(ims_1.size())>>torch.Size([32,1,32,32])ims_2=test_im_label[1]print(ims_2.size())>>torch.Size([32,1,32,32])label=test_im_label[2]print(label)39
12.2.7孿生神經網絡2.構建孿生神經網絡classSiameseNetwork(nn.Module):def__init__(self):super().__init__()n=nn.Sequential(nn.ReflectionPad2d(1),nn.Conv2d(1,5,3,1),nn.ReLU(),nn.BatchNorm2d(5),nn.ReflectionPad2d(1),nn.Conv2d(5,10,3,1),nn.ReLU(),nn.BatchNorm2d(10),nn.ReflectionPad2d(1),nn.Conv2d(10,20,3,1),nn.ReLU(inplace=True),nn.BatchNorm2d(20),)40
12.2.7孿生神經網絡self.fc=nn.Sequential(nn.Linear(32*32*20,100),nn.ReLU(),nn.Linear(100,50),nn.ReLU(),nn.Linear(50,5))defforward_once(self,x):y=n(x)y=y.view(y.size()[0],-1)y=self.fc(y)returnydefforward(self,x1,x2):y1=self.forward_once(x1)y2=self.forward_once(x2)returny1,y2net=SiameseNetwork()#定義模型41
12.2.7孿生神經網絡3.
定義對比損失函數classContrastiveLoss(torch.nn.Module):def__init__(self,margin=2.0):super(ContrastiveLoss,self).__init__()self.margin=margindefforward(self,x1,x2,label):euclidean_distance=F.pairwise_distance(x1,x2,keepdim=True)loss=torch.mean((1-label)*torch.pow(euclidean_distance,2)+(label)*torch.pow(torch.clamp(self.margin-euclidean_distance,min=0.0),2))returnlossloss=ContrastiveLoss()42
12.2.7孿生神經網絡4.定義優化器optimizer=optim.Adam(net.parameters(),lr=0.001)#優化器5.模型訓練T=10forepochinrange(T):loss_=0.0#累積訓練誤差fori,datainenumerate(train_dataloader,1):
im_1,im_2,label=dataoutput1,output2=net(im_1,im_2)L=loss(output1,output2,label)optimizer.zero_grad()L.backward()optimizer.step()loss_+=L.data.numpy()#誤差累積#顯示誤差變化if(epoch==0)|((epoch+1)%2==0):print('Epoch:[{}/{}],Loss:{:.4f}'.format(epoch+1,T,loss_/i))43
12.2.7孿生神經網絡6.模型測試output1,output2=net(ims_1,ims_2)euclidean_distance=F.pairwise_distance(output1,output2)#也可以設置閾值求精度,如:euclidean_distance[euclidean_distance<2]=1euclidean_distance[euclidean_distance>=2]=0y_diff=torch.abs(euclidean_distance-label)#求差異Acc=torch.mean(y_diff)print('精度A2:{}'.format(Acc))44
12.2.7孿生神經網絡Epoch:[1/100],Loss:1.0626Epoch:[2/100],Loss:0.9127Epoch:[4/100],Loss:0.8167Epoch:[6/100],Loss:0.7341Epoch:[8/100],Loss:0.6428Epoch:[10/100],Loss:0.5639Epoch:[12/100],Loss:0.4999Epoch:[14/100],Loss:0.4407Epoch:[16/100],Loss:0.3927Epoch:[18/100],
溫馨提示
- 1. 本站所有資源如無特殊說明,都需要本地電腦安裝OFFICE2007和PDF閱讀器。圖紙軟件為CAD,CAXA,PROE,UG,SolidWorks等.壓縮文件請下載最新的WinRAR軟件解壓。
- 2. 本站的文檔不包含任何第三方提供的附件圖紙等,如果需要附件,請聯系上傳者。文件的所有權益歸上傳用戶所有。
- 3. 本站RAR壓縮包中若帶圖紙,網頁內容里面會有圖紙預覽,若沒有圖紙預覽就沒有圖紙。
- 4. 未經權益所有人同意不得將文件中的內容挪作商業或盈利用途。
- 5. 人人文庫網僅提供信息存儲空間,僅對用戶上傳內容的表現方式做保護處理,對用戶上傳分享的文檔內容本身不做任何修改或編輯,并不能對任何下載內容負責。
- 6. 下載文件中如有侵權或不適當內容,請與我們聯系,我們立即糾正。
- 7. 本站不保證下載資源的準確性、安全性和完整性, 同時也不承擔用戶因使用這些下載資源對自己和他人造成任何形式的傷害或損失。
最新文檔
- 2025年環境工程師職業資格考試題及答案
- 社區考試題簡答題及答案
- 財務會計準則實操試題及答案
- 西方政治制度中的利益集團作用分析試題及答案
- 機電工程新興技術應用試題及答案
- 知識共享政策的實施與效果評估試題及答案
- 軟件設計師考試關鍵思考點試題及答案
- 網絡流量監控的趨勢與試題及答案
- 意識到考試復習的重要內容試題及答案
- 網絡策略與商業價值關系分析試題及答案
- 薪酬激勵實施方案
- 2025年上海市各區高三語文一模試題匯編之文言文二閱讀(含答案)
- 大學英語四級寫作課件
- 《PBR次世代游戲建模技術》教學大綱
- 國家開放大學本科《管理英語3》一平臺機考真題及答案總題庫珍藏版
- 20萬噸高塔造粒顆粒硝酸銨工藝安全操作規程
- CJJ82-2012 園林綠化工程施工及驗收規范
- 江蘇省南京市2022-2023學年四年級下學期數學期末試卷(含答案)
- 江蘇省南京市建鄴區2022-2023學年五年級下學期期末數學試卷
- 提高感染性休克集束化治療完成率工作方案
- 肝硬化病人健康宣教課件
評論
0/150
提交評論