提交 076dcb9b 编写于 作者: Y yangyaming

Simpify the initialization for weights.

上级 53ab7e78
...@@ -141,22 +141,12 @@ class SmoothL1LossGradKernel : public framework::OpKernel { ...@@ -141,22 +141,12 @@ class SmoothL1LossGradKernel : public framework::OpKernel {
diff.device(place) = EigenVector<T>::Flatten(*in2).unaryExpr( diff.device(place) = EigenVector<T>::Flatten(*in2).unaryExpr(
SmoothL1LossBackward<T>(sigma2)); SmoothL1LossBackward<T>(sigma2));
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
// compute weights // compute weights
Tensor paddle_weights; Tensor paddle_weights;
paddle_weights.mutable_data<T>(mat_dims, context.GetPlace()); paddle_weights.mutable_data<T>(mat_dims, context.GetPlace());
auto weights = EigenMatrix<T>::From(paddle_weights); auto weights = EigenMatrix<T>::From(paddle_weights);
// initialize to 1.0 // initialize to 1.0
if (platform::is_cpu_place(context.GetPlace())) { weights.device(place) = weights.constant(static_cast<T>(1.0));
weights.setConstant(static_cast<T>(1.0));
} else {
Tensor paddle_cpu_weights;
paddle_cpu_weights.mutable_data<T>(mat_dims, platform::CPUPlace());
EigenMatrix<T>::From(paddle_cpu_weights).setConstant(static_cast<T>(1.0));
paddle_weights.CopyFrom<T>(paddle_cpu_weights, context.GetPlace());
}
if (has_weight) { if (has_weight) {
auto inside_weight = EigenMatrix<T>::From(*in0, mat_dims); auto inside_weight = EigenMatrix<T>::From(*in0, mat_dims);
auto outside_weight = EigenMatrix<T>::From(*in1, mat_dims); auto outside_weight = EigenMatrix<T>::From(*in1, mat_dims);
...@@ -170,6 +160,9 @@ class SmoothL1LossGradKernel : public framework::OpKernel { ...@@ -170,6 +160,9 @@ class SmoothL1LossGradKernel : public framework::OpKernel {
Eigen::array<int, 2>({{1, static_cast<int>(cols)}})) * Eigen::array<int, 2>({{1, static_cast<int>(cols)}})) *
weights * diff_mat_view; weights * diff_mat_view;
auto* out0 = context.Output<Tensor>(framework::GradVarName("X"));
auto* out1 = context.Output<Tensor>(framework::GradVarName("Y"));
if (out0) { if (out0) {
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenMatrix<T>::From(*out0, mat_dims); auto x_grad = EigenMatrix<T>::From(*out0, mat_dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册