diff --git a/paddle/operators/huber_loss_op.cc b/paddle/operators/huber_loss_op.cc index 8c2ca86ccce099c1b4c17e6ac33854824ff680f4..2d9449f5ca50dab8d2a7928c4311ec2d66b47904 100644 --- a/paddle/operators/huber_loss_op.cc +++ b/paddle/operators/huber_loss_op.cc @@ -21,24 +21,24 @@ class HuberLossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must be initialized."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must be initialized."); + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized."); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); - PADDLE_ENFORCE_EQ(x->dims(), y->dims()); - PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2, + PADDLE_ENFORCE_EQ(x_dims, y_dims); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, "The rank of Input(X) must be 2 and the shape is " "[batch_size, 1]."); - PADDLE_ENFORCE_EQ(x->dims()[1], 1, + PADDLE_ENFORCE_EQ(x_dims[1], 1, "Each row of Input(X) contains a real value, " "so the 2nd dimension of Input(X) must be 1."); - ctx.Output("Residual")->Resize(x->dims()); - ctx.Output("Out")->Resize({x->dims()[0], 1}); + ctx->SetOutputDim("Residual", x_dims); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->ShareLoD("X", "Out"); } }; @@ -55,7 +55,7 @@ class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { "The target value of huber loss op." "Y is a 2-D tensor with shape [batch_size, 1]."); AddOutput("Residual", - "Intermediate tensor to cache residual value of Y and X." + "Intermediate tensor to cache residual value between Y and X." "The shape is same as Input(X) and will be reused in backward.") .AsIntermediate(); AddOutput("Out", @@ -82,25 +82,30 @@ class HuberLossGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - protected: - void InferShape(const framework::InferShapeContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* residual = ctx.Input("Residual"); - auto* out_grad = ctx.Input(framework::GradVarName("Out")); - auto* x_grad = ctx.Output(framework::GradVarName("X")); - auto* y_grad = ctx.Output(framework::GradVarName("Y")); - - PADDLE_ENFORCE_NOT_NULL(x, "Input(X) should not be null."); - PADDLE_ENFORCE_NOT_NULL(y, "Input(Y) should not be null."); - PADDLE_ENFORCE_NOT_NULL(residual, "Input(Residual) should not be null."); - PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@GRAD) should not be null."); - - PADDLE_ENFORCE_EQ(residual->dims(), x->dims()); - PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims()); - - if (x_grad) x_grad->Resize(x->dims()); - if (y_grad) y_grad->Resize(y->dims()); + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Residual"), + "Input(Residual) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto residual_dims = ctx->GetInputDim("Residual"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(residual_dims, x_dims); + PADDLE_ENFORCE_EQ(out_grad_dims, x_dims); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } } }; diff --git a/paddle/operators/huber_loss_op.h b/paddle/operators/huber_loss_op.h index 6913141bde38eaf85500723059cb8672aad8d02f..d8a2da52f508ce131570b011acca1eb67c48c343 100644 --- a/paddle/operators/huber_loss_op.h +++ b/paddle/operators/huber_loss_op.h @@ -42,14 +42,14 @@ struct HuberLossForward { }; template -class HuberLossKernel : public framework::OpKernel { +class HuberLossKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("X"); auto* in1 = context.Input("Y"); auto* out0 = context.Output("Residual"); auto* out1 = context.Output("Out"); - auto delta = static_cast(context.op().Attr("delta")); + auto delta = static_cast(context.Attr("delta")); auto place = context.GetEigenDevice(); auto x = EigenVector::Flatten(*in0); @@ -65,11 +65,10 @@ class HuberLossKernel : public framework::OpKernel { template struct HuberLossBackward { - HOSTDEVICE HuberLossBackward(const T& delta, bool is_x) - : is_x(is_x), delta(delta) {} + HOSTDEVICE HuberLossBackward(const T& delta, T sign) + : sign(sign), delta(delta) {} HOSTDEVICE T operator()(const T& val) const { - T sign = is_x ? -1.0 : 1.0; T abs_val = std::abs(val); if (abs_val <= delta) { return sign * val; @@ -82,12 +81,12 @@ struct HuberLossBackward { } } - bool is_x; + T sign; T delta; }; template -class HuberLossGradKernel : public framework::OpKernel { +class HuberLossGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* in0 = context.Input("Residual"); @@ -104,14 +103,14 @@ class HuberLossGradKernel : public framework::OpKernel { out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); x_grad.device(place) = - out_grad * residual.unaryExpr(HuberLossBackward(delta, true)); + out_grad * residual.unaryExpr(HuberLossBackward(delta, -1.0)); } if (out1) { out1->mutable_data(context.GetPlace()); auto y_grad = EigenVector::Flatten(*out1); y_grad.device(place) = - out_grad * residual.unaryExpr(HuberLossBackward(delta, false)); + out_grad * residual.unaryExpr(HuberLossBackward(delta, 1.0)); } } }; diff --git a/python/paddle/v2/framework/tests/test_huber_loss_op.py b/python/paddle/v2/framework/tests/test_huber_loss_op.py index ff0a17c184d1beb9a3d6c1b9b847776bf1592733..b2f102d4fc02a524514d4ba1cd523ddc1e4604ea 100644 --- a/python/paddle/v2/framework/tests/test_huber_loss_op.py +++ b/python/paddle/v2/framework/tests/test_huber_loss_op.py @@ -32,15 +32,15 @@ class TestHuberLossOp(OpTest): self.check_output() def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008) def test_check_grad_ingore_x(self): self.check_grad( - ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("residual")) + ['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual")) def test_check_grad_ingore_y(self): self.check_grad( - ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('residual')) + ['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual')) if __name__ == '__main__':