本次实践的内容为,参考UNet模型结构图,完整复现模型, 

 1.分析模型

上图为UNet原文中的模型结构图,详细的给出了每一层的输出形状以及模型的详细模块组建方式。为了清晰明了的复现模型且便于测试每一个部分的正确性,故将模型划分为一下几个部分:

一、UNet

        1.DownLayer

                1.1 Double Convolution 2D (2Con)

                1.2 Max Pooling (MP)

        2.UpLayer

                2.1 Double Convolution 2D (2Con)

                2.2 Decovolution 2D (Decon)

                2.3 Concat Method 

        3. Convolution 

值得注意的是,模型中包含了一种特别的Concat方法,这个方法是迁就上采样输入, 进行通道堆叠方法,因此会失去部分来自copy and crop的特征值。 

 2.复现代码

2.1 Double Convolution 2D.

class DoubleConv2d(nn.Module):
    def __init__(self, in_c, out_c):
        super(DoubleConv2d, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_c, out_channels=out_c, kernel_size=3),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=out_c, out_channels=out_c, kernel_size=3),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

2.2 DownLayer

class DownLayer(nn.Module):
    def __init__(self, in_c, out_c):
        super(DownLayer, self).__init__()
        self.doubleCon2d = DoubleConv2d(in_c, out_c)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        con = self.doubleCon2d(x)
        return self.maxpool(con), con

2.3 UpLayer 

class UpLayer(nn.Module):
    def __init__(self, in_c, out_c):
        super(UpLayer, self).__init__()
        self.doubleCon2d = DoubleConv2d(in_c, out_c)
        self.decon = nn.ConvTranspose2d(in_channels=in_c, out_channels=out_c, kernel_size=2, stride=2)

    def forward(self, x_up, x_left):
        #print("x", x_up.shape)
        up = self.decon(x_up)
        #print("up", up.shape)
        #print("y", x_left.shape)
        x = self.concate(x_left, up)
        #print("x_cat", x.shape)
        return self.doubleCon2d(x)

    def concate(self, x, y):
        _, c1, w1, h1 = x.shape
        _, c2, w2, h2 = y.shape
        out = torch.cat([x[:, :, (w1 - w2) // 2: w1 - (w1 - w2) // 2, (w1 - w2) // 2: w1 - (w1 - w2) // 2], y], dim=1)
        return out

2.4  UNet

class UNet(nn.Module):
    def __init__(self, in_c, out_c):
        super(UNet, self).__init__()
        self.down1 = DownLayer(in_c, 64)
        self.down2 = DownLayer(64, 128)
        self.down3 = DownLayer(128, 256)
        self.down4 = DownLayer(256, 512)
        self.down5 = DownLayer(512, 1024)

        self.up1 = UpLayer(1024, 512)
        self.up2 = UpLayer(512, 256)
        self.up3 = UpLayer(256, 128)
        self.up4 = UpLayer(128, 64)

        self.conv = nn.Conv2d(64, out_c, 1)

    def forward(self, x):
        d1, r1 = self.down1(x)
        print("MP and 2Con out:", d1.shape, r1.shape)
        d2, r2 = self.down2(d1)
        print("MP and 2Con out:", d2.shape, r2.shape)
        d3, r3 = self.down3(d2)
        print("MP and 2Con out:", d3.shape, r3.shape)
        d4, r4 = self.down4(d3)
        print("MP and 2Con out:", d4.shape, r4.shape)
        d5, r5 = self.down5(d4)
        print("MP and 2Con out:", d5.shape, r5.shape)

        u1 = self.up1(r5, r4)
        print("up out:", u1.shape)
        u2 = self.up2(u1, r3)
        print("up out:", u2.shape)
        u3 = self.up3(u2, r2)
        print("up out:", u3.shape)
        u4 = self.up4(u3, r1)
        print("up out:", u4.shape)

        return self.conv(u4)

 3.测试验证

if __name__ == "__main__":
    x = torch.rand(1, 2, 572, 572, dtype=torch.float32)
    model = UNet(2, 2)
    out = model(x)
    print(out.shape)

输出: 

MP and 2Con out: torch.Size([1, 64, 284, 284]) torch.Size([1, 64, 568, 568])
MP and 2Con out: torch.Size([1, 128, 140, 140]) torch.Size([1, 128, 280, 280])
MP and 2Con out: torch.Size([1, 256, 68, 68]) torch.Size([1, 256, 136, 136])
MP and 2Con out: torch.Size([1, 512, 32, 32]) torch.Size([1, 512, 64, 64])
MP and 2Con out: torch.Size([1, 1024, 14, 14]) torch.Size([1, 1024, 28, 28])
up out: torch.Size([1, 512, 52, 52])
up out: torch.Size([1, 256, 100, 100])
up out: torch.Size([1, 128, 196, 196])
up out: torch.Size([1, 64, 388, 388])
torch.Size([1, 2, 388, 388])

可以看出输出形状以及模型结构与论文陈述一致。

Logo

更多推荐