-
IOU loss딥러닝/pytorch 2020. 2. 7. 17:49
hardnet을 학습하면서 validation 이미지들의 iou 값이 계속 0.9를 못넘었다
그래서 loss function 문제인가 하고 이전에 사용했던 cross entropy 대신 iou loss를 사용하였다
링크는 아래
https://discuss.pytorch.org/t/how-to-implement-soft-iou-loss/15152
How to implement soft-IoU loss?
I am trying to implement soft-mIoU loss for semantic segmentation as per the following equation. but loss is very low and I am not able to find the wrong step in the implementation. def to_one_hot(tensor,nClasses): n,h,w = tensor.size() one_hot = torch.zer
discuss.pytorch.org
여기서 약간의 수정을 가해서 아래와 같은 코드가 됨
def to_one_hot(tensor,nClasses,device): n,h,w = tensor.size() one_hot = torch.zeros(n,nClasses,h,w).to(device).scatter_(1,tensor.view(n,1,h,w),1) return one_hot class mIoULoss(torch.nn.Module): def __init__(self, weight=None, size_average=True, n_classes=4, device='cuda'): super(mIoULoss, self).__init__() self.classes = n_classes self.device = device def forward(self, inputs, target): # inputs => N x Classes x H x W # target_oneHot => N x Classes x H x W inputs = inputs.to(self.device) target = target.to(self.device) SMOOTH = 1e-6 N = inputs.size()[0] inputs = F.softmax(inputs,dim=1) target_oneHot = to_one_hot(target, self.classes,self.device) # Numerator Product inter = inputs * target_oneHot ## Sum over all pixels N x C x H x W => N x C inter = inter.view(N,self.classes,-1).sum(2) + SMOOTH #Denominator union= inputs + target_oneHot - (inputs*target_oneHot) ## Sum over all pixels N x C x H x W => N x C union = union.view(N,self.classes,-1).sum(2) + SMOOTH loss = inter/union ## Return average loss over classes and batch return -loss.mean()
저기서 to_one_hot 함수의 scatter 함수가 굉장히 특이해서 좀 찾아봤더니 아래 링크에서 설명을 잘해놨다
[정리][PyTorch] Lab-06 Softmax Classification
https://www.youtube.com/watch?v=B3gtAi-wlG8&list=PLQ28Nx3M4JrhkqBVIXg-i...
blog.naver.com
근데 보통 [num, height, weight] 를 [num, class, height, weight] 로 바꾸지 않나
scatter_ 함수는 [num, 1, height, weight] 로 dimention 수를 맞춰줘야 하더라...
그리고 학습을 계속 했는데 결과는......
0.9는 여전히 넘지 않았다
이쯤되면 labeling하고 predict하고의 오차 범위라고 생각해야할 듯
'딥러닝 > pytorch' 카테고리의 다른 글
아파서 내일 할 것만 정리 (0) 2020.02.06 HardNet 개발 (2) (0) 2020.02.05 HardNet 개발 (5) 2020.02.05 pytorch gpu parallel 설정 (0) 2020.02.04