提交 5939a17c 编写于 作者: Y yangyaming

Follow comments and adapt to new interface.

上级 05211610
...@@ -21,24 +21,24 @@ class HuberLossOp : public framework::OperatorWithKernel { ...@@ -21,24 +21,24 @@ class HuberLossOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must be initialized.");
auto* x = ctx.Input<Tensor>("X"); auto x_dims = ctx->GetInputDim("X");
auto* y = ctx.Input<Tensor>("Y"); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x->dims(), y->dims()); PADDLE_ENFORCE_EQ(x_dims, y_dims);
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2, PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is " "The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1]."); "[batch_size, 1].");
PADDLE_ENFORCE_EQ(x->dims()[1], 1, PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Each row of Input(X) contains a real value, " "Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1."); "so the 2nd dimension of Input(X) must be 1.");
ctx.Output<Tensor>("Residual")->Resize(x->dims()); ctx->SetOutputDim("Residual", x_dims);
ctx.Output<Tensor>("Out")->Resize({x->dims()[0], 1}); ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx->ShareLoD("X", "Out");
} }
}; };
...@@ -55,7 +55,7 @@ class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -55,7 +55,7 @@ class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
"The target value of huber loss op." "The target value of huber loss op."
"Y is a 2-D tensor with shape [batch_size, 1]."); "Y is a 2-D tensor with shape [batch_size, 1].");
AddOutput("Residual", AddOutput("Residual",
"Intermediate tensor to cache residual value of Y and X." "Intermediate tensor to cache residual value between Y and X."
"The shape is same as Input(X) and will be reused in backward.") "The shape is same as Input(X) and will be reused in backward.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Out", AddOutput("Out",
...@@ -82,25 +82,30 @@ class HuberLossGradOp : public framework::OperatorWithKernel { ...@@ -82,25 +82,30 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(const framework::InferShapeContext& ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
auto* x = ctx.Input<Tensor>("X"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
auto* y = ctx.Input<Tensor>("Y"); PADDLE_ENFORCE(ctx->HasInput("Residual"),
auto* residual = ctx.Input<Tensor>("Residual"); "Input(Residual) should not be null.");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); "Input(Out@GRAD) should not be null.");
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_NOT_NULL(x, "Input(X) should not be null."); auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_NOT_NULL(y, "Input(Y) should not be null."); auto residual_dims = ctx->GetInputDim("Residual");
PADDLE_ENFORCE_NOT_NULL(residual, "Input(Residual) should not be null."); auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
PADDLE_ENFORCE_EQ(residual->dims(), x->dims()); PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims());
auto x_grad_name = framework::GradVarName("X");
if (x_grad) x_grad->Resize(x->dims()); auto y_grad_name = framework::GradVarName("Y");
if (y_grad) y_grad->Resize(y->dims()); if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
} }
}; };
......
...@@ -42,14 +42,14 @@ struct HuberLossForward { ...@@ -42,14 +42,14 @@ struct HuberLossForward {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class HuberLossKernel : public framework::OpKernel { class HuberLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y"); auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual"); auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out"); auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta")); auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
auto x = EigenVector<T>::Flatten(*in0); auto x = EigenVector<T>::Flatten(*in0);
...@@ -65,11 +65,10 @@ class HuberLossKernel : public framework::OpKernel { ...@@ -65,11 +65,10 @@ class HuberLossKernel : public framework::OpKernel {
template <typename T> template <typename T>
struct HuberLossBackward { struct HuberLossBackward {
HOSTDEVICE HuberLossBackward(const T& delta, bool is_x) HOSTDEVICE HuberLossBackward(const T& delta, T sign)
: is_x(is_x), delta(delta) {} : sign(sign), delta(delta) {}
HOSTDEVICE T operator()(const T& val) const { HOSTDEVICE T operator()(const T& val) const {
T sign = is_x ? -1.0 : 1.0;
T abs_val = std::abs(val); T abs_val = std::abs(val);
if (abs_val <= delta) { if (abs_val <= delta) {
return sign * val; return sign * val;
...@@ -82,12 +81,12 @@ struct HuberLossBackward { ...@@ -82,12 +81,12 @@ struct HuberLossBackward {
} }
} }
bool is_x; T sign;
T delta; T delta;
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class HuberLossGradKernel : public framework::OpKernel { class HuberLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Residual"); auto* in0 = context.Input<Tensor>("Residual");
...@@ -104,14 +103,14 @@ class HuberLossGradKernel : public framework::OpKernel { ...@@ -104,14 +103,14 @@ class HuberLossGradKernel : public framework::OpKernel {
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0); auto x_grad = EigenVector<T>::Flatten(*out0);
x_grad.device(place) = x_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, true)); out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
} }
if (out1) { if (out1) {
out1->mutable_data<T>(context.GetPlace()); out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1); auto y_grad = EigenVector<T>::Flatten(*out1);
y_grad.device(place) = y_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, false)); out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
} }
} }
}; };
......
...@@ -32,15 +32,15 @@ class TestHuberLossOp(OpTest): ...@@ -32,15 +32,15 @@ class TestHuberLossOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05) self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008)
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( self.check_grad(
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("residual")) ['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual"))
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( self.check_grad(
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('residual')) ['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册