热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow中使用tf.conv2d_transpose()函数进行卷积转置操作

我先解释一下必要信息:tf.conv2d_transpose(value,filter,output_shape,strides,paddingSAME,dat

我先解释一下必要信息:

tf.conv2d_transpose(value, filter, output_shape, strides, padding="SAME", data_format="NHWC", name=None)

 

除去name参数用以指定该操作的name,与方法有关的一共六个参数:

 

 

第一个参数value:指需要做反卷积的输入图像,它要求是一个Tensor
第二个参数filter:卷积核,它要求是一个Tensor,具有[filter_height, filter_width, out_channels, in_channels]这样的shape,具体含义是[卷积核的高度,卷积核的宽度,卷积核个数,图像通道数]
第三个参数output_shape:反卷积操作输出的shape,细心的同学会发现卷积操作是没有这个参数的.
第四个参数strides:反卷积时在图像每一维的步长,这是一个一维的向量,长度4
第五个参数padding:string类型的量,只能是"SAME","VALID"其中之一,这个值决定了不同的卷积方式
第六个参数data_format:string类型的量,'NHWC'和'NCHW'其中之一,这是tensorflow新版本中新加的参数,它说明了value参数的数据格式。'NHWC'指tensorflow标准的数据格式[batch, height, width, in_channels],'NCHW'指Theano的数据格式,[batch, in_channels,height, width],当然默认值是'NHWC'

 

 

通俗的讲这个解卷积,也就做反卷积,也叫做转置卷积(最贴切),我们就叫做反卷积吧,它的目的就是卷积的反向操作, 

所以在做这些之前,你心里要有一个正向卷积的流程在心里,什么?你没有?好吧,那我就引导你一下:

input_shape = [1,5,5,3] 
kernel_shape=[2,2,3,1] 
strides=[1,2,2,1] 
padding = "SAME"

 

out_shape   结果应该是什么,应该是[1,3,3,1] 只有一个通道的3*3的图片,

然后我们就对它进行反向操作,注意哪方面不同:

设x=out_shape,#[1,3,3,1]

import tensorflow as tf
tf.set_random_seed(1)x = tf.random_normal(shape=[1,3,3,1])#正向卷积的结果,要作为反向卷积的输出
kernel = tf.random_normal(shape=[2,2,3,1])#正向卷积的kernel的模样# strides 和padding也是假想中 正向卷积的模样。
y = tf.nn.conv2d_transpose(x,kernel,output_shape=[1,5,5,3],strides=[1,2,2,1],padding="SAME")
# 在这里,output_shape=[1,6,6,3]也可以,考虑正向过程,[1,6,6,3]时,然后通过
# kernel_shape:[2,2,3,1],strides:[1,2,2,1]也可以
# 获得x_shape:[1,3,3,1]。
# output_shape 也可以是一个 tensor
sess = tf.Session()
tf.global_variables_initializer().run(session=sess)print(y.eval(session=sess))

然后输出的y就是最上面的input_shape,

我想到了一个很合理的方法就是这样定制你的反卷积网络,也即是你在进行反卷积之前,你要推算一下正向卷积所需要的路径,然后把正向卷积所需要的kernel,和strides写入tf.conv2d_transpose()函数就行了,当然输入和输出要互相对换一下就行了,

下面是我自己实现的3维反卷积操作,原理是一样的:

import tensorflow as tfkernel1 = tf.constant(1.0, shape=[3,3,3,512,512]) #正向卷积核
kernel2 = tf.constant(1.0, shape=[3,3,3,512,512]) #反向卷积核
x3 = tf.constant(1.0, shape=[10,2,7,7,512])#正向卷积输入
y2 = tf.nn.conv3d(x3, kernel1, strides=[1,1,1,1,1], padding="SAME") #正向卷积
pool=tf.nn.max_pool3d(y2,ksize=[1,2,2,2,1],strides=[1,2,2,2,1],padding='SAME')#池化sess=tf.Session()
sess.run(tf.global_variables_initializer())
sess.run(pool)
print(pool.shape)#(10,1,4,4,512)#反向卷积
y3 = tf.nn.conv3d_transpose(pool,kernel2,output_shape=[10,2,7,7,512], strides=[1,2,2,2,1],padding="SAME")
sess.run(y3)
print(y3.shape)#(10,2,7,7,512)

上面的例子是由[10,2,7,7,512]----卷积池化到----[10,1,4,4,512]----反卷积到----[10,2,7,7,512]    

至于内部原理是怎么实现的,请看:https://blog.csdn.net/u012938704/article/details/52838902

https://blog.csdn.net/kekong0713/article/details/68941498

http://deeplearning.net/software/theano_versions/dev/tutorial/conv_arithmetic.html#transposed-convolution-arithmetic


推荐阅读
  • 深入解析Android中的SQLite数据库使用
    本文详细介绍了如何在Android应用中使用SQLite数据库进行数据存储。通过自定义类继承SQLiteOpenHelper,实现数据库的创建与版本管理,并提供了具体的学生信息管理示例代码。 ... [详细]
  • Canvas漫游:碰撞检测与动画模拟
    探索Canvas在Web开发中的应用,通过碰撞检测与动画模拟提升交互体验。 ... [详细]
  • 使用 Azure Service Principal 和 Microsoft Graph API 获取 AAD 用户列表
    本文介绍了一段通用代码示例,该代码不仅能够操作 Azure Active Directory (AAD),还可以通过 Azure Service Principal 的授权访问和管理 Azure 订阅资源。Azure 的架构可以分为两个层级:AAD 和 Subscription。 ... [详细]
  • 深入解析Spring Cloud Ribbon负载均衡机制
    本文详细介绍了Spring Cloud中的Ribbon组件如何实现服务调用的负载均衡。通过分析其工作原理、源码结构及配置方式,帮助读者理解Ribbon在分布式系统中的重要作用。 ... [详细]
  • 本文深入探讨了 Java 中的 Serializable 接口,解释了其实现机制、用途及注意事项,帮助开发者更好地理解和使用序列化功能。 ... [详细]
  • 尽管使用TensorFlow和PyTorch等成熟框架可以显著降低实现递归神经网络(RNN)的门槛,但对于初学者来说,理解其底层原理至关重要。本文将引导您使用NumPy从头构建一个用于自然语言处理(NLP)的RNN模型。 ... [详细]
  • 本文将详细探讨 Java 中提供的不可变集合(如 `Collections.unmodifiableXXX`)和同步集合(如 `Collections.synchronizedXXX`)的实现原理及使用方法,帮助开发者更好地理解和应用这些工具。 ... [详细]
  • Keras 实战:自编码器入门指南
    本文介绍了使用 Keras 框架实现自编码器的基本方法。自编码器是一种用于无监督学习的神经网络模型,主要功能包括数据降维、特征提取等。通过实际案例,我们将展示如何使用全连接层和卷积层来构建自编码器,并讨论不同维度对重建效果的影响。 ... [详细]
  • 前言Git是目前最流行的版本控制系统,在它的基础之上,GitHub和GitLab成为当前最流行的代码托管平台,它们均提供的代码评审、项目管理、持续集成等功能,越来越多的互联网企业都 ... [详细]
  • Explore a common issue encountered when implementing an OAuth 1.0a API, specifically the inability to encode null objects and how to resolve it. ... [详细]
  • 本文详细介绍了如何在Linux系统上安装和配置Smokeping,以实现对网络链路质量的实时监控。通过详细的步骤和必要的依赖包安装,确保用户能够顺利完成部署并优化其网络性能监控。 ... [详细]
  • 1.如何在运行状态查看源代码?查看函数的源代码,我们通常会使用IDE来完成。比如在PyCharm中,你可以Ctrl+鼠标点击进入函数的源代码。那如果没有IDE呢?当我们想使用一个函 ... [详细]
  • XNA 3.0 游戏编程:从 XML 文件加载数据
    本文介绍如何在 XNA 3.0 游戏项目中从 XML 文件加载数据。我们将探讨如何将 XML 数据序列化为二进制文件,并通过内容管道加载到游戏中。此外,还会涉及自定义类型读取器和写入器的实现。 ... [详细]
  • 如何在窗口右下角添加调整大小的手柄
    本文探讨了如何在传统MFC/Win32 API编程中实现类似C# WinForms中的SizeGrip功能,即在窗口的右下角显示一个用于调整窗口大小的手柄。我们将介绍具体的实现方法和相关API。 ... [详细]
  • 本文详细介绍了Java库XChart中的XYSeries类下的setLineColor()方法,并提供了多个实际应用场景的代码示例。 ... [详细]
author-avatar
supe丶r女人帮
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有