scjung 2020. 2. 7. 17:49

hardnet을 학습하면서 validation 이미지들의 iou 값이 계속 0.9를 못넘었다

 

그래서 loss function 문제인가 하고 이전에 사용했던 cross entropy 대신 iou loss를 사용하였다

 

링크는 아래

https://discuss.pytorch.org/t/how-to-implement-soft-iou-loss/15152

 

How to implement soft-IoU loss?

I am trying to implement soft-mIoU loss for semantic segmentation as per the following equation. but loss is very low and I am not able to find the wrong step in the implementation. def to_one_hot(tensor,nClasses): n,h,w = tensor.size() one_hot = torch.zer

discuss.pytorch.org

 

여기서 약간의 수정을 가해서 아래와 같은 코드가 됨

def to_one_hot(tensor,nClasses,device):
    n,h,w = tensor.size()
    one_hot = torch.zeros(n,nClasses,h,w).to(device).scatter_(1,tensor.view(n,1,h,w),1)
    return one_hot

class mIoULoss(torch.nn.Module):
    def __init__(self, weight=None, size_average=True, n_classes=4, device='cuda'):
        super(mIoULoss, self).__init__()
        self.classes = n_classes
        self.device = device

    def forward(self, inputs, target):
    	# inputs => N x Classes x H x W
    	# target_oneHot => N x Classes x H x W
        inputs = inputs.to(self.device)
        target = target.to(self.device)
        
        SMOOTH = 1e-6
        N = inputs.size()[0]

        inputs = F.softmax(inputs,dim=1)
        target_oneHot = to_one_hot(target, self.classes,self.device)
        # Numerator Product
        inter = inputs * target_oneHot
        ## Sum over all pixels N x C x H x W => N x C
        inter = inter.view(N,self.classes,-1).sum(2) + SMOOTH

        #Denominator 
        union= inputs + target_oneHot - (inputs*target_oneHot)
        ## Sum over all pixels N x C x H x W => N x C
        union = union.view(N,self.classes,-1).sum(2) + SMOOTH

        loss = inter/union

        ## Return average loss over classes and batch
        return -loss.mean()

저기서 to_one_hot 함수의 scatter 함수가 굉장히 특이해서 좀 찾아봤더니 아래 링크에서 설명을 잘해놨다

 

https://m.blog.naver.com/PostView.nhn?blogId=hongjg3229&logNo=221555913457&proxyReferer=https%3A%2F%2Fwww.google.com%2F

 

[정리][PyTorch] Lab-06 Softmax Classification

https://www.youtube.com/watch?v=B3gtAi-wlG8&list=PLQ28Nx3M4JrhkqBVIXg-i...

blog.naver.com

근데 보통 [num, height, weight] 를 [num, class, height, weight] 로 바꾸지 않나

 

scatter_ 함수는 [num, 1, height, weight] 로 dimention 수를 맞춰줘야 하더라...

 

 

그리고 학습을 계속 했는데 결과는......

 

 

 

 

0.9는 여전히 넘지 않았다

 

이쯤되면 labeling하고 predict하고의 오차 범위라고 생각해야할 듯