RCAB 模块参考 https://blog.csdn.net/qq_41251963/article/details/120195167
## Residual Group (RG)
class ResidualGroup(nn.Module):def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):super(ResidualGroup, self).__init__()modules_body = []modules_body = [RCAB(conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \for _ in range(n_resblocks)]modules_body.append(conv(n_feat, n_feat, kernel_size))self.body = nn.Sequential(*modules_body)def forward(self, x):res = self.body(x)res += xreturn res
## Residual Channel Attention Network (RCAN)
class RCAN(nn.Module):def __init__(self, args, conv=common.default_conv):super(RCAN, self).__init__()n_resgroups = args.n_resgroupsn_resblocks = args.n_resblocksn_feats = args.n_featskernel_size = 3reduction = args.reduction scale = args.scale[0]act = nn.ReLU(True)# RGB mean for DIV2Krgb_mean = (0.4488, 0.4371, 0.4040)rgb_std = (1.0, 1.0, 1.0)self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std)# define head modulemodules_head = [conv(args.n_colors, n_feats, kernel_size)]# define body modulemodules_body = [ResidualGroup(conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \for _ in range(n_resgroups)]modules_body.append(conv(n_feats, n_feats, kernel_size))# define tail modulemodules_tail = [common.Upsampler(conv, scale, n_feats, act=False),conv(n_feats, args.n_colors, kernel_size)]self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)self.head = nn.Sequential(*modules_head)self.body = nn.Sequential(*modules_body)self.tail = nn.Sequential(*modules_tail)def forward(self, x):x = self.sub_mean(x)x = self.head(x)res = self.body(x)res += xx = self.tail(res)x = self.add_mean(x)return x