提交 5939a17c 编写于 作者: Y yangyaming

Follow comments and adapt to new interface.

上级 05211610
......@@ -21,24 +21,24 @@ class HuberLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must be initialized.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) must be initialized.");
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized.");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
PADDLE_ENFORCE_EQ(x->dims(), y->dims());
PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2,
PADDLE_ENFORCE_EQ(x_dims, y_dims);
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
"The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1].");
PADDLE_ENFORCE_EQ(x->dims()[1], 1,
PADDLE_ENFORCE_EQ(x_dims[1], 1,
"Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
ctx.Output<Tensor>("Residual")->Resize(x->dims());
ctx.Output<Tensor>("Out")->Resize({x->dims()[0], 1});
ctx->SetOutputDim("Residual", x_dims);
ctx->SetOutputDim("Out", {x_dims[0], 1});
ctx->ShareLoD("X", "Out");
}
};
......@@ -55,7 +55,7 @@ class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
"The target value of huber loss op."
"Y is a 2-D tensor with shape [batch_size, 1].");
AddOutput("Residual",
"Intermediate tensor to cache residual value of Y and X."
"Intermediate tensor to cache residual value between Y and X."
"The shape is same as Input(X) and will be reused in backward.")
.AsIntermediate();
AddOutput("Out",
......@@ -82,25 +82,30 @@ class HuberLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* residual = ctx.Input<Tensor>("Residual");
auto* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* y_grad = ctx.Output<Tensor>(framework::GradVarName("Y"));
PADDLE_ENFORCE_NOT_NULL(x, "Input(X) should not be null.");
PADDLE_ENFORCE_NOT_NULL(y, "Input(Y) should not be null.");
PADDLE_ENFORCE_NOT_NULL(residual, "Input(Residual) should not be null.");
PADDLE_ENFORCE_NOT_NULL(out_grad, "Input(Out@GRAD) should not be null.");
PADDLE_ENFORCE_EQ(residual->dims(), x->dims());
PADDLE_ENFORCE_EQ(out_grad->dims(), x->dims());
if (x_grad) x_grad->Resize(x->dims());
if (y_grad) y_grad->Resize(y->dims());
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Residual"),
"Input(Residual) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto residual_dims = ctx->GetInputDim("Residual");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(residual_dims, x_dims);
PADDLE_ENFORCE_EQ(out_grad_dims, x_dims);
auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(y_grad_name)) {
ctx->SetOutputDim(y_grad_name, y_dims);
}
}
};
......
......@@ -42,14 +42,14 @@ struct HuberLossForward {
};
template <typename Place, typename T, typename AttrType = T>
class HuberLossKernel : public framework::OpKernel {
class HuberLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
auto* in1 = context.Input<Tensor>("Y");
auto* out0 = context.Output<Tensor>("Residual");
auto* out1 = context.Output<Tensor>("Out");
auto delta = static_cast<T>(context.op().Attr<AttrType>("delta"));
auto delta = static_cast<T>(context.Attr<AttrType>("delta"));
auto place = context.GetEigenDevice<Place>();
auto x = EigenVector<T>::Flatten(*in0);
......@@ -65,11 +65,10 @@ class HuberLossKernel : public framework::OpKernel {
template <typename T>
struct HuberLossBackward {
HOSTDEVICE HuberLossBackward(const T& delta, bool is_x)
: is_x(is_x), delta(delta) {}
HOSTDEVICE HuberLossBackward(const T& delta, T sign)
: sign(sign), delta(delta) {}
HOSTDEVICE T operator()(const T& val) const {
T sign = is_x ? -1.0 : 1.0;
T abs_val = std::abs(val);
if (abs_val <= delta) {
return sign * val;
......@@ -82,12 +81,12 @@ struct HuberLossBackward {
}
}
bool is_x;
T sign;
T delta;
};
template <typename Place, typename T, typename AttrType = T>
class HuberLossGradKernel : public framework::OpKernel {
class HuberLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Residual");
......@@ -104,14 +103,14 @@ class HuberLossGradKernel : public framework::OpKernel {
out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0);
x_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, true));
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, -1.0));
}
if (out1) {
out1->mutable_data<T>(context.GetPlace());
auto y_grad = EigenVector<T>::Flatten(*out1);
y_grad.device(place) =
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, false));
out_grad * residual.unaryExpr(HuberLossBackward<T>(delta, 1.0));
}
}
};
......
......@@ -32,15 +32,15 @@ class TestHuberLossOp(OpTest):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.05)
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008)
def test_check_grad_ingore_x(self):
self.check_grad(
['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("residual"))
['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual"))
def test_check_grad_ingore_y(self):
self.check_grad(
['X'], 'Out', max_relative_error=0.5, no_grad_set=set('residual'))
['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual'))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册