-
[tensorflow 2] HardNet 모델딥러닝/tensorflow 2020. 4. 22. 13:27
해당 코드는 pytorch 로 짜여진 hardnet 모델을 tensorflow 2의 Model로 그대로 넣은 코드이다.
기존 pytorch 코드는 여기
https://github.com/PingoLH/FCHarDNet/blob/master/ptsemseg/models/hardnet.py
PingoLH/FCHarDNet
Fully Convolutional HarDNet for Segmentation in Pytorch - PingoLH/FCHarDNet
github.com
테스트 해봤을 때 속도가 pytorch로 학습하는 속도보다 훠얼씬 떨어지지만 일단은 돌아가니 올리도록 한다.
(1 epoch 당 pytorch: 20~30분 / tensorflow: 9~10시간)
import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras.layers import Layer, UpSampling2D, Conv2D, BatchNormalization, ReLU, AveragePooling2D class TransitionUp(Layer): def __init__(self, **kwargs): super().__init__(**kwargs) def call(self, inputs, skip=None, concat=True, **kwargs): out = UpSampling2D(interpolation='bilinear')(inputs) if concat: skip = tf.dtypes.cast(skip, tf.float32) out = tf.concat([out, skip], 3) return out class ConvLayer(Layer): def __init__(self, out_channels, kernel=3, stride=1, **kwargs): super().__init__(**kwargs) self.conv = Conv2D(out_channels, kernel_size=kernel, strides=stride, padding=[[0, 0], [kernel // 2, kernel // 2], [kernel // 2, kernel // 2], [0, 0]], use_bias=False) self.norm = BatchNormalization() self.relu = ReLU() def call(self, inputs, **kwargs): out = self.conv(inputs) out = self.norm(out) out = self.relu(out) return out class HardBlock(Layer): def get_link(self, layer, base_ch, growth_rate, grmul): if layer == 0: return base_ch, 0, [] out_channels = growth_rate link = [] for i in range(10): dv = 2 ** i if layer % dv == 0: k = layer - dv link.append(k) if i > 0: out_channels *= grmul out_channels = int(int(out_channels + 1) / 2) * 2 in_channels = 0 for i in link: ch, _, _ = self.get_link(i, base_ch, growth_rate, grmul) in_channels += ch return out_channels, in_channels, link def __init__(self, in_channels, growth_rate, grmul, n_layers, keepBase=False, **kwargs): super().__init__(**kwargs) self.keepBase = keepBase self.links = [] self.out_channels = 0 self.layers = [] for i in range(n_layers): outch, inch, link = self.get_link(i + 1, in_channels, growth_rate, grmul) self.links.append(link) self.layers.append(ConvLayer(outch)) if (i % 2 == 0) or (i == n_layers - 1): self.out_channels += outch def get_out_ch(self): return self.out_channels def call(self, inputs, **kwargs): layers_ = [inputs] for layer in range(len(self.layers)): link = self.links[layer] tin = [] for i in link: tin.append(layers_[i]) if len(tin) > 1: inputs = tf.concat(tin, 3) else: inputs = tin[0] out = self.layers[layer](inputs) layers_.append(out) t = len(layers_) out_ = [] for i in range(t): if (i == 0 and self.keepBase) or (i == t - 1) or (i % 2 == 1): out_.append(layers_[i]) out = tf.concat(out_, 3) return out class HardNet(Model): def __init__(self, n_classes=19, *args, **kwargs): super().__init__(*args, **kwargs) first_ch = [16, 24, 32, 48] ch_list = [64, 96, 160, 224, 320] grmul = 1.7 gr = [10, 16, 18, 24, 32] n_layers = [4, 4, 8, 8, 8] blks = len(n_layers) self.shortcut_layers = [] self.base = [] self.base.append(ConvLayer(out_channels=first_ch[0], kernel=3, stride=2)) self.base.append(ConvLayer(first_ch[1], kernel=3)) self.base.append(ConvLayer(first_ch[2], kernel=3, stride=2)) self.base.append(ConvLayer(first_ch[3], kernel=3)) skip_connection_channel_counts = [] ch = first_ch[3] for i in range(blks): blk = HardBlock(ch, gr[i], grmul, n_layers[i]) ch = blk.get_out_ch() skip_connection_channel_counts.append(ch) self.base.append(blk) if i < blks - 1: self.shortcut_layers.append(len(self.base) - 1) self.base.append(ConvLayer(ch_list[i], kernel=1)) ch = ch_list[i] if i < blks - 1: self.base.append(AveragePooling2D(strides=2)) cur_channels_count = ch prev_block_channels = ch n_blocks = blks - 1 self.n_blocks = n_blocks ####################### # Upsampling path # ####################### self.transUpBlocks = [] self.denseBlocksUp = [] self.conv1x1_up = [] for i in range(n_blocks - 1, -1, -1): self.transUpBlocks.append(TransitionUp()) cur_channels_count = prev_block_channels + skip_connection_channel_counts[i] self.conv1x1_up.append(ConvLayer(cur_channels_count // 2, kernel=1)) cur_channels_count = cur_channels_count // 2 blk = HardBlock(cur_channels_count, gr[i], grmul, n_layers[i]) blk.trainable = False self.denseBlocksUp.append(blk) prev_block_channels = blk.get_out_ch() self.finalConv = Conv2D(n_classes, kernel_size=1, strides=1, padding='valid') def call(self, x, training=None, mask=None): skip_connections = [] for i in range(len(self.base)): x = self.base[i](x) if i in self.shortcut_layers: skip_connections.append(x) out = x for i in range(self.n_blocks): skip = skip_connections.pop() out = self.transUpBlocks[i](out, skip, True) out = self.conv1x1_up[i](out) out = self.denseBlocksUp[i](out) out = self.finalConv(out) out = UpSampling2D(size=(4, 4), interpolation='bilinear')(out) return out
'딥러닝 > tensorflow' 카테고리의 다른 글
[기초 1] tensorflow 2 keras 모델 정의 (0) 2020.04.22 [기초 0] tensorflow 2 keras 기초 (0) 2020.04.22 [tensorflow 2] loss, optimizer 유동적으로 가져오기 (0) 2020.04.22 [tensorflow 2] DataLoader 사용법 (0) 2020.04.22 [keras] generator를 이용한 multiple input data (0) 2020.02.17