提交 53ab7e78 编写于 作者: Y yangyaming

Adapt new interface.

上级 0728943d
......@@ -111,7 +111,8 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OP(smooth_l1_loss, ops::SmoothL1LossOp,
ops::SmoothL1LossOpMaker<float>, ops::SmoothL1LossGradOp);
ops::SmoothL1LossOpMaker<float>, smooth_l1_loss_grad,
ops::SmoothL1LossGradOp);
REGISTER_OP_CPU_KERNEL(
smooth_l1_loss, ops::SmoothL1LossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
......
......@@ -59,7 +59,7 @@ class SmoothL1LossKernel : public framework::OpKernel {
out1->mutable_data<T>(context.GetPlace());
auto place = context.GetEigenDevice<Place>();
auto sigma = static_cast<T>(context.op_.GetAttr<AttrType>("sigma"));
auto sigma = static_cast<T>(context.op().Attr<AttrType>("sigma"));
T sigma2 = sigma * sigma;
bool has_weight = (in2 != nullptr) && (in3 != nullptr);
......@@ -122,7 +122,7 @@ class SmoothL1LossGradKernel : public framework::OpKernel {
auto* in1 = context.Input<Tensor>("OutsideWeight");
auto* in2 = context.Input<Tensor>("diff");
auto* og = context.Input<Tensor>(framework::GradVarName("Out"));
auto sigma = static_cast<T>(context.op_.GetAttr<AttrType>("sigma"));
auto sigma = static_cast<T>(context.op().Attr<AttrType>("sigma"));
T sigma2 = sigma * sigma;
bool has_weight = (in0 != nullptr) && (in1 != nullptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册