0% found this document useful (0 votes)
6 views12 pages

Mscanet Model

The document contains a PyTorch implementation of various neural network components, including attention mechanisms like SpatialAttentionBlock and ChannelAttentionBlock, as well as a U-Net architecture with additional features such as scale-aware modules and convolutional blocks. It defines multiple classes for building complex models, focusing on enhancing feature representation through attention and convolutional operations. Overall, the document serves as a foundation for constructing advanced deep learning models for tasks such as image segmentation.

Uploaded by

Sapna
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
6 views12 pages

Mscanet Model

The document contains a PyTorch implementation of various neural network components, including attention mechanisms like SpatialAttentionBlock and ChannelAttentionBlock, as well as a U-Net architecture with additional features such as scale-aware modules and convolutional blocks. It defines multiple classes for building complex models, focusing on enhancing feature representation through attention and convolutional operations. Overall, the document serves as a foundation for constructing advanced deep learning models for tasks such as image segmentation.

Uploaded by

Sapna
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 12

# -*- coding: utf-8 -*-

"""
Created on Wed Apr 10 09:57:49 2019

@author: Fsl
"""
from scipy import ndimage
import torch
from torchvision import models
import torch.nn as nn
# from .resnet import resnet34
# from resnet import resnet34
# import resnet
from torch.nn import functional as F
# import torchsummary
from torch.nn import init
import numpy as np
from functools import partial
from thop import profile
up_kwargs = {'mode': 'bilinear', 'align_corners': True}
BatchNorm2d = nn.BatchNorm2d

class SpatialAttentionBlock(nn.Module):
def __init__(self, in_channels):
super(SpatialAttentionBlock, self).__init__()
self.query = nn.Sequential(
nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)),
nn.BatchNorm2d(in_channels//8),
nn.ReLU(inplace=True)
)
self.key = nn.Sequential(
nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1),
padding=(1,0)),
nn.BatchNorm2d(in_channels//8),
nn.ReLU(inplace=True)
)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)

def forward(self, x):


"""
:param x: input( BxCxHxW )
:return: affinity value + x
"""
B, C, H, W = x.size()
# compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose
proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1)
proj_key = self.key(x).view(B, -1, W * H)
affinity = torch.matmul(proj_query, proj_key)
affinity = self.softmax(affinity)
proj_value = self.value(x).view(B, -1, H * W)
weights = torch.matmul(proj_value, affinity.permute(0, 2, 1))
weights = weights.view(B, C, H, W)
out = self.gamma * weights + x
return out
class ChannelAttentionBlock(nn.Module):
def __init__(self, in_channels):
super(ChannelAttentionBlock, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)

def forward(self, x):


"""
:param x: input( BxCxHxW )
:return: affinity value + x
"""
B, C, H, W = x.size()
proj_query = x.view(B, C, -1)
proj_key = x.view(B, C, -1).permute(0, 2, 1)
affinity = torch.matmul(proj_query, proj_key)
affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity)
- affinity
affinity_new = self.softmax(affinity_new)
proj_value = x.view(B, C, -1)
weights = torch.matmul(affinity_new, proj_value)
weights = weights.view(B, C, H, W)
out = self.gamma * weights + x
return out

class AffinityAttention2(nn.Module):
""" Affinity attention module """

def __init__(self, in_channels):


super(AffinityAttention2, self).__init__()
self.sab = SpatialAttentionBlock(in_channels)
self.cab = ChannelAttentionBlock(in_channels)
# self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)

def forward(self, x):


"""
sab: spatial attention block
cab: channel attention block
:param x: input tensor
:return: sab + cab
"""
sab = self.sab(x)
cab = self.cab(sab)
out = sab + cab
return out

class UnetDsv3(nn.Module):
def __init__(self, in_size, out_size, scale_factor):
super(UnetDsv3, self).__init__()
self.dsv = nn.Sequential(nn.Conv2d(in_size, out_size, kernel_size=1,
stride=1, padding=0),
nn.Upsample(size=scale_factor, mode='bilinear'), )

def forward(self, input):


return self.dsv(input)

class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0,
dilation=1, groups=1,
relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn else None
self.relu = nn.ReLU() if relu else None

def forward(self, x):


x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x

class Scale_Aware(nn.Module):
def __init__(self, in_channels):
super(Scale_Aware, self).__init__()

# self.bn = nn.ModuleList([nn.BatchNorm2d(in_channels),
nn.BatchNorm2d(in_channels), nn.BatchNorm2d(in_channels)])
self.conv1x1 = nn.ModuleList(
[nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels,
dilation=1, kernel_size=1, padding=0),
nn.Conv2d(in_channels=2 * in_channels, out_channels=in_channels,
dilation=1, kernel_size=1, padding=0)])
self.conv3x3_1 = nn.ModuleList(
[nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2,
dilation=1, kernel_size=3, padding=1),
nn.Conv2d(in_channels=in_channels, out_channels=in_channels // 2,
dilation=1, kernel_size=3, padding=1)])
self.conv3x3_2 = nn.ModuleList(
[nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1,
kernel_size=3, padding=1),
nn.Conv2d(in_channels=in_channels // 2, out_channels=2, dilation=1,
kernel_size=3, padding=1)])
self.conv_last = ConvBnRelu(in_planes=in_channels, out_planes=in_channels,
ksize=1, stride=1, pad=0, dilation=1)

self.relu = nn.ReLU()
def forward(self, x_l, x_h):
feat = torch.cat([x_l, x_h], dim=1)
# feat=feat_cat.detach()
feat = self.relu(self.conv1x1[0](feat))
feat = self.relu(self.conv3x3_1[0](feat))
att = self.conv3x3_2[0](feat)
att = F.softmax(att, dim=1)

att_1 = att[:, 0, :, :].unsqueeze(1)


att_2 = att[:, 1, :, :].unsqueeze(1)

fusion_1_2 = att_1 * x_l + att_2 * x_h


return fusion_1_2
class BaseNetHead(nn.Module):
def __init__(self, in_planes, out_planes, scale,
is_aux=False, norm_layer=nn.BatchNorm2d):
super(BaseNetHead, self).__init__()
if is_aux:
self.conv_1x1_3x3=nn.Sequential(
ConvBnRelu(in_planes, 64, 1, 1, 0,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False),
ConvBnRelu(64, 64, 3, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False))
else:
self.conv_1x1_3x3=nn.Sequential(
ConvBnRelu(in_planes, 32, 1, 1, 0,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False),
ConvBnRelu(32, 32, 3, 1, 1,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False))
# self.dropout = nn.Dropout(0.1)
if is_aux:
self.conv_1x1_2 = nn.Conv2d(64, out_planes, kernel_size=1,
stride=1, padding=0)
else:
self.conv_1x1_2 = nn.Conv2d(32, out_planes, kernel_size=1,
stride=1, padding=0)
self.scale = scale

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)

def forward(self, x):

if self.scale > 1:
x = F.interpolate(x, scale_factor=self.scale,
mode='bilinear',
align_corners=True)
fm = self.conv_1x1_3x3(x)
# fm = self.dropout(fm)
output = self.conv_1x1_2(fm)
return output

class ConvBnRelu(nn.Module):
def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
has_relu=True, inplace=True, has_bias=False):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
stride=stride, padding=pad,
dilation=dilation, groups=groups, bias=has_bias)
self.has_bn = has_bn
if self.has_bn:
self.bn = nn.BatchNorm2d(out_planes)
self.has_relu = has_relu
if self.has_relu:
self.relu = nn.ReLU(inplace=inplace)

def forward(self, x):


x = self.conv(x)
if self.has_bn:
x = self.bn(x)
if self.has_relu:
x = self.relu(x)

return x

class GlobalAvgPool2d(nn.Module):
def __init__(self):
"""Global average pooling over the input's spatial dimensions"""
super(GlobalAvgPool2d, self).__init__()

def forward(self, inputs):


in_size = inputs.size()
inputs = inputs.view((in_size[0], in_size[1], -1)).mean(dim=2)
inputs = inputs.view(in_size[0], in_size[1], 1, 1)

return inputs

class Local_Channel(nn.Module):
def __init__(self, in_channel):
super(Local_Channel, self).__init__()
self.attn = nn.Sequential(GlobalAvgPool2d(), nn.Conv2d(in_channel,
in_channel, 1), nn.Sigmoid())
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
attn_map = self.attn(x)
return x * (1 - self.gamma) + attn_map * x * self.gamma, attn_map

class Local_Spatial(nn.Module):
def __init__(self, in_channel, mid_channel):
super(Local_Spatial, self).__init__()
self.conv1x1 = nn.Conv2d(in_channel, mid_channel, 1)
self.branch1 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 1, 1)
self.branch2 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 2, 2)
self.branch3 = nn.Conv2d(mid_channel, mid_channel, 3, 1, 3, 3)
self.attn = nn.Conv2d(3 * mid_channel, 1, 1)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
mid = self.conv1x1(x)
branch1 = self.branch1(mid)
branch2 = self.branch2(mid)
branch3 = self.branch3(mid)
branch123 = torch.cat([branch1, branch2, branch3], dim=1)
attn_map = self.attn(branch123)
return x * (1 - self.gamma) + attn_map * x * self.gamma, attn_map

nonlinearity = partial(F.relu, inplace=True)


class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""

def __init__(self, in_channels, out_channels, mid_channels=None):


super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)

def forward(self, x):


return self.double_conv(x)

class Up(nn.Module):
"""Upscaling then double conv"""

def __init__(self, in_channels, out_channels, bilinear=True):


super().__init__()

# if bilinear, use the normal convolutions to reduce the number of channels


if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear',
align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2,
kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)

def forward(self, x1, x2):


x1 = self.up(x1)
# input is CHW
# diffY = x2.size()[2] - x1.size()[2]
# diffX = x2.size()[3] - x1.size()[3]
#
# x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
# diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/
0e854509c2cea854e247a9c615f175f76fbb2e3a
#
https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513
d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class CBAM_Module2(nn.Module):
def __init__(self, channels=512, reduction=2):
super(CBAM_Module2, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
padding=0)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
padding=0)
self.sigmoid_channel = nn.Sigmoid()
self.conv_after_concat = nn.Conv2d(2, 1, kernel_size=7, stride=1,
padding=3)
self.sigmoid_spatial = nn.Sigmoid()

def forward(self, x):


# Channel Attention module
module_input = x
avg = self.avg_pool(x)
mx = self.max_pool(x)
avg = self.fc1(avg)
mx = self.fc1(mx)
avg = self.relu(avg)
mx = self.relu(mx)
avg = self.fc2(avg)
mx = self.fc2(mx)
x = avg + mx
x = self.sigmoid_channel(x)
# Spatial Attention module
x = module_input * x + module_input
module_input = x
avg = torch.mean(x, 1, True)
mx, _ = torch.max(x, 1, True)
x = torch.cat((avg, mx), 1)
x = self.conv_after_concat(x)
x = self.sigmoid_spatial(x)
x = module_input * x + module_input
return x

class Bridge(nn.Module):
def __init__(self, in_channels_1, in_channels_2, in_channels_3, mid_channels):
super(Bridge, self).__init__()
self.mid_channels = mid_channels
self.conv_qk1 = nn.Conv2d(in_channels_1, mid_channels, 1, 1, 0)
self.conv_qk2 = nn.Conv2d(in_channels_2, mid_channels, 1, 1, 0)
self.conv_qk3 = nn.Conv2d(in_channels_3, mid_channels, 1, 1, 0)

self.conv_v1 = nn.Conv2d(in_channels_1, mid_channels, 1, 1, 0)


self.conv_v2 = nn.Conv2d(in_channels_2, mid_channels, 1, 1, 0)
self.conv_v3 = nn.Conv2d(in_channels_3, mid_channels, 1, 1, 0)

self.conv_out1 = nn.Conv2d(2 * mid_channels + in_channels_1, in_channels_1,


1, 1, 0)
self.conv_out2 = nn.Conv2d(2 * mid_channels + in_channels_2, in_channels_2,
1, 1, 0)
self.conv_out3 = nn.Conv2d(2 * mid_channels + in_channels_3, in_channels_3,
1, 1, 0)
def forward(self, f1, f2, f3):
batch_size = f1.size(0)
qk1 = self.conv_qk1(f1).view(batch_size, self.mid_channels, -1)
qk2 = self.conv_qk2(f2).view(batch_size, self.mid_channels, -1)
qk3 = self.conv_qk3(f3).view(batch_size, self.mid_channels, -1)

v1 = self.conv_v1(f1).view(batch_size, self.mid_channels, -1)


v2 = self.conv_v2(f2).view(batch_size, self.mid_channels, -1)
v3 = self.conv_v3(f3).view(batch_size, self.mid_channels, -1)

sim12 = torch.matmul(qk1.permute(0, 2, 1), qk2)


sim23 = torch.matmul(qk2.permute(0, 2, 1), qk3)
sim31 = torch.matmul(qk3.permute(0, 2, 1), qk1)

attn12 = F.softmax(sim12, dim=-1)


attn21 = F.softmax(sim12.permute(0, 2, 1), dim=-1)
attn23 = F.softmax(sim23, dim=-1)
attn32 = F.softmax(sim23.permute(0, 2, 1), dim=-1)
attn31 = F.softmax(sim31, dim=-1)
attn13 = F.softmax(sim31.permute(0, 2, 1), dim=-1)

y12 = torch.matmul(v1, attn12).contiguous()


y13 = torch.matmul(v1, attn13).contiguous()
y21 = torch.matmul(v2, attn21).contiguous()
y23 = torch.matmul(v2, attn23).contiguous()
y31 = torch.matmul(v3, attn31).contiguous()
y32 = torch.matmul(v3, attn32).contiguous()

y12 = y12.view(batch_size, self.mid_channels, int(f2.size()[2]),


int(f2.size()[3]))
y13 = y13.view(batch_size, self.mid_channels, int(f3.size()[2]),
int(f3.size()[3]))
y21 = y21.view(batch_size, self.mid_channels, int(f1.size()[2]),
int(f1.size()[3]))
y23 = y23.view(batch_size, self.mid_channels, int(f3.size()[2]),
int(f3.size()[3]))
y31 = y31.view(batch_size, self.mid_channels, int(f1.size()[2]),
int(f1.size()[3]))
y32 = y32.view(batch_size, self.mid_channels, int(f2.size()[2]),
int(f2.size()[3]))

out1 = self.conv_out1(torch.cat([f1, y31, y21], dim=1))


out2 = self.conv_out2(torch.cat([f2, y12, y32], dim=1))
out3 = self.conv_out3(torch.cat([f3, y23, y13], dim=1))

return out1, out2, out3

class Down(nn.Module):
"""Downscaling with maxpool then double conv"""

def __init__(self, in_channels, out_channels):


super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)

class ResidualConv(nn.Module):
def __init__(self, input_dim, output_dim, stride, padding):
super(ResidualConv, self).__init__()

self.conv_block = nn.Sequential(
nn.BatchNorm2d(input_dim),
nn.ReLU(),
nn.Conv2d(
input_dim, output_dim, kernel_size=3, stride=stride,
padding=padding
),
nn.BatchNorm2d(output_dim),
nn.ReLU(),
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
)
self.conv_skip = nn.Sequential(
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride,
padding=1),
nn.BatchNorm2d(output_dim),
)

def forward(self, x):

return self.conv_block(x) + self.conv_skip(x)

class ConvBnRelu(nn.Module):
def __init__(self, in_planes, out_planes, ksize, stride, pad, dilation=1,
groups=1, has_bn=True, norm_layer=nn.BatchNorm2d,
has_relu=True, inplace=True, has_bias=False):
super(ConvBnRelu, self).__init__()
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=ksize,
stride=stride, padding=pad,
dilation=dilation, groups=groups, bias=has_bias)
self.has_bn = has_bn
if self.has_bn:
self.bn = nn.BatchNorm2d(out_planes)
self.has_relu = has_relu
if self.has_relu:
self.relu = nn.ReLU(inplace=inplace)

def forward(self, x):


x = self.conv(x)
if self.has_bn:
x = self.bn(x)
if self.has_relu:
x = self.relu(x)

return x

class DecoderBlock(nn.Module):
def __init__(self, in_planes, out_planes,
norm_layer=nn.BatchNorm2d,scale=2,relu=True,last=False):
super(DecoderBlock, self).__init__()

self.conv_3x3 = ConvBnRelu(in_planes, in_planes, 3, 1, 1,


has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)
self.conv_1x1 = ConvBnRelu(in_planes, out_planes, 1, 1, 0,
has_bn=True, norm_layer=norm_layer,
has_relu=True, has_bias=False)

self.sap=SAPblock(in_planes)
self.scale=scale
self.last=last

for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
init.normal_(m.weight.data, 1.0, 0.02)
init.constant_(m.bias.data, 0.0)

def forward(self, x):

if self.last==False:
x = self.conv_3x3(x)
# x=self.sap(x)
if self.scale>1:

x=F.interpolate(x,scale_factor=self.scale,mode='bilinear',align_corners=True)
x=self.conv_1x1(x)
return x

class msca_net(nn.Module):
def __init__(self, classes=2, channels=3, ccm=True, norm_layer=nn.BatchNorm2d,
is_training=True, expansion=2,
base_channel=32):
super(msca_net, self).__init__()
self.backbone = models.resnet34(pretrained=True)
# self.backbone =resnet34(pretrained=False)
self.expansion = expansion
self.base_channel = base_channel
if self.expansion == 4 and self.base_channel == 64:
expan = [512, 1024, 2048]
spatial_ch = [128, 256]
elif self.expansion == 4 and self.base_channel == 32:
expan = [256, 512, 1024]
spatial_ch = [32, 128]
conv_channel_up = [256, 384, 512]
elif self.expansion == 2 and self.base_channel == 32:
expan = [128, 256, 512]
spatial_ch = [64, 64]
conv_channel_up = [128, 256, 512]

conv_channel = expan[0]

self.is_training = is_training
# self.sap = SAPblock(expan[-1])

# self.decoder5 = DecoderBlock(expan[-1], expan[-2], relu=False, last=True)


# 256
# self.decoder4 = DecoderBlock(expan[-2], expan[-3], relu=False) # 128
# self.decoder3 = DecoderBlock(expan[-3], spatial_ch[-1], relu=False) # 64
# self.decoder2 = DecoderBlock(spatial_ch[-1], spatial_ch[-2]) # 32

bilinear =True
factor = 2
self.up1 = Up(768, 512 // factor, bilinear)
self.up2 = Up(384, 256 // factor, bilinear)
self.up3 = Up(192, 64, bilinear)
self.up4 = Up(128, 64, bilinear)

self.main_head = BaseNetHead(64, classes, 2,


is_aux=False, norm_layer=norm_layer)

# self.relu = nn.ReLU()

# self.fpt = FPT(feature_dim=4)

filters = [64, 64, 128, 256]


self.out_size = (112, 160)
self.dsv4 = UnetDsv3(in_size=filters[3], out_size=64,
scale_factor=self.out_size)
self.dsv3 = UnetDsv3(in_size=filters[2], out_size=64,
scale_factor=self.out_size)
self.dsv2 = UnetDsv3(in_size=filters[1], out_size=64,
scale_factor=self.out_size)
self.dsv1 = nn.Conv2d(in_channels=filters[0], out_channels=64,
kernel_size=1)

self.sw1 = Scale_Aware(in_channels=64)
self.sw2 = Scale_Aware(in_channels=64)
self.sw3 = Scale_Aware(in_channels=64)

self.affinity_attention = AffinityAttention2(512)
self.cbam = CBAM_Module2()
self.gamma1 = nn.Parameter(torch.zeros(1))
self.gamma2 = nn.Parameter(torch.zeros(1))
self.gamma3 = nn.Parameter(torch.zeros(1))

self.bridge = Bridge(64, 128, 256, 64)


def forward(self, x):

x = self.backbone.conv1(x)
x = self.backbone.bn1(x)
c1 = self.backbone.relu(x) # 1/2 64

x = self.backbone.maxpool(c1)
c2 = self.backbone.layer1(x) # 1/4 64
c3 = self.backbone.layer2(c2) # 1/8 128
c4 = self.backbone.layer3(c3) # 1/16 256
c5 = self.backbone.layer4(c4) # 1/32 512
# d_bottom=self.bottom(c5)

# m1, m2, m3, m4 = self.fpt(c1, c2, c3, c4)


m2, m3, m4 = self.bridge(c2, c3, c4)

# c5 = self.sap(c5)
attention = self.affinity_attention(c5)
cbam_attn = self.cbam(c5)
# l_channel, _ = self.l_channel(c5)
# l_spatial, _ = self.l_spatial(c5)
c5 = self.gamma1 * attention + self.gamma2 * cbam_attn + self.gamma3 * c5#多
种并行方式, 用不用 bn relu, 用不用 scale aware

# d5=d_bottom+c5 #512

# d4 = self.relu(self.decoder5(c5) + m4) # 256


# d3 = self.relu(self.decoder4(d4) + m3) # 128
# d2 = self.relu(self.decoder3(d3) + m2) # 64
# d1 = self.decoder2(d2) + m1 # 32
d4 = self.up1(c5, m4)
d3 = self.up2(d4, m3)
d2 = self.up3(d3, m2)
d1 = self.up4(d2, c1)

dsv4 = self.dsv4(d4)
dsv3 = self.dsv3(d3)
dsv2 = self.dsv2(d2)
dsv1 = self.dsv1(d1)

dsv43 = self.sw1(dsv4, dsv3)


dsv432 = self.sw2(dsv43, dsv2)
dsv4321 = self.sw3(dsv432, dsv1)

main_out = self.main_head(dsv4321)

final = F.sigmoid(main_out)

return final

You might also like