热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

UNet神经网络

0引言随着深度学习领域中各类算法的迅速发展,卷积神经网络(CNN)被广泛应用在了分类任务上,输出的结果是整个图像的类标签。

0 引言

随着深度学习领域中各类算法的迅速发展,卷积神经网络(CNN)被广泛应用在了分类任务上,输出的结果是整个图像的类标签。在生物医学领域,医生需要对病人的病灶区域进行病理分析,这时需要一种更先进的网络模型,即能通过少量的图片训练集,就能实现对像素点类别的预测,并且可以对像素点进行着色绘图,形成更复杂、严谨的判断。于是U-Net网络被设计了出来。


1 U-Net概念及原理

U-Net网络结构最早由Ronneberger等人于2015年提出。该图像的核心思想是引入了跳跃连接,使得图像分割的精度大大提升。
U-Net网络的主要结构包括了解码器编码器瓶颈层三个部分。


  • 编码器:包括了四个程序块。每个程序块都包括 3×33\times33×3 的卷积(使用Relu激活函数),步长为 2222×22\times22×2 的池化层(下采样)。每个程序块处理后,特征图逐步减小。

  • 解码器: 与编码器部分对称,也包括四个程序块,每个程序块包括步长为 2222×22\times22×2 的上采样操作,然后与编码部分进行特征映射级联(Concatenate),即拼接,最后通过两个 3×33\times33×3 的卷积(Relu)。

  • 瓶颈层:包含两个 3×33\times33×3 的卷积层。

最后经过一个 1×11\times11×1的卷积层得到最后的输出。
在这里插入图片描述
如图所示,该网络模型形似字母“U”,故称为U-Net。

整体过程:
先对图片进行卷积和池化。比如说一开始输入的图片大小是 224×224224\times224224×224,进过四次池化后,分别得到 112×112112\times112112×112 , 56×5656\times5656×56 , 28×2828 \times 2828×28, 14×1414 \times 1414×14 四个不同尺寸的特征图。然后对 14×1414\times 1414×14 的特征图做上采样,得到 28×2828\times2828×28 的特征图。将这个 28×2828\times2828×28的特征图与之前池化得到的 28×2828\times2828×28 特征图进行通道上的拼接(concat),然后再对拼接之后的特征图做卷积和上采样,得到 56×5656\times 5656×56 的特征图,然后再与之前的 56×5656\times5656×56 拼接,卷积然后再上采样,经过四次就就可以得到一个与原输入图像大小相同的图片了。

在本图片上的U-Net中,它输入大小为 572×572572\times 572572×572, 而输出大小为 388×388388 \times 388388×388, 那是因为它在卷积过程中没有加padding层所造成的。


2 代码

import torch
import torch.nn as nn
import torch.nn.functional as F# Double Convolution
class DoubleConv2d(nn.Module):def __init__(self, inputChannel, outputChannel):super(DoubleConv2d, self).__init__()self.conv = nn.Sequential(nn.Conv2d(inputChannel, outputChannel, kernel_size=3, padding=1),nn.BatchNorm2d(outputChannel),nn.ReLU(True),nn.Conv2d(outputChannel, outputChannel, kernel_size=3, padding=1),nn.BatchNorm2d(outputChannel),nn.ReLU(True))def forward(self, x):out = self.conv(x)return out# Down Sampling
class DownSampling(nn.Module):def __init__(self):super(DownSampling, self).__init__()self.down = nn.MaxPool2d(kernel_size=2)def forward(self, x):out = self.down(x)return out# Up Sampling
class UpSampling(nn.Module):# Use the deconvolutiondef __init__(self, inputChannel, outputChannel):super(UpSampling, self).__init__()self.up = nn.Sequential(nn.ConvTranspose2d(inputChannel, outputChannel, kernel_size=2, stride=2),nn.BatchNorm2d(outputChannel))def forward(self, x, y):x =self.up(x)diffY = y.size()[2] - x.size()[2]diffX = y.size()[3] - x.size()[3]x = F.pad(x, [diffX // 2, diffX - diffX // 2,diffY // 2, diffY - diffY // 2])out = torch.cat([y, x], dim=1)return outclass Unet(nn.Module):def __init__(self):super(Unet, self).__init__()self.layer1 = DoubleConv2d(1, 64)self.layer2 = DoubleConv2d(64, 128)self.layer3 = DoubleConv2d(128, 256)self.layer4 = DoubleConv2d(256, 512)self.layer5 = DoubleConv2d(512, 1024)self.layer6 = DoubleConv2d(1024, 512)self.layer7 = DoubleConv2d(512, 256)self.layer8 = DoubleConv2d(256, 128)self.layer9 = DoubleConv2d(128, 64)self.layer10 = nn.Conv2d(64, 2, kernel_size=3, padding=1) # The last output layerself.down = DownSampling()self.up1 = UpSampling(1024, 512)self.up2 = UpSampling(512, 256)self.up3 = UpSampling(256, 128)self.up4 = UpSampling(128, 64)def forward(self, x):conv1 = self.layer1(x)down1 = self.down(conv1)conv2 = self.layer2(down1)down2 = self.down(conv2)conv3 = self.layer3(down2)down3 = self.down(conv3)conv4 = self.layer4(down3)down4 = self.down(conv4)conv5 = self.layer5(down4)up1 = self.up1(conv5, conv4)conv6 = self.layer6(up1)up2 = self.up2(conv6, conv3)conv7 = self.layer7(up2)up3 = self.up3(conv7, conv2)conv8 = self.layer8(up3)up4 = self.up4(conv8, conv1)conv9 = self.layer9(up4)out = self.layer10(conv9)return out# Test partmynet = Unet()
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# mynet.to(device)
input = torch.rand(3, 1, 572, 572)
# output = mynet(input.to(device))
output = mynet(input)
print(output.shape) # (3,2,572,572)


https://www.jianshu.com/p/a73f74992b1a
https://arxiv.org/pdf/1505.04597v1.pdf
https://blog.csdn.net/qq_34107425/article/details/110184747
https://blog.csdn.net/weixin_41857483/article/details/120768804


推荐阅读
  • 本博文基于《Amalgamationofproteinsequence,structureandtextualinformationforimprovingprote ... [详细]
  • 浏览器中的异常检测算法及其在深度学习中的应用
    本文介绍了在浏览器中进行异常检测的算法,包括统计学方法和机器学习方法,并探讨了异常检测在深度学习中的应用。异常检测在金融领域的信用卡欺诈、企业安全领域的非法入侵、IT运维中的设备维护时间点预测等方面具有广泛的应用。通过使用TensorFlow.js进行异常检测,可以实现对单变量和多变量异常的检测。统计学方法通过估计数据的分布概率来计算数据点的异常概率,而机器学习方法则通过训练数据来建立异常检测模型。 ... [详细]
  • 微软头条实习生分享深度学习自学指南
    本文介绍了一位微软头条实习生自学深度学习的经验分享,包括学习资源推荐、重要基础知识的学习要点等。作者强调了学好Python和数学基础的重要性,并提供了一些建议。 ... [详细]
  • 本文介绍了brain的意思、读音、翻译、用法、发音、词组、同反义词等内容,以及脑新东方在线英语词典的相关信息。还包括了brain的词汇搭配、形容词和名词的用法,以及与brain相关的短语和词组。此外,还介绍了与brain相关的医学术语和智囊团等相关内容。 ... [详细]
  • Linux重启网络命令实例及关机和重启示例教程
    本文介绍了Linux系统中重启网络命令的实例,以及使用不同方式关机和重启系统的示例教程。包括使用图形界面和控制台访问系统的方法,以及使用shutdown命令进行系统关机和重启的句法和用法。 ... [详细]
  • Python正则表达式学习记录及常用方法
    本文记录了学习Python正则表达式的过程,介绍了re模块的常用方法re.search,并解释了rawstring的作用。正则表达式是一种方便检查字符串匹配模式的工具,通过本文的学习可以掌握Python中使用正则表达式的基本方法。 ... [详细]
  • 不同优化算法的比较分析及实验验证
    本文介绍了神经网络优化中常用的优化方法,包括学习率调整和梯度估计修正,并通过实验验证了不同优化算法的效果。实验结果表明,Adam算法在综合考虑学习率调整和梯度估计修正方面表现较好。该研究对于优化神经网络的训练过程具有指导意义。 ... [详细]
  • Html5-Canvas实现简易的抽奖转盘效果
    本文介绍了如何使用Html5和Canvas标签来实现简易的抽奖转盘效果,同时使用了jQueryRotate.js旋转插件。文章中给出了主要的html和css代码,并展示了实现的基本效果。 ... [详细]
  • iOS Swift中如何实现自动登录?
    本文介绍了在iOS Swift中如何实现自动登录的方法,包括使用故事板、SWRevealViewController等技术,以及解决用户注销后重新登录自动跳转到主页的问题。 ... [详细]
  • Whatsthedifferencebetweento_aandto_ary?to_a和to_ary有什么区别? ... [详细]
  • Learning to Paint with Model-based Deep Reinforcement Learning
    本文介绍了一种基于模型的深度强化学习方法,通过结合神经渲染器,教机器像人类画家一样进行绘画。该方法能够生成笔画的坐标点、半径、透明度、颜色值等,以生成类似于给定目标图像的绘画。文章还讨论了该方法面临的挑战,包括绘制纹理丰富的图像等。通过对比实验的结果,作者证明了基于模型的深度强化学习方法相对于基于模型的DDPG和模型无关的DDPG方法的优势。该研究对于深度强化学习在绘画领域的应用具有重要意义。 ... [详细]
  • Introduction(简介)Forbeingapowerfulobject-orientedprogramminglanguage,Cisuseda ... [详细]
  • OCR:用字符识别方法将形状翻译成计算机文字的过程Matlab:商业数学软件;CUDA:CUDA™是一种由NVIDIA推 ... [详细]
  • [转载]从零开始学习OpenGL ES之四 – 光效
    继续我们的iPhoneOpenGLES之旅,我们将讨论光效。目前,我们没有加入任何光效。幸运的是,OpenGL在没有设置光效的情况下仍然可 ... [详细]
  • 人工智能推理能力与假设检验
    最近Google的Deepmind开始研究如何让AI做数学题。这个问题的提出非常有启发,逻辑推理,发现新知识的能力应该是强人工智能出现自我意识之前最需要发展的能力。深度学习目前可以 ... [详细]
author-avatar
mobiledu2502928483
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有