diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 3ef443d1c7f475cbd578078db02fb5e0d500d060..e4791e6c074bf5f2dcfd1269577c77d9a9fcbf2a 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/reduce_op.h" +#include "paddle/operators/net_op.h" namespace paddle { namespace operators { @@ -161,6 +162,66 @@ class ReduceMinOpMaker : public ReduceOpMaker { } }; +class NormOp : public NetOp { + public: + NormOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName, + "Input(X) of NormOp should not be null."); + PADDLE_ENFORCE_NE(Output("AbsOut"), framework::kEmptyVarName, + "Output(AbsOut) of NormOp should not be null."); + PADDLE_ENFORCE_NE(Output("PowOut"), framework::kEmptyVarName, + "Output(PowOut) of NormOp should not be null."); + PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName, + "Output(SumOut) of NormOp should not be null."); + PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName, + "Output(Out) of NormOp should not be null."); + auto dim = Attr("dim"); + auto keep_dim = Attr("keep_dim"); + auto p = Attr("p"); + PADDLE_ENFORCE_GT(p, 0, "Order of the norm should be positive."); + AppendOp(framework::OpRegistry::CreateOp("abs", {{"X", {Input("X")}}}, + {{"Y", {Output("AbsOut")}}}, {})); + AppendOp(framework::OpRegistry::CreateOp("pow", {{"X", {Output("AbsOut")}}}, + {{"Y", {Output("PowOut")}}}, + {{"factor", p}})); + framework::AttributeMap sum_attr; + sum_attr["dim"] = dim; + sum_attr["keep_dim"] = keep_dim; + AppendOp(framework::OpRegistry::CreateOp( + "reduce_sum", {{"X", {Output("PowOut")}}}, + {{"Out", {Output("SumOut")}}}, sum_attr)); + AppendOp(framework::OpRegistry::CreateOp( + "pow", {{"X", {Output("SumOut")}}}, {{"Y", {Output("Out")}}}, + {{"factor", static_cast(1. / p)}})); + CompleteAddOp(false); + } +}; + +class NormOpMaker : public ReduceOpMaker { + public: + NormOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : ReduceOpMaker(proto, op_checker) { + AddOutput("AbsOut", + "(Tensor) The intermediate output of Norm operator, " + "saving the absolute value of the input tensor X.") + .AsIntermediate(); + AddOutput("PowOut", + "(Tensor) The intermediate output of Norm operator, " + "saving the p-th power of the output tensor AbsOut.") + .AsIntermediate(); + AddOutput("SumOut", + "(Tensor) the intermediate output of Norm operator, " + "saving the sum of PowOut reduced on the given dimension.") + .AsIntermediate(); + AddAttr("p", "(float, default 2) The order of Norm.").SetDefault(2); + SetComment("Norm", "vector p-norm"); + AddComment(comment_); + } +}; + } // namespace operators } // namespace paddle @@ -201,3 +262,5 @@ REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(reduce_min_grad, ops::ReduceGradKernel); + +REGISTER_OP_WITHOUT_GRADIENT(norm, ops::NormOp, ops::NormOpMaker); diff --git a/python/paddle/v2/framework/tests/test_reduce_op.py b/python/paddle/v2/framework/tests/test_reduce_op.py index 70359d60cbe656150877673c63e81eae92d8ab9a..0fec31c2e22e1eda2c62aa9b38487d703815f414 100644 --- a/python/paddle/v2/framework/tests/test_reduce_op.py +++ b/python/paddle/v2/framework/tests/test_reduce_op.py @@ -85,5 +85,33 @@ class Test1DReduce(OpTest): self.check_grad(['X'], 'Out') +class TestNorm(OpTest): + def setUp(self): + # use x away from 0 to avoid errors of numerical gradient when gradient near 0 + x = np.random.random((5, 6, 10)).astype("float32") + 0.2 + p = 2 + dim = 1 + keep_dim = False + abs_out = np.absolute(x) + pow_out = np.power(x, p) + sum_out = np.sum(pow_out, axis=dim, keepdims=keep_dim) + out = np.power(sum_out, 1. / p) + self.op_type = "norm" + self.inputs = {'X': x} + self.attrs = {"p": p, "dim": dim, "keep_dim": keep_dim} + self.outputs = { + "AbsOut": abs_out, + "PowOut": pow_out, + "SumOut": sum_out, + "Out": out + } + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out', max_relative_error=0.01) + + if __name__ == '__main__': unittest.main()