# 带权重和ignore_index的交叉商损失函数
# 背景
使用ADE20K
数据集进行验证的分割算法,因这个数据集是exausted annotated
,也就是图像中的每个像素都标注了类别,因此背景只占了很少的一部分,因此训练时会设置ignore_index=255
,在商汤的框架mmsegmentation
中的ade20k.py
(opens new window)中的educe_zero_label=True
参数,正是为了实现ignore index
忽略背景
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', reduce_zero_label=True),
dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0))]
# 解析
正常的交叉熵计算公式为:
ignore index
即需要忽略掉的label id
,再计算交叉熵的时候会去除掉ignore index
,
对于数据:
import torch
y_pre = torch.tensor([[-0.7795, -1.7373, 0.3856],
[ 0.4031, -0.5632, -0.2389],
[-2.3908, 0.7452, 0.7748]])
y_real = torch.tensor([0, 1, 2])
critien = torch.nn.CrossEntropyLoss()
print(critien(y_pre, y_real))
# tensor(1.2784)
可以将其当成batch=3
,类别数C=3
的例子,常规交叉熵的计算方式为:
- 使用
ignore_index
参数时,
譬如ignore_index=2
,则其计算会忽略掉label id 2
,相当于只剩下类别0,1
critien = torch.nn.CrossEntropyLoss(ignore_index=2)
print(critien(y_pre, y_real))
# tensor(1.5678)
- 带权重
w
时,
计算average loss ,算样本个数时也需乘上w[labels[i]]
譬如w=[0.3, 0.6, 0.1]
, 带权重交叉熵的计算方式为:
w = torch.tensor([0.3, 0.6, 0.1])
critien = torch.nn.CrossEntropyLoss(weight=w)
print(critien(y_pre, y_real))
# tensor(1.4941)
带权重交叉熵损失函数主要用于处理类别不平衡的问题。
- 同时使用
ignore_index
和权重w
,
综合以上两种情况,譬如ignore_index=2
和权重w=[0.3, 0.6, 0.1]
critien = torch.nn.CrossEntropyLoss(ignore_index=2, weight=w)
print(critien(y_pre, y_real))
# tensor(1.5824)
动手实现的计算带权重和ignore_index
参数的交叉熵损失函数如下,测试可以得到与torch.nn.CrossEntropyLoss
一样的结果,帮助理解其原理
import torch
def cross_entropy(inp, tar,ig=None, w=None):
ss = 0
for i,s in enumerate(tar):
ts = torch.log((torch.exp(inp[i][s])/torch.sum(inp[i].exp())))
ts = 0 if s.item() == ig else ts
if w is not None:
ts = ts * w[s]
ss += ts
bi = torch.bincount(tar)
ns = 0
for i in range(torch.max(tar)+1):
if i == ig:
continue
if w is not None:
ns += bi[i] * w[i]
else:
ns +=bi[i]
return -ss / ns
w = torch.tensor([0.3, 0.6, 0.1])
y_pre = torch.tensor([[-0.7795, -1.7373, 0.3856],
[ 0.4031, -0.5632, -0.2389],
[-2.3908, 0.7452, 0.7748]])
y_real = torch.tensor([0, 1, 2])
critien = torch.nn.CrossEntropyLoss(ignore_index=2, weight=w)
print(cross_entropy(y_pre, y_real, ig=2, w=w))
print(critien(y_pre, y_real))
"""
tensor(1.5824)
tensor(1.5824)
"""