作者:静静敲代码 | 来源:互联网 | 2023-10-11 21:10
DSLR-QualityPhotosonMobileDeviceswithDeepConvolutionalNetworks---colorloss-Pytorch实现1.实现原理
DSLR-Quality Photos on Mobile Devices with Deep Convolutional Networks---colorloss-Pytorch实现
1.实现原理
最近在做图像增强相关的工作,偶然间看到了这篇文章,作者提出了一个损失叫做color-loss,根据文章描述该方法是通过模糊输入图像与ground-Truth的纹理、内容,仅仅保存图像的颜色信息实现图像颜色的校正。实现过程比较简单,首先构建一个高斯模糊核,然后利用高斯模糊核作为卷积核对图像进行卷积运算,得到模糊后的图像;然后计算输入图像与ground-Truth的MSE作为损失函数。
作者的github中有该模型的代码,但是是用TensorFlow实现的。因为我的代码pytorch的,所以自己重新改写了一下。在作者的代码中用到了深度可分离卷积,在pytorch中我没有对其进行深度可分离操作。
算是个深度学习的小白吧,有问题可以给我留言呀~~~
2.代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from math import exp, pi
import numpy as np
import cv2 as cv
import scipy.stats as st
import matplotlib.pyplot as pltdef gauss_kernel(kernlen=21, nsig=3, channels=1):interval = (2*nsig+1.)/(kernlen)x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)kern1d = np.diff(st.norm.cdf(x))kernel_raw = np.sqrt(np.outer(kern1d, kern1d))kernel = kernel_raw/kernel_raw.sum()out_filter = np.array(kernel, dtype = np.float32)out_filter = out_filter.reshape((kernlen, kernlen))# out_filter = np.repeat(out_filter, channels, axis = 0)return out_filter # kernel_size=21class SeparableConv2d(nn.Module):def __init__(self):super(SeparableConv2d, self).__init__()kernel = gauss_kernel(21, 3, 3)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)## kernel_point = [[1.0]]## kernel_point = torch.FloatTensor(kernel_point).unsqueeze(0).unsqueeze(0)# kernel = torch.FloatTensor(kernel).expand(3, 3, 21, 21) # torch.expand()向输入的维度前面进行扩充,输入为三通道时,将weight扩展为[3,3,21,21]## kernel_point = torch.FloatTensor(kernel_point).expand(3,3,1,1)self.weight = nn.Parameter(data=kernel, requires_grad=False)# self.pointwise = nn.Conv2d(1, 1, 1, 1, 0, 1, 1,bias=False) # 单通道时in_channels=1,out_channels=1,三通道时,in_channels=3, out_channels=3 卷积核为随机的## self.weight_point = nn.Parameter(data=kernel_point, requires_grad=False)def forward(self, img1):x = F.conv2d(img1, self.weight, groups=1,padding=10)## x = F.conv2d(x, self.weight_point, groups=1, padding=0) #卷积核为[1]# x = self.pointwise(x)return x
# plt.imshow(out_kernel)
# plt.imshow(out_kernel)