From bbd6e09c224e6c44a98af016f931545363596cfe Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 20 Sep 2017 17:09:07 +0800 Subject: [PATCH] Using LoDTensor for output. --- paddle/operators/modified_huber_loss_op.cc | 7 ++++--- paddle/operators/modified_huber_loss_op.h | 12 +++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc index a6e76c8166..6fe018f9a8 100644 --- a/paddle/operators/modified_huber_loss_op.cc +++ b/paddle/operators/modified_huber_loss_op.cc @@ -34,8 +34,8 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(x->dims().size(), 2, "The tensor rank of X must be 2."); PADDLE_ENFORCE_EQ(x->dims()[1], 1, "The 2nd dimension of X must be 1."); - context.Output("IntermediateVal")->Resize(x->dims()); - context.Output("Out")->Resize({x->dims()[0], 1}); + context.Output("IntermediateVal")->Resize(x->dims()); + context.Output("Out")->Resize({x->dims()[0], 1}); } }; @@ -80,7 +80,8 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { auto* y = context.Input("Y"); auto* intermediate_val = context.Input("IntermediateVal"); auto* out_grad = context.Input(framework::GradVarName("Out")); - auto* x_grad = context.Output(framework::GradVarName("X")); + auto* x_grad = + context.Output(framework::GradVarName("X")); PADDLE_ENFORCE_NOT_NULL(x, "X must be initialized."); PADDLE_ENFORCE_NOT_NULL(y, "Y must be initialized."); diff --git a/paddle/operators/modified_huber_loss_op.h b/paddle/operators/modified_huber_loss_op.h index e78be06ebd..2b2aae1708 100644 --- a/paddle/operators/modified_huber_loss_op.h +++ b/paddle/operators/modified_huber_loss_op.h @@ -52,8 +52,8 @@ class ModifiedHuberLossKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("X"); auto* in1 = context.Input("Y"); - auto* out0 = context.Output("IntermediateVal"); - auto* out1 = context.Output("Out"); + auto* out0 = context.Output("IntermediateVal"); + auto* out1 = context.Output("Out"); out0->mutable_data(context.GetPlace()); out1->mutable_data(context.GetPlace()); @@ -77,9 +77,11 @@ class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("Y"); - auto* in1 = context.Input("IntermediateVal"); - auto* in2 = context.Input(framework::GradVarName("Out")); - auto* out0 = context.Output(framework::GradVarName("X")); + auto* in1 = context.Input("IntermediateVal"); + auto* in2 = + context.Input(framework::GradVarName("Out")); + auto* out0 = + context.Output(framework::GradVarName("X")); if (out0) { const T* y_ptr = in0->data(); -- GitLab