딥러닝/tensorflow
[tensorflow 2] DataLoader 사용법
scjung
2020. 4. 22. 10:44
tensorflow 2 로 올라오면서 pytorch와 많은 것들이 대동소이해졌다
pytorch에서 사용하던 torch.utils.data.Dataset과 torch.utils.data.DataLoader를 tensorflow 2에서도 비슷한 코드로 구현이 가능하다
(동일하지는 않음)
방법은 tensorflow.keras.utils.Sequence를 사용한다.
Sequence는 torch.utils.data.Dataset과 동일하며
tensorflow에서 torch.utils.data.DataLoader의 기능은 fit_generator, predict_generator에 내장되어 있다.
(현재 fit_generator, predict_generator는 deprecated 되었다고 공식 문서에 명시되었다. Model.predict, Model.fit를 대신 사용하면 된다.)
여기서 중요한 점은 DataLoader 안에서 batch_size가 고려된 데이터를 반환해줘야 한다는 것이다.
(pytorch의 Dataset의 경우 하나의 데이터만 return 시켜주면 된다.)
예시)
tensorflow 2의 이미지 shape : (batch_size, height, width, channel)
pytorch의 이미지 shape : (height, width, channel)
train용 예제 코드
from tensorflow.keras.utils import Sequence
import numpy as np
class DataLoader(Sequence):
def __init__(self, data_list, batch_size, *args, **kwargs):
super(DataLoader, self).__init__(*args, **kwargs)
self.data_list = np.array_split(data_list, len(data_list) // batch_size)
def __getitem__(self, index):
return self.data_list[index]
def __len__(self):
return len(self.data_list)
if __name__ == "__main__":
model = TFModel() # temp model
data_list = getDataList()
batch_size = 16
data_loader = DataLoader(data_list, batch_size)
model.fit(data_loader)