딥러닝/tensorflow

[tensorflow 2] HardNet 모델

scjung 2020. 4. 22. 13:27

해당 코드는 pytorch 로 짜여진 hardnet 모델을 tensorflow 2의 Model로 그대로 넣은 코드이다.

 

기존 pytorch 코드는 여기

 

https://github.com/PingoLH/FCHarDNet/blob/master/ptsemseg/models/hardnet.py

 

PingoLH/FCHarDNet

Fully Convolutional HarDNet for Segmentation in Pytorch - PingoLH/FCHarDNet

github.com

 

테스트 해봤을 때 속도가 pytorch로 학습하는 속도보다 훠얼씬 떨어지지만 일단은 돌아가니 올리도록 한다.

(1 epoch 당 pytorch: 20~30분 / tensorflow: 9~10시간)

 

import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Layer, UpSampling2D, Conv2D, BatchNormalization, ReLU, AveragePooling2D


class TransitionUp(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs, skip=None, concat=True, **kwargs):
        out = UpSampling2D(interpolation='bilinear')(inputs)

        if concat:
            skip = tf.dtypes.cast(skip, tf.float32)
            out = tf.concat([out, skip], 3)

        return out


class ConvLayer(Layer):
    def __init__(self, out_channels, kernel=3, stride=1, **kwargs):
        super().__init__(**kwargs)

        self.conv = Conv2D(out_channels,
                           kernel_size=kernel,
                           strides=stride,
                           padding=[[0, 0], [kernel // 2, kernel // 2], [kernel // 2, kernel // 2], [0, 0]],
                           use_bias=False)
        self.norm = BatchNormalization()
        self.relu = ReLU()

    def call(self, inputs, **kwargs):
        out = self.conv(inputs)
        out = self.norm(out)
        out = self.relu(out)

        return out


class HardBlock(Layer):
    def get_link(self, layer, base_ch, growth_rate, grmul):
        if layer == 0:
            return base_ch, 0, []
        out_channels = growth_rate
        link = []
        for i in range(10):
            dv = 2 ** i
            if layer % dv == 0:
                k = layer - dv
                link.append(k)
                if i > 0:
                    out_channels *= grmul
        out_channels = int(int(out_channels + 1) / 2) * 2
        in_channels = 0
        for i in link:
            ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul)
            in_channels += ch
        return out_channels, in_channels, link

    def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, **kwargs):
        super().__init__(**kwargs)
        self.keepBase = keepBase
        self.links = []
        self.out_channels = 0
        self.layers = []

        for i in range(n_layers):
            outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul)
            self.links.append(link)
            self.layers.append(ConvLayer(outch))
            if (i % 2 == 0) or (i == n_layers - 1):
                self.out_channels += outch

    def get_out_ch(self):
        return self.out_channels

    def call(self, inputs, **kwargs):
        layers_ = [inputs]

        for layer in range(len(self.layers)):
            link = self.links[layer]
            tin = []
            for i in link:
                tin.append(layers_[i])
            if len(tin) > 1:
                inputs = tf.concat(tin, 3)
            else:
                inputs = tin[0]
            out = self.layers[layer](inputs)
            layers_.append(out)
        t = len(layers_)
        out_ = []
        for i in range(t):
            if (i == 0 and self.keepBase) or (i == t - 1) or (i % 2 == 1):
                out_.append(layers_[i])
        out = tf.concat(out_, 3)
        return out


class HardNet(Model):
    def __init__(self, n_classes=19, *args, **kwargs):
        super().__init__(*args, **kwargs)
        first_ch = [16, 24, 32, 48]
        ch_list = [64, 96, 160, 224, 320]
        grmul = 1.7
        gr = [10, 16, 18, 24, 32]
        n_layers = [4, 4, 8, 8, 8]

        blks = len(n_layers)

        self.shortcut_layers = []

        self.base = []
        self.base.append(ConvLayer(out_channels=first_ch[0], kernel=3, stride=2))
        self.base.append(ConvLayer(first_ch[1], kernel=3))
        self.base.append(ConvLayer(first_ch[2], kernel=3, stride=2))
        self.base.append(ConvLayer(first_ch[3], kernel=3))

        skip_connection_channel_counts = []
        ch = first_ch[3]
        for i in range(blks):
            blk = HardBlock(ch, gr[i], grmul, n_layers[i])
            ch = blk.get_out_ch()
            skip_connection_channel_counts.append(ch)
            self.base.append(blk)
            if i < blks - 1:
                self.shortcut_layers.append(len(self.base) - 1)

            self.base.append(ConvLayer(ch_list[i], kernel=1))
            ch = ch_list[i]

            if i < blks - 1:
                self.base.append(AveragePooling2D(strides=2))

        cur_channels_count = ch
        prev_block_channels = ch
        n_blocks = blks - 1
        self.n_blocks = n_blocks

        #######################
        #   Upsampling path   #
        #######################

        self.transUpBlocks = []
        self.denseBlocksUp = []
        self.conv1x1_up = []

        for i in range(n_blocks - 1, -1, -1):
            self.transUpBlocks.append(TransitionUp())
            cur_channels_count = prev_block_channels + skip_connection_channel_counts[i]
            self.conv1x1_up.append(ConvLayer(cur_channels_count // 2, kernel=1))
            cur_channels_count = cur_channels_count // 2

            blk = HardBlock(cur_channels_count, gr[i], grmul, n_layers[i])
            blk.trainable = False
            self.denseBlocksUp.append(blk)
            prev_block_channels = blk.get_out_ch()

        self.finalConv = Conv2D(n_classes, kernel_size=1, strides=1, padding='valid')

    def call(self, x, training=None, mask=None):
        skip_connections = []

        for i in range(len(self.base)):
            x = self.base[i](x)
            if i in self.shortcut_layers:
                skip_connections.append(x)
        out = x

        for i in range(self.n_blocks):
            skip = skip_connections.pop()
            out = self.transUpBlocks[i](out, skip, True)
            out = self.conv1x1_up[i](out)
            out = self.denseBlocksUp[i](out)

        out = self.finalConv(out)
        out = UpSampling2D(size=(4, 4), interpolation='bilinear')(out)
        return out