自定义损失函数

自定义损失函数

PyTorch在torch.nn模块为我们提供了许多常用的损失函数,比如:MSELoss,L1Loss,BCELoss… 但是随着深度学习的发展,出现了越来越多的非官方提供的Loss,比如DiceLoss,HuberLoss,SobolevLoss… 这些Loss Function专门针对一些非通用的模型,PyTorch不能将他们全部添加到库中去,因此这些损失函数的实现则需要我们通过自定义损失函数来实现。另外,在一些算法实现中,研究者往往会提出全新的损失函数来提升模型的表现,这时我们既无法使用PyTorch自带的损失函数,也没有相关的博客供参考,此时自己实现损失函数就显得更为重要了。

经过本节的学习,你将收获:

  • 掌握如何自定义损失函数

6.1.1 以函数方式定义

事实上,损失函数仅仅是一个函数而已,因此我们可以通过直接以函数定义的方式定义一个自己的函数,如下所示:

1
2
3
def my_loss(output, target):
loss = torch.mean((output - target)**2)
return loss

6.1.2 以类方式定义

虽然以函数定义的方式很简单,但是以类方式定义更加常用,在以类方式定义损失函数时,我们如果看每一个损失函数的继承关系我们就可以发现Loss函数部分继承自_loss, 部分继承自_WeightedLoss, 而_WeightedLoss继承自_loss_loss继承自 nn.Module。我们可以将其当作神经网络的一层来对待,同样地,我们的损失函数类就需要继承自nn.Module类,在下面的例子中我们以DiceLoss为例向大家讲述。

Dice Loss是一种在分割领域常见的损失函数,定义如下:

image-20230709160015651

实现代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class DiceLoss(nn.Module):
def __init__(self,weight=None,size_average=True):
super(DiceLoss,self).__init__()

def forward(self,inputs,targets,smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice = (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
return 1 - dice

# 使用方法
criterion = DiceLoss()
loss = criterion(input,targets)

除此之外,常见的损失函数还有BCE-Dice Loss,Jaccard/Intersection over Union (IoU) Loss,Focal Loss…

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class DiceBCELoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(DiceBCELoss, self).__init__()

def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
Dice_BCE = BCE + dice_loss

return Dice_BCE
class IoULoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(IoULoss, self).__init__()

def forward(self, inputs, targets, smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
intersection = (inputs * targets).sum()
total = (inputs + targets).sum()
union = total - intersection

IoU = (intersection + smooth)/(union + smooth)

return 1 - IoU
ALPHA = 0.8
GAMMA = 2

class FocalLoss(nn.Module):
def __init__(self, weight=None, size_average=True):
super(FocalLoss, self).__init__()

def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA, smooth=1):
inputs = F.sigmoid(inputs)
inputs = inputs.view(-1)
targets = targets.view(-1)
BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
BCE_EXP = torch.exp(-BCE)
focal_loss = alpha * (1-BCE_EXP)**gamma * BCE

return focal_loss
# 更多的可以参考链接1

注:

在自定义损失函数时,涉及到数学运算时,我们最好全程使用PyTorch提供的张量计算接口,这样就不需要我们实现自动求导功能并且我们可以直接调用cuda,使用numpy或者scipy的数学运算时,操作会有些麻烦,大家可以自己下去进行探索。关于PyTorch使用Class定义损失函数的原因,可以参考PyTorch的讨论区(链接6)

本节参考

【1】https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch/notebook
【2】https://www.zhihu.com/question/66988664/answer/247952270
【3】https://blog.csdn.net/dss_dssssd/article/details/84103834
【4】https://zj-image-processing.readthedocs.io/zh_CN/latest/pytorch/自定义损失函数/
【5】https://blog.csdn.net/qq_27825451/article/details/95165265
【6】https://discuss.pytorch.org/t/should-i-define-my-custom-loss-function-as-a-class/89468

原文链接:[6.1 自定义损失函数 — 深入浅出PyTorch (datawhalechina.github.io)](https://datawhalechina.github.io/thorough-pytorch/第六章/6.1 自定义损失函数.html)


自定义损失函数
http://csgituser.github.io/2023/03/08/自定义损失函数/
Author
Museum
Posted on
March 8, 2023
Licensed under