Mscanet Model
Mscanet Model
"""
Created on Wed Apr 10 09:57:49 2019
@author: Fsl
"""
from scipy import ndimage
import torch
from torchvision import models
import torch.nn as nn
# from .resnet import resnet34
# from resnet import resnet34
# import resnet
from torch.nn import functional as F
# import torchsummary
from torch.nn import init
import numpy as np
from functools import partial
from thop import profile
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
BatchNorm2d = nn.BatchNorm2d
class SpatialAttentionBlock(nn.Module):
def __init__(self, in_channels):
super(SpatialAttentionBlock, self).__init__()
self.query = nn.Sequential(
nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)),
nn.BatchNorm2d(in_channels//8),
nn.ReLU(inplace=True)
)
self.key = nn.Sequential(
nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1),
padding=(1,0)),
nn.BatchNorm2d(in_channels//8),
nn.ReLU(inplace=True)
)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
class AffinityAttention2(nn.Module):
""" Affinity attention module """
class UnetDsv3(nn.Module):
def __init__(self, in_size, out_size, scale_factor):
super(UnetDsv3, self).__init__()
self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1,
stride=1, padding=0),
nn.Upsample(size=scale_factor, mode='bilinear'), )
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0,
dilation=1, groups=1,
relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn else None
self.relu = nn.ReLU() if relu else None
class Scale_Aware(nn.Module):
def __init__(self, in_channels):
super(Scale_Aware, self).__init__()
# self.bn = nn.ModuleList([nn.BatchNorm2d(in_channels),
nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels)])
self.conv1x1 = nn.ModuleList(
[nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels,
dilation=1, kernel_size=1, padding=0),
nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels,
dilation=1, kernel_size=1, padding=0)])
self.conv3x3_1 = nn.ModuleList(
[nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2,
dilation=1, kernel_size=3, padding=1),
nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2,
dilation=1, kernel_size=3, padding=1)])
self.conv3x3_2 = nn.ModuleList(
[nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1,
kernel_size=3, padding=1),
nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1,
kernel_size=3, padding=1)])
self.conv_last = ConvBnRelu(in_planes=in_channels, out_planes=in_channels,
ksize=1, stride=1, pad=0, dilation=1)
self.relu = nn.ReLU()
def forward(self, x_l, x_h):
feat = torch.cat([x_l, x_h], dim=1)
# feat=feat_cat.detach()
feat = self.relu(self.conv1x1[0](feat))
feat = self.relu(self.conv3x3_1[0](feat))
att = self.conv3x3_2[0](feat)
att = F.softmax(att, dim=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
if self.scale > 1:
x = F.interpolate(x, scale_factor=self.scale,
mode='bilinear',
align_corners=True)
fm = self.conv_1x1_3x3(x)
# fm = self.dropout(fm)
output = self.conv_1x1_2(fm)
return output
class ConvBnRelu(nn.Module):
def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
has_relu=True, inplace=True, has_bias=False):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
stride=stride, padding=pad,
dilation=dilation, groups=groups, bias=has_bias)
self.has_bn = has_bn
if self.has_bn:
self.bn = nn.BatchNorm2d(out_planes)
self.has_relu = has_relu
if self.has_relu:
self.relu = nn.ReLU(inplace=inplace)
return x
class GlobalAvgPool2d(nn.Module):
def __init__(self):
"""Global average pooling over the input's spatial dimensions"""
super(GlobalAvgPool2d, self).__init__()
return inputs
class Local_Channel(nn.Module):
def __init__(self, in_channel):
super(Local_Channel, self).__init__()
self.attn = nn.Sequential(GlobalAvgPool2d(), nn.Conv2d(in_channel,
in_channel, 1), nn.Sigmoid())
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
attn_map = self.attn(x)
return x * (1 - self.gamma) + attn_map * x * self.gamma, attn_map
class Local_Spatial(nn.Module):
def __init__(self, in_channel, mid_channel):
super(Local_Spatial, self).__init__()
self.conv1x1 = nn.Conv2d(in_channel, mid_channel, 1)
self.branch1 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 1, 1)
self.branch2 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 2, 2)
self.branch3 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 3, 3)
self.attn = nn.Conv2d(3 * mid_channel, 1, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
mid = self.conv1x1(x)
branch1 = self.branch1(mid)
branch2 = self.branch2(mid)
branch3 = self.branch3(mid)
branch123 = torch.cat([branch1, branch2, branch3], dim=1)
attn_map = self.attn(branch123)
return x * (1 - self.gamma) + attn_map * x * self.gamma, attn_map
class Up(nn.Module):
"""Upscaling then double conv"""
class Bridge(nn.Module):
def __init__(self, in_channels_1, in_channels_2, in_channels_3, mid_channels):
super(Bridge, self).__init__()
self.mid_channels = mid_channels
self.conv_qk1 = nn.Conv2d(in_channels_1, mid_channels, 1, 1, 0)
self.conv_qk2 = nn.Conv2d(in_channels_2, mid_channels, 1, 1, 0)
self.conv_qk3 = nn.Conv2d(in_channels_3, mid_channels, 1, 1, 0)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()
self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride,
padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride,
padding=1),
nn.BatchNorm2d(output_dim),
)
class ConvBnRelu(nn.Module):
def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
has_relu=True, inplace=True, has_bias=False):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
stride=stride, padding=pad,
dilation=dilation, groups=groups, bias=has_bias)
self.has_bn = has_bn
if self.has_bn:
self.bn = nn.BatchNorm2d(out_planes)
self.has_relu = has_relu
if self.has_relu:
self.relu = nn.ReLU(inplace=inplace)
return x
class DecoderBlock(nn.Module):
def __init__(self, in_planes, out_planes,
norm_layer=nn.BatchNorm2d,scale=2,relu=True,last=False):
super(DecoderBlock, self).__init__()
self.sap=SAPblock(in_planes)
self.scale=scale
self.last=last
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)
if self.last==False:
x = self.conv_3x3(x)
# x=self.sap(x)
if self.scale>1:
x=F.interpolate(x,scale_factor=self.scale,mode='bilinear',align_corners=True)
x=self.conv_1x1(x)
return x
class msca_net(nn.Module):
def __init__(self, classes=2, channels=3, ccm=True, norm_layer=nn.BatchNorm2d,
is_training=True, expansion=2,
base_channel=32):
super(msca_net, self).__init__()
self.backbone = models.resnet34(pretrained=True)
# self.backbone =resnet34(pretrained=False)
self.expansion = expansion
self.base_channel = base_channel
if self.expansion == 4 and self.base_channel == 64:
expan = [512, 1024, 2048]
spatial_ch = [128, 256]
elif self.expansion == 4 and self.base_channel == 32:
expan = [256, 512, 1024]
spatial_ch = [32, 128]
conv_channel_up = [256, 384, 512]
elif self.expansion == 2 and self.base_channel == 32:
expan = [128, 256, 512]
spatial_ch = [64, 64]
conv_channel_up = [128, 256, 512]
conv_channel = expan[0]
self.is_training = is_training
# self.sap = SAPblock(expan[-1])
bilinear =True
factor = 2
self.up1 = Up(768, 512 // factor, bilinear)
self.up2 = Up(384, 256 // factor, bilinear)
self.up3 = Up(192, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
# self.relu = nn.ReLU()
# self.fpt = FPT(feature_dim=4)
self.sw1 = Scale_Aware(in_channels=64)
self.sw2 = Scale_Aware(in_channels=64)
self.sw3 = Scale_Aware(in_channels=64)
self.affinity_attention = AffinityAttention2(512)
self.cbam = CBAM_Module2()
self.gamma1 = nn.Parameter(torch.zeros(1))
self.gamma2 = nn.Parameter(torch.zeros(1))
self.gamma3 = nn.Parameter(torch.zeros(1))
x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
c1 = self.backbone.relu(x) # 1/2 64
x = self.backbone.maxpool(c1)
c2 = self.backbone.layer1(x) # 1/4 64
c3 = self.backbone.layer2(c2) # 1/8 128
c4 = self.backbone.layer3(c3) # 1/16 256
c5 = self.backbone.layer4(c4) # 1/32 512
# d_bottom=self.bottom(c5)
# c5 = self.sap(c5)
attention = self.affinity_attention(c5)
cbam_attn = self.cbam(c5)
# l_channel, _ = self.l_channel(c5)
# l_spatial, _ = self.l_spatial(c5)
c5 = self.gamma1 * attention + self.gamma2 * cbam_attn + self.gamma3 * c5#多
种并行方式, 用不用 bn relu, 用不用 scale aware
# d5=d_bottom+c5 #512
dsv4 = self.dsv4(d4)
dsv3 = self.dsv3(d3)
dsv2 = self.dsv2(d2)
dsv1 = self.dsv1(d1)
main_out = self.main_head(dsv4321)
final = F.sigmoid(main_out)
return final