提交 b7776e66 编写于 作者: Y yangyaming

Fix dimension bugs.

上级 2763f3e3
......@@ -87,10 +87,13 @@ class SmoothL1LossKernel : public framework::OpKernel {
auto outside_weight = EigenVector<T>::Flatten(*in3);
errors.device(place) = errors * outside_weight;
}
auto loss = EigenMatrix<T>::From(*out1, {in0->dims()[0], 1});
auto loss = EigenVector<T>::Flatten(*out1);
// first dimension of 'X' is the number of samples
auto errors_mat_view = EigenMatrix<T>::From(paddle_errors, in0->dims());
loss.device(place) = errors_mat_view.sum(Eigen::array<int, 1>({1}));
auto mat_dims =
framework::make_ddim({static_cast<int>(in0->dims()[0]),
static_cast<int>(in_counts / in0->dims()[0])});
auto errors_mat_view = EigenMatrix<T>::From(paddle_errors, mat_dims);
loss.device(place) = errors_mat_view.sum(Eigen::array<int, 1>({{1}}));
}
};
......@@ -162,9 +165,9 @@ class SmoothL1LossGradKernel : public framework::OpKernel {
// compute gradients
auto out_grad = EigenMatrix<T>::From(*og);
auto diff_mat_view = EigenMatrix<T>::From(paddle_diff, mat_dims);
auto gradients =
out_grad.broadcast(Eigen::array<int, 2>({1, static_cast<int>(cols)})) *
weights * diff_mat_view;
auto gradients = out_grad.broadcast(
Eigen::array<int, 2>({{1, static_cast<int>(cols)}})) *
weights * diff_mat_view;
if (out0) {
out0->mutable_data<T>(context.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册