热门标签 | HotTags
当前位置:  开发笔记 > 编程语言 > 正文

经典论文复现|基于深度卷积网络的图像超分辨率算法

笔者本次选择复现的是汤晓鸥教授和何恺明团队发表于2015年的经典论文——论文复现代码:

笔者本次选择复现的是汤晓鸥教授和何恺明团队发表于 2015 年的经典论文—— SRCNN 。 超分辨率技术(Super-Resolution)是指从观测到的低分辨率图像重建出相应的高分辨率图像,在监控设备、卫星图像和医学影像等领域都有重要的应用价值。在深度卷积网络的浪潮下, 本文 首次提出了基于深度卷积网络的端到端超分辨率算法。

经典论文复现 | 基于深度卷积网络的图像超分辨率算法

论文复现代码: http://aistudio.baidu.com/aistudio/#/projectdetail/24446

SRCNN流程

经典论文复现 | 基于深度卷积网络的图像超分辨率算法 ▲ SRCNN算法框架

SRCNN 将深度学习与传统稀疏编码之间的关系作为依据,将 3 层网络划分为 图像块提取 (Patch extraction and representation)、 非线性映射 (Non-linear mapping)以及最终的 重建 (Reconstruction)。

SRCNN 具体流程如下:

1. 先将低分辨率图像使用双三次差值放大至目标尺寸(如放大至 2 倍、3 倍、4 倍),此时仍然称放大至目标尺寸后的图像为低分辨率图像(Low-resolution image),即图中的输入(input);

2. 将低分辨率图像输入三层 卷积神经网络 。举例:在论文其中一个实验相关设置,对 YCrCb 颜色空间中的 Y 通道进行重建,网络形式为 (conv1+relu1)—(conv2+relu2)—(conv3+relu3);第一层卷积:卷积核尺寸 9×9 (f1×f1),卷积核数目 64 (n1),输出 64 张特征图;第二层卷积:卷积核尺寸 1×1 (f2×f2),卷积核数目 32 (n2),输出 32 张特征图;第三层卷积:卷积核尺寸 5×5 (f3×f3),卷积核数目 1 (n3),输出 1 张特征图即为最终重建高分辨率图像。

经典论文复现 | 基于深度卷积网络的图像超分辨率算法

训练

训练数据集

论文中某一实验采用 91 张自然图像作为训练数据集,对训练集中的图像先使用双三次差值缩小到低分辨率尺寸,再将其放大到目标放大尺寸,最后切割成诸多 33 × 33 图像块作为训练数据,作为标签数据的则为图像中心的 21 × 21 图像块(与卷积层细节设置相关)。

损失函数

采用 MSE 函数作为 卷积神经网络 损失函数。

经典论文复现 | 基于深度卷积网络的图像超分辨率算法

卷积层细节设置

第一层卷积核 9 × 9,得到特征图尺寸为 (33-9)/1+1=25,第二层卷积核 1 × 1,得到特征图尺寸不变,第三层卷积核 5 × 5,得到特征图尺寸为 (25-5)/1+1=21。训练时得到的尺寸为 21 × 21,因此图像中心的 21 × 21 图像块作为标签数据(卷积训练时不进行 padding)。

# 查看个人持久化工作区文件
!ls /home/aistudio/work/
# coding=utf-8
import os
import paddle.fluid as fluid
import paddle.v2 as paddle
from PIL import Image
import numpy as np
import scipy.misc
import scipy.ndimage
import h5py
import glob

FLAGS={"epoch": 10,"batch_size": 128,"image_size": 33,"label_size": 21,
      "learning_rate": 1e-4,"c_dim": 1,"scale": 3,"stride": 14,
      "checkpoint_dir": "checkpoint","sample_dir": "sample","is_train": True}

#utils
def read_data(path):
    with h5py.File(path, 'r') as hf:
        data = np.array(hf.get('data'))
        label = np.array(hf.get('label'))
        return data, label

def preprocess(path, scale=3):

    image = imread(path, is_grayscale=True)
    label_ = modcrop(image, scale)

    label_ = label_ / 255.
    input_ = scipy.ndimage.interpolation.zoom(label_, zoom=(1. / scale), prefilter=False)  # 一次
    input_ = scipy.ndimage.interpolation.zoom(input_, zoom=(scale / 1.), prefilter=False)  # 二次,bicubic

    return input_, label_

def prepare_data(dataset):
    if FLAGS['is_train']:
        data_dir = os.path.join(os.getcwd(), dataset)
        data = glob.glob(os.path.join(data_dir, "*.bmp"))
    else:
        data_dir = os.path.join(os.sep, (os.path.join(os.getcwd(), dataset)), "Set5")
        data = glob.glob(os.path.join(data_dir, "*.bmp"))

    return data

def make_data(data, label):
    if not os.path.exists('data/checkpoint'):
        os.makedirs('data/checkpoint')
    if FLAGS['is_train']:
        savepath = os.path.join(os.getcwd(), 'data/checkpoint/train.h5')
    # else:
    #     savepath = os.path.join(os.getcwd(), 'data/checkpoint/test.h5')

    with h5py.File(savepath, 'w') as hf:
        hf.create_dataset('data', data=data)
        hf.create_dataset('label', data=label)

def imread(path, is_grayscale=True):
    if is_grayscale:
        return scipy.misc.imread(path, flatten=True, mode='YCbCr').astype(np.float)  # 将图像转灰度
    else:
        return scipy.misc.imread(path, mode='YCbCr').astype(np.float)  # 默认为false

def modcrop(image, scale=3):

    if len(image.shape) == 3:  # 彩色 800*600*3
        h, w, _ = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w, :]
    else:  # 灰度 800*600
        h, w = image.shape
        h = h - np.mod(h, scale)
        w = w - np.mod(w, scale)
        image = image[0:h, 0:w]
    return image

def input_setup(config):
    if config['is_train']:
        data = prepare_data(dataset="data/data899/Train.zip_files/Train")
    else:
        data = prepare_data(dataset="Test")

    sub_input_sequence = []
    sub_label_sequence = []
    padding = abs(config['image_size'] - config['label_size']) // 2  # 6 填充

    if config['is_train']:
        for i in range(len(data)):
            input_, label_ = preprocess(data[i], config['scale'])  # data[i]为数据目录

            if len(input_.shape) == 3:
                h, w, _ = input_.shape
            else:
                h, w = input_.shape
            for x in range(0, h - config['image_size'] + 1, config['stride']):
                for y in range(0, w - config['image_size'] + 1, config['stride']):
                    sub_input = input_[x:x + config['image_size'], y:y + config['image_size']]  # [33 x 33]
                    sub_label = label_[x + padding:x + padding + config['label_size'],
                                y + padding:y + padding + config['label_size']]  # [21 x 21]

                    # Make channel value,颜色通道1
                    sub_input = sub_input.reshape([config['image_size'], config['image_size'], 1])
                    sub_label = sub_label.reshape([config['label_size'], config['label_size'], 1])

                    sub_input_sequence.append(sub_input)
                    sub_label_sequence.append(sub_label)
        arrdata = np.asarray(sub_input_sequence)  # [?, 33, 33, 1]
        arrlabel = np.asarray(sub_label_sequence)  # [?, 21, 21, 1]

        make_data(arrdata, arrlabel)  # 把处理好的数据进行存储,路径为checkpoint/..
    else:
        input_, label_ = preprocess(data[4], config['scale'])

        if len(input_.shape) == 3:
            h, w, _ = input_.shape
        else:
            h, w = input_.shape
        input = input_.reshape([h, w, 1])

        label = label_[6:h - 6, 6:w - 6]
        label = label.reshape([h - 12, w - 12, 1])

        sub_input_sequence.append(input)
        sub_label_sequence.append(label)

        input1 = np.asarray(sub_input_sequence)
        label1 = np.asarray(sub_label_sequence)
        return input1, label1, h, w

def imsave(image, path):
    return scipy.misc.imsave(path, image)
#train
def reader_creator_image_and_label():
    input_setup(FLAGS)
    data_dir= os.path.join('./data/{}'.format(FLAGS['checkpoint_dir']), "train.h5")
    images,labels=read_data(data_dir)
    def reader():
        for i in range(len(images)):
            yield images, labels
    return reader
def train(use_cuda, num_passes,BATCH_SIZE = 128, model_save_dir='../models'):
    if FLAGS['is_train']:
      images = fluid.layers.data(name='images', shape=[1, FLAGS['image_size'], FLAGS['image_size']], dtype='float32')
      labels = fluid.layers.data(name='labels', shape=[1, FLAGS['label_size'], FLAGS['label_size']], dtype='float32')
    else:
      _,_,FLAGS['image_size'],FLAGS['label_size']=input_setup(FLAGS)
      images = fluid.layers.data(name='images', shape=[1, FLAGS['image_size'], FLAGS['label_size']], dtype='float32')
      labels = fluid.layers.data(name='labels', shape=[1, FLAGS['image_size']-12, FLAGS['label_size']-12], dtype='float32')

    #feed_order=['images','labels']
    # 获取神经网络的训练结果
    predict = model(images)
    # 获取损失函数
    cost = fluid.layers.square_error_cost(input=predict, label=labels)
    # 定义平均损失函数
    avg_cost = fluid.layers.mean(cost)
    # 定义优化方法
    optimizer = fluid.optimizer.Momentum(learning_rate=1e-4,momentum=0.9)
    opts =optimizer.minimize(avg_cost)

    # 是否使用GPU
    place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()

    # 初始化执行器
    exe=fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    # 获取训练数据
    train_reader = paddle.batch(
        reader_creator_image_and_label(), batch_size=BATCH_SIZE)
    # 获取测试数据
    # test_reader = paddle.batch(
    #     read_data(), batch_size=BATCH_SIZE)
    #print(len(next(train_reader())))
    feeder = fluid.DataFeeder(place=place, feed_list=[images, labels])
    for pass_id in range(num_passes):
        for batch_id, data in enumerate(train_reader()):
            avg_cost_value = exe.run(fluid.default_main_program(),
                                    feed=feeder.feed(data),
                                    fetch_list=[avg_cost])

            if batch_id%100 == 0:
                print("loss="+avg_cost_value[0])

def model(images):
    conv1=fluid.layers.conv2d(input=images, num_filters=64, filter_size=9, act='relu')
    conv2=fluid.layers.conv2d(input=conv1, num_filters=32, filter_size=1,act='relu')
    conv3=fluid.layers.conv2d(input=conv2, num_filters=1, filter_size=5)
    return conv3

if __name__ == '__main__':
    # 开始训练
    train(use_cuda=False, num_passes=10)

测试

全卷积网络

所用网络为全卷积网络,因此作为实际测试时,直接输入完整图像即可。

Padding

训练时得到的实际上是除去四周 (33-21)/2=6 像素外的图像,若直接采用训练时的设置(无 padding),得到的图像最后会减少四周各 6 像素(如插值放大后输入 512 × 512,输出 500 × 500)。

因此在测试时每一层卷积都进行了 padding(卷积核尺寸为 1 × 1的不需要进 行 padding),这样保证插值放大后输入与输出尺寸的一致性。

重建结果

客观评价指标 PSNR 与 SSIM:相比其他传统方法,SRCNN 取得更好的重建效果。

经典论文复现 | 基于深度卷积网络的图像超分辨率算法

主观效果:相比其他传统方法,SRCNN 重建效果更具优势。

经典论文复现 | 基于深度卷积网络的图像超分辨率算法


以上所述就是小编给大家介绍的《经典论文复现 | 基于深度卷积网络的图像超分辨率算法》,希望对大家有所帮助,如果大家有任何疑问请给我留言,小编会及时回复大家的。在此也非常感谢大家对 我们 的支持!


推荐阅读
  • 在第七天的深度学习课程中,我们将重点探讨DGL框架的高级应用,特别是在官方文档指导下进行数据集的下载与预处理。通过详细的步骤说明和实用技巧,帮助读者高效地构建和优化图神经网络的数据管道。此外,我们还将介绍如何利用DGL提供的模块化工具,实现数据的快速加载和预处理,以提升模型训练的效率和准确性。 ... [详细]
  • 深入解析 Django 中用户模型的自定义方法与技巧 ... [详细]
  • 2019年后蚂蚁集团与拼多多面试经验详述与深度剖析
    2019年后蚂蚁集团与拼多多面试经验详述与深度剖析 ... [详细]
  • Java 8 引入了 Stream API,这一新特性极大地增强了集合数据的处理能力。通过 Stream API,开发者可以更加高效、简洁地进行集合数据的遍历、过滤和转换操作。本文将详细解析 Stream API 的核心概念和常见用法,帮助读者更好地理解和应用这一强大的工具。 ... [详细]
  • 优化后的标题:校园互联新方案:10397连接教育未来 ... [详细]
  • voc生成xml 代码
    目录 lxmlwindows安装 读取示例 可视化 生成示例 上面是代码,下面有调用示例 api调用代码,其实只有几行:这个生成代码也很简 ... [详细]
  • 在Unity3D的第13天学习中,我们深入探讨了关节系统和布料模拟技术。关节系统作为Unity中的关键物理组件,能够实现游戏对象间的动态连接,如刚体间的关系、门的开合动作以及角色的布娃娃效果。铰链关节涉及两个刚体的交互,能够精确模拟复杂的机械运动,为游戏增添了真实感。此外,布料模拟技术则进一步提升了角色衣物和环境装饰物的自然表现,增强了视觉效果的真实性和沉浸感。 ... [详细]
  • 如何使用 org.geomajas.configuration.FontStyleInfo.getColor() 方法及其代码示例详解 ... [详细]
  • 利用ViewComponents在Asp.Net Core中构建高效分页组件
    通过运用 ViewComponents 技术,在 Asp.Net Core 中实现了高效的分页组件开发。本文详细介绍了如何通过创建 `PaginationViewComponent` 类并利用 `HelloWorld.DataContext` 上下文,实现对分页参数的定义与管理,从而提升 Web 应用程序的性能和用户体验。 ... [详细]
  • 如何在Android应用中设计和实现专业的启动欢迎界面(Splash Screen)
    在Android应用开发中,设计与实现一个专业的启动欢迎界面(Splash Screen)至关重要。尽管Android设计指南对使用Splash Screen的态度存在争议,但一个精心设计的启动界面不仅能提升用户体验,还能增强品牌识别度。本文将探讨如何在遵循最佳实践的同时,通过技术手段实现既美观又高效的启动欢迎界面,包括加载动画、过渡效果以及性能优化等方面。 ... [详细]
  • 本文提供了PyTorch框架中常用的预训练模型的下载链接及详细使用指南,涵盖ResNet、Inception、DenseNet、AlexNet、VGGNet等六大分类模型。每种模型的预训练参数均经过精心调优,适用于多种计算机视觉任务。文章不仅介绍了模型的下载方式,还详细说明了如何在实际项目中高效地加载和使用这些模型,为开发者提供全面的技术支持。 ... [详细]
  • 使用React与Ant Design 3.x构建IP地址输入组件
    本文深入探讨了利用React框架结合Ant Design 3.x版本开发IP地址输入组件的方法。通过详细的代码示例,展示了如何高效地创建具备良好用户体验的IP输入框,对于前端开发者而言具有较高的实践指导意义。 ... [详细]
  • 如何在datetimebox中进行赋值与取值操作
    在 datetimebox 中进行赋值和取值操作时,可以通过以下方法实现:使用 `$('#j_dateStart').datebox('setValue', '指定日期')` 进行赋值,而通过 `$('#j_dateStart').datebox('getValue')` 获取当前选中的日期值。若需要清空日期值,可以使用 `$('#j_dateStart').datebox('clear')` 方法。这些操作能够确保日期控件的准确性和灵活性,适用于各种前端应用场景。 ... [详细]
  • 在财务分析与金融数据处理中,利用Python的强大库如NumPy和SciPy可以高效地计算各种财务指标。例如,通过调用这些库中的函数,可以轻松计算货币的时间价值,包括终值(FV)等关键指标。此外,这些库还提供了丰富的统计和数学工具,有助于进行更深入的数据分析和模型构建。 ... [详细]
  • HTML5 Web存储技术是许多开发者青睐本地应用程序的重要原因之一,因为它能够实现在客户端本地存储数据。HTML5通过引入Web Storage API,使得Web应用程序能够在浏览器中高效地存储数据,从而提升了应用的性能和用户体验。相较于传统的Cookie机制,Web Storage不仅提供了更大的存储容量,还简化了数据管理和访问的方式。本文将从基础概念、关键技术到实际应用,全面解析HTML5 Web存储技术,帮助读者深入了解其工作原理和应用场景。 ... [详细]
author-avatar
一起去钓鱼
这个家伙很懒,什么也没留下!
PHP1.CN | 中国最专业的PHP中文社区 | DevBox开发工具箱 | json解析格式化 |PHP资讯 | PHP教程 | 数据库技术 | 服务器技术 | 前端开发技术 | PHP框架 | 开发工具 | 在线工具
Copyright © 1998 - 2020 PHP1.CN. All Rights Reserved | 京公网安备 11010802041100号 | 京ICP备19059560号-4 | PHP1.CN 第一PHP社区 版权所有