作者:步履乘风 | 来源:互联网 | 2023-10-12 16:47
Pk对zk的求导,以及Pk对zj的求导请参考https:blog.csdn.netu013066730articledetails86231215前向代码:for(inti0;i
Pk对zk的求导,以及Pk对zj的求导请参考https://blog.csdn.net/u013066730/article/details/86231215
前向代码:
for (int i = 0; i for (int j = 0; j const int label_value = static_cast(label[i * inner_num_ + j]);
if (has_ignore_label_ && label_value == ignore_label_) {
continue;
}
DCHECK_GE(label_value, 0);
DCHECK_LT(label_value, channels);
const int index = i * dim + label_value * inner_num_ + j;
// FL(p_t) = -(1 - p_t) ^ gamma * log(p_t)
// loss -= std::max(power_prob_data[index] * log_prob_data[index],
// Dtype(log(Dtype(FLT_MIN))));
loss -= power_prob_data[index] * log_prob_data[index];
++count;
}
}
// prob
top[0]->mutable_cpu_data()[0] = loss / get_normalizer(normalization_, count);
反向代码:
for (int i = 0; i for (int j = 0; j // label
const int label_value = static_cast(label[i * inner_num_ + j]);
// ignore label
if (has_ignore_label_ && label_value == ignore_label_) {
for (int c = 0; c bottom_diff[i * dim + c * inner_num_ + j] = 0;
}
continue;
}
// the gradient from FL w.r.t p_t, here ignore the `sign`
int ind_i = i * dim + label_value * inner_num_ + j; // index of ground-truth label
Dtype grad = 0 - gamma_ * (power_prob_data[ind_i] / std::max(1 - prob_data[ind_i], eps))
* log_prob_data[ind_i] * prob_data[ind_i]
+ power_prob_data[ind_i];
// the gradient w.r.t input data x
for (int c = 0; c int ind_j = i * dim + c * inner_num_ + j;
if(c == label_value) {
CHECK_EQ(ind_i, ind_j);
// if i == j, (here i,j are refered for derivative of softmax)
bottom_diff[ind_j] = grad * (prob_data[ind_i] - 1);
} else {
// if i != j, (here i,j are refered for derivative of softmax)
bottom_diff[ind_j] = grad * prob_data[ind_j];
}
}
// count
++count;
}
}
// Scale gradient
Dtype loss_weight = top[0]->cpu_diff()[0] / get_normalizer(normalization_, count);
caffe_scal(prob_.count(), loss_weight, bottom_diff);