Loss 出现 Nan

总结可能的原因如下:

  • 数据不正确

    • 比如说我们处理的实际是一个N分类问题的时候,计算cross entropy时将其当作一个M分类问题。如果N>M, 则在计算loss的时候不会报错,计算得到的loss直接就是Nan。
    • 训练样本中的脏数据导致计算结果为 0;
  • 学习率过大

    有时候学习率过大也会导致NAN,为了判别这种错误,我们只需要将学习率设置为0,看是否会继续出现NAN,如果还会出现NAN,则可以排除这种错误。

  • 激活函数有误

    比如我们使用x12x^{\frac{1}{2}} 作为激活函数的时候,其在x0x\leq 0 处是不可导的,此时也会产生NAN的问题。

    上述的例子也说明了,当我们在计算l2loss的时候为何不去开根号,而是保留平方的形式。

    参考:

  • 数据没有归一化

    当我们做一个regression任务的时候,如果prediction没有归一化的话,可能会导致prediction值过大,在计算loss的之后容易产生过大的loss值,而导致梯度爆炸,出现 Nan 的情况。

    参考:数据处理中的归一化和标准化方法

  • 出现一些异常操作

    比如出现除以0,log 0等操作会导致Nan 。比如说在归一化的时候,分母是正样本的总数,但是batch内如果没有正样本呢?这时候我们就需要加入一个小的平滑项,或者是判断,如果是0,返回一个0.0。