提交 a672b291 编写于 作者: F frankwhzhang

fix code style, test=develop

上级 ea95f9c3
......@@ -22,7 +22,9 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
/*Todo:
*Find a way to adapt TolerableValue, using blas or eigen.
*/
template <typename T>
struct TolerableValue {
HOSTDEVICE T operator()(const T& x) const {
......@@ -86,27 +88,27 @@ class BprLossGradientOpKernel : public framework::OpKernel<T> {
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const int step_size = x->dims()[0];
const int num_classes_ = x->dims()[1];
T* dx_ = dx->mutable_data<T>(ctx.GetPlace());
const T* dy_ = dy->data<T>();
const T* x_ = x->data<T>();
const int64_t* label_pos_ = label_pos->data<int64_t>();
const int num_classes = x->dims()[1];
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>();
const int64_t* label_pos_data = label_pos->data<int64_t>();
for (size_t sample_id = 0; sample_id < step_size; sample_id++) {
for (size_t x_offset = sample_id * num_classes_;
x_offset < (sample_id + 1) * num_classes_; x_offset++) {
dx_[x_offset] = static_cast<T>(0);
for (size_t x_offset = sample_id * num_classes;
x_offset < (sample_id + 1) * num_classes; x_offset++) {
dx_data[x_offset] = static_cast<T>(0);
}
auto p_index = sample_id * num_classes_ + label_pos_[sample_id];
for (size_t ni = 0; ni < num_classes_; ni++) {
if (label_pos_[sample_id] == ni) continue;
auto n_index = sample_id * num_classes_ + ni;
auto grad_ =
-dy_[sample_id] /
((num_classes_ - 1) *
(1.0f + TolerableValue<T>()(std::exp(x_[p_index] - x_[n_index]))));
dx_[p_index] += grad_;
dx_[n_index] -= grad_;
auto p_index = sample_id * num_classes + label_pos_data[sample_id];
for (size_t ni = 0; ni < num_classes; ni++) {
if (label_pos_data[sample_id] == ni) continue;
auto n_index = sample_id * num_classes + ni;
auto grad_ = -dy_data[sample_id] /
((num_classes - 1) *
(1.0f + TolerableValue<T>()(std::exp(x_data[p_index] -
x_data[n_index]))));
dx_data[p_index] += grad_;
dx_data[n_index] -= grad_;
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册