딥러닝/tensorflow

[tensorflow 2] DataLoader 사용법

scjung 2020. 4. 22. 10:44

tensorflow 2 로 올라오면서 pytorch와 많은 것들이 대동소이해졌다

 

pytorch에서 사용하던 torch.utils.data.Dataset과 torch.utils.data.DataLoader를 tensorflow 2에서도 비슷한 코드로 구현이 가능하다

(동일하지는 않음)

 

방법은 tensorflow.keras.utils.Sequence를 사용한다.

 

Sequence는 torch.utils.data.Dataset과 동일하며

 

tensorflow에서 torch.utils.data.DataLoader의 기능은 fit_generator, predict_generator에 내장되어 있다.

(현재 fit_generator, predict_generator는 deprecated 되었다고 공식 문서에 명시되었다. Model.predict, Model.fit를 대신 사용하면 된다.)

 

 

여기서 중요한 점은 DataLoader 안에서 batch_size가 고려된 데이터를 반환해줘야 한다는 것이다.

(pytorch의 Dataset의 경우 하나의 데이터만 return 시켜주면 된다.)

 

예시)

tensorflow 2의 이미지 shape : (batch_size, height, width, channel)

pytorch의 이미지 shape : (height, width, channel)

 

train용 예제 코드

from tensorflow.keras.utils import Sequence
import numpy as np


class DataLoader(Sequence):
	def __init__(self, data_list, batch_size, *args, **kwargs):
		super(DataLoader, self).__init__(*args, **kwargs)
		self.data_list = np.array_split(data_list, len(data_list) // batch_size)
        
	def __getitem__(self, index):
		return self.data_list[index]
        
	def __len__(self):
		return len(self.data_list)
        
if __name__ == "__main__":
	model = TFModel() # temp model
	data_list = getDataList()
	batch_size = 16
    
	data_loader = DataLoader(data_list, batch_size)
	model.fit(data_loader)