# 带权重和ignore_index的交叉商损失函数

# 背景

使用ADE20K数据集进行验证的分割算法,因这个数据集是exausted annotated,也就是图像中的每个像素都标注了类别,因此背景只占了很少的一部分,因此训练时会设置ignore_index=255,在商汤的框架mmsegmentation中的ade20k.py (opens new window)中的educe_zero_label=True参数,正是为了实现ignore index忽略背景


train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', reduce_zero_label=True),
    dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0))]

# 解析

正常的交叉熵计算公式为:

lj=yjlog(efyji=1cefi)

ignore index即需要忽略掉的label id,再计算交叉熵的时候会去除掉ignore index

对于数据:

import torch
y_pre = torch.tensor([[-0.7795, -1.7373, 0.3856],
                      [ 0.4031, -0.5632, -0.2389],
                      [-2.3908,  0.7452, 0.7748]])
y_real = torch.tensor([0, 1, 2])
critien = torch.nn.CrossEntropyLoss()
print(critien(y_pre, y_real))
# tensor(1.2784)

可以将其当成batch=3,类别数C=3的例子,常规交叉熵的计算方式为:

l1=1log(e0.7795e0.7795+e1.7373+e0.3856)l2=1log(e0.5632e0.4031+e0.5632+e0.2389)l3=1log(e0.7748e2.3908+e0.7452+e0.7748)l=l1+l2+l33=1.2784
  • 使用ignore_index参数时,
lj=yjlog(efyji=1cefi)1{yj==ignore_index}

譬如ignore_index=2,则其计算会忽略掉label id 2,相当于只剩下类别0,1

l1=1log(e0.7795e0.7795+e1.7373)l2=1log(e0.5632e0.4031+e0.5632)l=l1+l22=1.5678
critien = torch.nn.CrossEntropyLoss(ignore_index=2)
print(critien(y_pre, y_real))
# tensor(1.5678)
  • 带权重w
lj=yjlog(efyji=1cefi)w[labels[j]]

计算average loss ,算样本个数时也需乘上w[labels[i]]

譬如w=[0.3, 0.6, 0.1], 带权重交叉熵的计算方式为:

yj=1l1=1log(e0.7795e0.7795+e1.7373+e0.3856)0.3l2=1log(e0.5632e0.4031+e0.5632+e0.2389)0.6l3=1log(e0.7748e2.3908+e0.7452+e0.7748)0.1l=l1+l2+l310.3+10.6+10.1=1.4941
w = torch.tensor([0.3, 0.6, 0.1]) 
critien = torch.nn.CrossEntropyLoss(weight=w)
print(critien(y_pre, y_real))
# tensor(1.4941)

带权重交叉熵损失函数主要用于处理类别不平衡的问题。

  • 同时使用ignore_index和权重w

综合以上两种情况,譬如ignore_index=2和权重w=[0.3, 0.6, 0.1]

l1=1log(e0.7795e0.7795+e1.7373)0.3l2=1log(e0.5632e0.4031+e0.5632)0.6l=l1+l210.3+10.6=1.5824
critien = torch.nn.CrossEntropyLoss(ignore_index=2, weight=w)
print(critien(y_pre, y_real))
# tensor(1.5824)

动手实现的计算带权重和ignore_index参数的交叉熵损失函数如下,测试可以得到与torch.nn.CrossEntropyLoss一样的结果,帮助理解其原理

import torch

def cross_entropy(inp, tar,ig=None, w=None): 
    ss = 0 
    for i,s in enumerate(tar):
        ts = torch.log((torch.exp(inp[i][s])/torch.sum(inp[i].exp()))) 
        ts = 0 if s.item() == ig else ts 
        if w is not None: 
            ts = ts * w[s]
        ss += ts
    bi = torch.bincount(tar)
    ns = 0 
    for i in range(torch.max(tar)+1): 
        if i == ig: 
            continue
        if w is not None:
            ns += bi[i] * w[i] 
        else:
            ns +=bi[i]
    return -ss / ns

w = torch.tensor([0.3, 0.6, 0.1]) 
y_pre = torch.tensor([[-0.7795, -1.7373, 0.3856],
                      [ 0.4031, -0.5632, -0.2389],
                      [-2.3908,  0.7452, 0.7748]])
y_real = torch.tensor([0, 1, 2])
critien = torch.nn.CrossEntropyLoss(ignore_index=2, weight=w)
print(cross_entropy(y_pre, y_real, ig=2, w=w))
print(critien(y_pre, y_real))
"""
tensor(1.5824)
tensor(1.5824)
"""
(adsbygoogle = window.adsbygoogle || []).push({});

# 参考资料