点击左上方蓝字关注我们
【飞桨开发者说】李思佑,PPDE飞桨开发者技术专家。指导老师:昆明理工大学理学院石剑平
我们都知道,有很多经典的老照片,受限于那个时代的技术,只能以黑白的形式传世。尽管黑白照片别有一番风味,但是彩色照片有时候能给人更强的代入感。本项目通过通俗易懂的方式简单实现黑白照片着色并对部分照片取得不错的着色效果。
黑白照片着色是计算机视觉领域经典的问题。近年来随着卷积神经网络(CNN)的广泛应用,通过CNN为黑白照片着色成为新颖且可行的方向。本文所讲的PaddleColorization项目是基于飞桨开发的,整体实现采用ResNet残差网络为主干网络并设计复合损失函数进行网络训练,并已经承载到了百度学习与实训社区AI Studio,欢迎大家fork学习~有任何问题欢迎在评论区留言互相交流哦。
AI Studio链接:
https://aistudio.baidu.com/aistudio/personalcenter/thirdview/56447
开启着色之旅
着色前的原图
着色后的效果
在学习前,先通过训练好的模型体验下黑白照片的彩色化吧。所需的预训练预测模型已经保存在
'work/model/gray2color.inference.model-0'
将黑白图片上传到'work/try/in'路径下
经由模型预测的着色结果将在'work/try/out'路径下
安装所需的依赖库:
!pip install sklearn scikit-image
导入Paddle及项目所需的Python库,将训练place设置为CUDA环境,读取训练好的预测模型gray2color.inference.model-0进行黑白图片的预测。
import paddle
import numpy as np
from skimage import io,color,transform
from paddle import fluid
import matplotlib.pyplot as plt
import os
import work.crop as crop
import matplotlib.pyplot as pltuse_cuda = True
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.core.Scope()
fluid.scope_guard(scope)
[inference_program, feed_target_names, fetch_targets] = fluid.io.load_inference_model(dirname=r'work/model/gray2color.inference.model-0',executor=exe,)def loadImage(image):'''读取图片,并转为Lab,并提取出L和ab'''img = io.imread(image)lab = np.array(color.rgb2lab(img)).transpose()l = lab[:1, :, :]return l.reshape(1,1,512,512).astype('float32')def lab2rgb(l,ab):a = ab[0]b = ab[1]l = l[:, :, np.newaxis]a = a[:, :, np.newaxis].astype('float64')b = b[:, :, np.newaxis].astype('float64')lab = np.concatenate((l, a, b), axis=2) img = color.lab2rgb((lab))img = transform.rotate(img, 270)img = np.fliplr(img)return imgdef run(input,output):'''处理图片并存储到相应位置INPUTinput 输入图片路径output 输出图片路径OUTPUTNone'''inference_transpiler_program = inference_program.clone()crop.convertjpg(jpgfile=input,outdir=None)l = loadImage(input)result = exe.run(inference_program, feed={feed_target_names[0]: (l)}, fetch_list=fetch_targets)ab = result[0][0]l = l[0][0]img = lab2rgb(l,ab)img = img.astype('float32')plt.grid(False)plt.axis('off')plt.imshow(img)plt.show()plt.savefig(str(output),bbox_inches='tight')if __name__ == '__main__':inpath = './work/try/in/'outpath = './work/try/out/'files = os.listdir(inpath)for file in files:try:run(inpath+file,outpath+file)except Exception as e:print()
项目简介
本项目基于飞桨的结合残差网络(ResNet),通过监督学习的方式,训练模型将黑白图片转换为彩色图片。
ResNet(Residual Network)是2015年ImageNet图像分类、图像物体定位和图像物体检测比赛的冠军。针对随着网络训练加深导致准确度下降的问题,ResNet提出了残差学习方法来减轻训练深层网络的困难。在已有设计思路(BN, 小卷积核,全卷积网络)的基础上,引入了残差模块。每个残差模块包含两条路径,其中一条路径是输入特征的直连通路,另一条路径对该特征做两到三次卷积操作得到该特征的残差,最后再将两条路径上的特征相加。
残差模块如图9所示,左边是基本模块连接方式,由两个输出通道数相同的3x3卷积组成。右边是瓶颈模块(Bottleneck)连接方式,之所以称为瓶颈,是因为上面的1x1卷积用来降维(图示例即256->64),下面的1x1卷积用来升维(图示例即64->256),这样中间3x3卷积的输入和输出通道数都较小(图示例即64->64)。
设计思路:通过训练网络对大量样本的学习得到经验分布(例如天空永远是蓝色的,草永远是绿色的),通过经验分布推得黑白图像上各部分合理的颜色
主要解决问题:大量物体颜色并不是固定的也就是物体颜色具有多模态性(例如:苹果可以是红色也可以是绿色和黄色)。通常使用均方差作为损失函数会让具有颜色多模态属性的物体趋于寻找一个“平均”的颜色(通常为淡黄色)导致着色后的图片饱和度不高。
将Adam优化器beta1参数设置为0.8,具体请参考原论文:
https://arxiv.org/abs/1412.6980
将BatchNorm批归一化中momentum参数设置为0.5。
采用基本模块连接方式。
为抑制多模态问题,在均方差的基础上重新设计损失函数。
损失函数公式如下:
ImageNet项目是一个用于视觉对象识别软件研究的大型可视化数据库。超过1400万的图像URL被ImageNet手动注释,以指示图片中的对象;在至少一百万个图像中,还提供了边界框。ImageNet包含2万多个类别;一个典型的类别,如“气球”或“草莓”,包含数百个图像。第三方图像URL的注释数据库可以直接从ImageNet免费获得;但是实际的图像不属于ImageNet。
自2010年以来,ImageNet项目每年举办一次软件比赛,即ImageNet大规模视觉识别挑战赛(ILSVRC),软件程序竞相正确分类检测物体和场景。ImageNet挑战使用了一个“修剪”的1000个非重叠类的列表。2012年在解决ImageNet挑战方面取得了巨大的突破,被广泛认为是2010年的深度学习革命的开始。
Lab模式是根据Commission International Eclairage(CIE)在1931年所制定的一种测定颜色的国际标准建立的。于1976年被改进并且命名的一种色彩模式。
Lab颜色模型弥补了RGB和CMYK两种色彩模式的不足。它是一种设备无关的颜色模型,也是一种基于生理特征的颜色模型。Lab颜色模型由三个要素组成,一个要素是亮度(L),a 和b是两个颜色通道。a包括的颜色是从深绿色(低亮度值)到灰色(中亮度值)再到亮粉红色(高亮度值);b是从亮蓝色(低亮度值)到灰色(中亮度值)再到黄色(高亮度值)。因此,这种颜色混合后将产生具有明亮效果的色彩。(来源:百度百科)
利用bash对数据集
进行自动化处理
该部分包括重建文件夹、移动和解压数据集、显示数据集中图片数量)(运行时间:约20min)
!mv data/data9402/train.* data/data9244/
!mkdir data/tar
!mkdir work/train
!mkdir work/test
!tar xf data/data9244/ILSVRC2012_img_val.tar -C work/test/
!cd data/tar/;cat ../data9244/train.tar* | tar -x
!cd ./work/train/;ls ../data/tar/*.tar | xargs -n1 tar xf
#显示work/train中图片数量
!find work/train -type f | wc -l
1. 预处理
采用多线程对训练集中单通道图删除(运行时间:约20min)
import os
import imghdr
import numpy as np
from PIL import Image
import threading'''多线程将数据集中单通道图删除'''
def cutArray(l, num):avg &#61; len(l) / float(num)o &#61; []last &#61; 0.0while last < len(l):o.append(l[int(last):int(last &#43; avg)])last &#43;&#61; avgreturn odef deleteErrorImage(path,image_dir):count &#61; 0for file in image_dir:try:image &#61; os.path.join(path,file)image_type &#61; imghdr.what(image)if image_type is not &#39;jpeg&#39;:os.remove(image)count &#61; count &#43; 1img &#61; np.array(Image.open(image))if len(img.shape) is 2:os.remove(image)count &#61; count &#43; 1 except Exception as e:print(e)print(&#39;done!&#39;)print(&#39;已删除数量&#xff1a;&#39; &#43; str(count))class thread(threading.Thread):def __init__(self, threadID, path, files):threading.Thread.__init__(self)self.threadID &#61; threadIDself.path &#61; pathself.files &#61; filesdef run(self):deleteErrorImage(self.path,self.files)if __name__ &#61;&#61; &#39;__main__&#39;:path &#61; &#39;./work/train/&#39;files &#61; os.listdir(path)files &#61; cutArray(files,8)
threadList &#61; []
for i in range(8):
threadList.append(threading.Thread(target&#61;deleteErrorImage,args&#61;(path,files[i])))for t in threadList:t.setDaemon(True)t.start()t.join()
2. 采用多线程对图片进行缩放后裁切到512*512分辨率&#xff08;运行时间&#xff1a;约40min&#xff09;
from PIL import Image
import os.path
import os
import threading
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES &#61; True&#39;&#39;&#39;多线程将图片缩放后再裁切到512*512分辨率&#39;&#39;&#39;
def cutArray(l, num):avg &#61; len(l) / float(num)o &#61; []last &#61; 0.0while last < len(l):o.append(l[int(last):int(last &#43; avg)])last &#43;&#61; avgreturn odef convertjpg(jpgfile,outdir,width&#61;512,height&#61;512):img&#61;Image.open(jpgfile)(l,h) &#61; img.sizerate &#61; min(l,h) / widthtry:img &#61; img.resize((int(l // rate),int(h // rate)),Image.BILINEAR)img &#61; img.crop((0,0,width,height))img.save(os.path.join(outdir,os.path.basename(jpgfile)))except Exception as e:print(e)class thread(threading.Thread):def __init__(self, threadID, inpath, outpath, files):threading.Thread.__init__(self)self.threadID &#61; threadIDself.inpath &#61; inpathself.outpath &#61; outpathself.files &#61; filesdef run(self):count &#61; 0try:for file in self.files:convertjpg(self.inpath &#43; file,self.outpath)count &#61; count &#43; 1except Exception as e:print(e)print(&#39;已处理图片数量&#xff1a;&#39; &#43; str(count))if __name__ &#61;&#61; &#39;__main__&#39;:inpath &#61; &#39;./work/train/&#39;outpath &#61; &#39;./work/train/&#39;files &#61; os.listdir(inpath)files &#61; cutArray(files,8)
T &#61; []
for i in range(8):T.append(thread(i &#43; 1, inpath, outpath, files[i]))
for i in T:
T[i].start()
T[i].join()
导入本项目所需的库
import os
import cv2
import numpy as np
import paddle.dataset as dataset
from skimage import io,color,transform
import sklearn.neighbors as neighbors
import paddle
import paddle.fluid as fluid
import numpy as np
import sys
import os
from skimage import io,color
import matplotlib.pyplot as plt
定义数据预处理工具
DataReader
&#39;&#39;&#39;准备数据&#xff0c;定义Reader()&#39;&#39;&#39;PATH &#61; &#39;work/train/&#39;
TEST &#61; &#39;work/train/&#39;
Q &#61; np.load(&#39;work/Q.npy&#39;)
Weight &#61; np.load(&#39;work/Weight.npy&#39;)class DataGenerater:def __init__(self):datalist &#61; os.listdir(PATH)self.testlist &#61; os.listdir(TEST)#datalist.sort()self.datalist &#61; datalistdef load(self, image):&#39;&#39;&#39;读取图片,并转为Lab&#xff0c;并提取出L和ab&#39;&#39;&#39;img &#61; io.imread(image)lab &#61; np.array(color.rgb2lab(img)).transpose()l &#61; lab[:1,:,:]l &#61; l.astype(&#39;float32&#39;)ab &#61; lab[1:,:,:]ab &#61; ab.astype(&#39;float32&#39;)return l,abdef create_train_reader(self):&#39;&#39;&#39;给dataset定义reader&#39;&#39;&#39;def reader():for img in self.datalist:try:l, ab &#61; self.load(PATH &#43; img)yield l.astype(&#39;float32&#39;), ab.astype(&#39;float32&#39;)except Exception as e:print(e)return readerdef create_test_reader(self,):&#39;&#39;&#39;给test定义reader&#39;&#39;&#39;def reader():for img in self.testlist:l,ab &#61; self.load(TEST &#43; img)yield l.astype(&#39;float32&#39;),ab.astype(&#39;float32&#39;)return reader
def train(batch_sizes &#61; 32):reader &#61; DataGenerater().create_train_reader()return readerdef test():reader &#61; DataGenerater().create_test_reader()return reader
定义网络功能模块并定义网络
本文采用的残差单元如上图所示&#xff0c;由两个输出通道数相同的3x3卷积组成。
网络设计采用3组基本残差模块和2组反卷积层组成
import IPython.display as display
import warnings
warnings.filterwarnings(&#39;ignore&#39;)Q &#61; np.load(&#39;work/Q.npy&#39;)
weight &#61; np.load(&#39;work/Weight.npy&#39;)
Params_dirname &#61; "work/model/gray2color.inference.model"&#39;&#39;&#39;自定义损失函数&#39;&#39;&#39;
def createLoss(predict, truth):&#39;&#39;&#39;均方差&#39;&#39;&#39;loss1 &#61; fluid.layers.square_error_cost(predict,truth)loss2 &#61; fluid.layers.square_error_cost(predict,fluid.layers.fill_constant(shape&#61;[BATCH_SIZE,2,512,512],value&#61;fluid.layers.mean(predict),dtype&#61;&#39;float32&#39;))cost &#61; fluid.layers.mean(loss1) &#43; 16.7 / fluid.layers.mean(loss2)return cost
#组合2D卷积层以及BatchNorm层
def conv_bn_layer(input,ch_out,filter_size,stride,padding,act&#61;&#39;relu&#39;,bias_attr&#61;True):tmp &#61; fluid.layers.conv2d(input&#61;input,filter_size&#61;filter_size,num_filters&#61;ch_out,stride&#61;stride,padding&#61;padding,act&#61;None,bias_attr&#61;bias_attr)return fluid.layers.batch_norm(input&#61;tmp,act&#61;act,momentum&#61;0.5)#定义残差模块&#xff08;由两个输出通道数相同的3x3卷积组成。&#xff09;
def shortcut(input, ch_in, ch_out, stride):if ch_in !&#61; ch_out:return conv_bn_layer(input, ch_out, 1, stride, 0, None)else:return inputdef basicblock(input, ch_in, ch_out, stride):tmp &#61; conv_bn_layer(input, ch_out, 3, stride, 1)tmp &#61; conv_bn_layer(tmp, ch_out, 3, 1, 1, act&#61;None, bias_attr&#61;True)short &#61; shortcut(input, ch_in, ch_out, stride)return fluid.layers.elementwise_add(x&#61;tmp, y&#61;short, act&#61;&#39;relu&#39;)def layer_warp(block_func, input, ch_in, ch_out, count, stride):tmp &#61; block_func(input, ch_in, ch_out, stride)for i in range(1, count):tmp &#61; block_func(tmp, ch_out, ch_out, 1)return tmp###反卷积层
def deconv(x, num_filters, filter_size&#61;5, stride&#61;2, dilation&#61;1, padding&#61;2, output_size&#61;None, act&#61;None):return fluid.layers.conv2d_transpose(input&#61;x,num_filters&#61;num_filters,# 滤波器数量output_size&#61;output_size,# 输出图片大小filter_size&#61;filter_size,# 滤波器大小stride&#61;stride,# 步长dilation&#61;dilation,# 膨胀比例大小padding&#61;padding,use_cudnn&#61;True,# 是否使用cudnn内核act&#61;act# 激活函数)def resnetImagenet(input):res1 &#61; layer_warp(basicblock, input, 64, 128, 1, 2)res2 &#61; layer_warp(basicblock, res1, 128, 256, 1, 2)res3 &#61; layer_warp(basicblock, res2, 256, 512, 4, 1)deconv1 &#61; deconv(res3, num_filters&#61;313, filter_size&#61;4, stride&#61;2, padding&#61;1)deconv2 &#61; deconv(deconv1, num_filters&#61;2, filter_size&#61;4, stride&#61;2, padding&#61;1)return deconv2
训练网络
设置的超参数为&#xff1a;
学习率:2e-5
Epoch:30
Mini-Batch: 10
输入Tensor:[-1,1,512,512]
预训练的预测模型存放路径&#xff1a;
work/model/gray2color.inference.model。
BATCH_SIZE &#61; 10
EPOCH_NUM &#61; 30
def ResNettrain():gray &#61; fluid.layers.data(name&#61;&#39;gray&#39;, shape&#61;[1, 512,512], dtype&#61;&#39;float32&#39;)truth &#61; fluid.layers.data(name&#61;&#39;truth&#39;, shape&#61;[2, 512,512], dtype&#61;&#39;float32&#39;)predict &#61; resnetImagenet(gray)cost &#61; createLoss(predict&#61;predict,truth&#61;truth)return predict,cost&#39;&#39;&#39;optimizer函数&#39;&#39;&#39;
def optimizer_program():return fluid.optimizer.Adam(learning_rate&#61;2e-5,beta1&#61;0.8)train_reader &#61; paddle.batch(paddle.reader.shuffle(reader&#61;train(), buf_size&#61;7500),batch_size&#61;BATCH_SIZE)
test_reader &#61; paddle.batch(reader&#61;test(), batch_size&#61;10)use_cuda &#61; True
if not use_cuda:os.environ[&#39;CPU_NUM&#39;] &#61; str(6)
feed_order &#61; [&#39;gray&#39;, &#39;weight&#39;]
place &#61; fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()main_program &#61; fluid.default_main_program()
star_program &#61; fluid.default_startup_program()&#39;&#39;&#39;网络训练&#39;&#39;&#39;
predict,cost &#61; ResNettrain()&#39;&#39;&#39;优化函数&#39;&#39;&#39;
optimizer &#61; optimizer_program()
optimizer.minimize(cost)exe &#61; fluid.Executor(place)def train_loop():gray &#61; fluid.layers.data(name&#61;&#39;gray&#39;, shape&#61;[1, 512,512], dtype&#61;&#39;float32&#39;)truth &#61; fluid.layers.data(name&#61;&#39;truth&#39;, shape&#61;[2, 512,512], dtype&#61;&#39;float32&#39;)feeder &#61; fluid.DataFeeder(feed_list&#61;[&#39;gray&#39;,&#39;truth&#39;], place&#61;place)exe.run(star_program)#增量训练fluid.io.load_persistables(exe, &#39;work/model/incremental/&#39;, main_program)for pass_id in range(EPOCH_NUM):step &#61; 0for data in train_reader():loss &#61; exe.run(main_program, feed&#61;feeder.feed(data),fetch_list&#61;[cost])step &#43;&#61; 1if step % 1000 &#61;&#61; 0:try:generated_img &#61; exe.run(main_program, feed&#61;feeder.feed(data),fetch_list&#61;[predict])plt.figure(figsize&#61;(15,6))plt.grid(False)for i in range(10):ab &#61; generated_img[0][i]l &#61; data[i][0][0]a &#61; ab[0]b &#61; ab[1]l &#61; l[:, :, np.newaxis]a &#61; a[:, :, np.newaxis].astype(&#39;float64&#39;)b &#61; b[:, :, np.newaxis].astype(&#39;float64&#39;)lab &#61; np.concatenate((l, a, b), axis&#61;2)img &#61; color.lab2rgb((lab))img &#61; transform.rotate(img, 270)img &#61; np.fliplr(img)plt.grid(False)plt.subplot(2, 5, i &#43; 1)plt.imshow(img)plt.axis(&#39;off&#39;)plt.xticks([])plt.yticks([])msg &#61; &#39;Epoch ID&#61;{0} Batch ID&#61;{1} Loss&#61;{2}&#39;.format(pass_id, step, loss[0][0])plt.suptitle(msg,fontsize&#61;20)plt.draw()plt.savefig(&#39;{}/{:04d}_{:04d}.png&#39;.format(&#39;work/output_img&#39;, pass_id, step),bbox_inches&#61;&#39;tight&#39;)plt.pause(0.01)display.clear_output(wait&#61;True)except IOError:print(IOError)fluid.io.save_persistables(exe,&#39;work/model/incremental/&#39;,main_program)fluid.io.save_inference_model(Params_dirname, ["gray"],[predict], exe)
train_loop()
项目总结
对于训练结果&#xff0c;虽然本项目通过抑制平均化加大了离散程度&#xff0c;提高了着色的饱和度&#xff0c;但最终结果仍然有较大现实差距&#xff0c;只能对部分场景有比较好的结果&#xff0c;对人造场景&#xff08;如超市景观等&#xff09;仍然表现力不足。接下来作者准备进一步去设计损失函数&#xff0c;目的是让网络着色结果足以欺骗人的”直觉感受“&#xff0c;而不是一味地接近真实场景。
如在使用过程中有问题&#xff0c;可加入飞桨官方QQ群进行交流&#xff1a;1108045677。
如果您想详细了解更多飞桨的相关内容&#xff0c;请参阅以下文档。
·飞桨开源框架项目地址·
GitHub:
https://github.com/PaddlePaddle/Paddle
Gitee:
https://gitee.com/paddlepaddle/Paddle
·飞桨官网地址·
https://www.paddlepaddle.org.cn/
扫描二维码 &#xff5c; 关注我们
微信号 : PaddleOpenSource
END
今明精彩直播