提交 a53191f1 编写于 作者: G guosheng

Add norm_op

上级 9fbf94b6
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/reduce_op.h" #include "paddle/operators/reduce_op.h"
#include "paddle/operators/net_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -161,6 +162,66 @@ class ReduceMinOpMaker : public ReduceOpMaker { ...@@ -161,6 +162,66 @@ class ReduceMinOpMaker : public ReduceOpMaker {
} }
}; };
class NormOp : public NetOp {
public:
NormOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: NetOp(type, inputs, outputs, attrs) {
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
"Input(X) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("AbsOut"), framework::kEmptyVarName,
"Output(AbsOut) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("PowOut"), framework::kEmptyVarName,
"Output(PowOut) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName,
"Output(SumOut) of NormOp should not be null.");
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
"Output(Out) of NormOp should not be null.");
auto dim = Attr<int>("dim");
auto keep_dim = Attr<bool>("keep_dim");
auto p = Attr<float>("p");
PADDLE_ENFORCE_GT(p, 0, "Order of the norm should be positive.");
AppendOp(framework::OpRegistry::CreateOp("abs", {{"X", {Input("X")}}},
{{"Y", {Output("AbsOut")}}}, {}));
AppendOp(framework::OpRegistry::CreateOp("pow", {{"X", {Output("AbsOut")}}},
{{"Y", {Output("PowOut")}}},
{{"factor", p}}));
framework::AttributeMap sum_attr;
sum_attr["dim"] = dim;
sum_attr["keep_dim"] = keep_dim;
AppendOp(framework::OpRegistry::CreateOp(
"reduce_sum", {{"X", {Output("PowOut")}}},
{{"Out", {Output("SumOut")}}}, sum_attr));
AppendOp(framework::OpRegistry::CreateOp(
"pow", {{"X", {Output("SumOut")}}}, {{"Y", {Output("Out")}}},
{{"factor", static_cast<float>(1. / p)}}));
CompleteAddOp(false);
}
};
class NormOpMaker : public ReduceOpMaker {
public:
NormOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: ReduceOpMaker(proto, op_checker) {
AddOutput("AbsOut",
"(Tensor) The intermediate output of Norm operator, "
"saving the absolute value of the input tensor X.")
.AsIntermediate();
AddOutput("PowOut",
"(Tensor) The intermediate output of Norm operator, "
"saving the p-th power of the output tensor AbsOut.")
.AsIntermediate();
AddOutput("SumOut",
"(Tensor) the intermediate output of Norm operator, "
"saving the sum of PowOut reduced on the given dimension.")
.AsIntermediate();
AddAttr<float>("p", "(float, default 2) The order of Norm.").SetDefault(2);
SetComment("Norm", "vector p-norm");
AddComment(comment_);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -201,3 +262,5 @@ REGISTER_OP_CPU_KERNEL( ...@@ -201,3 +262,5 @@ REGISTER_OP_CPU_KERNEL(
REGISTER_OP_CPU_KERNEL(reduce_min_grad, REGISTER_OP_CPU_KERNEL(reduce_min_grad,
ops::ReduceGradKernel<paddle::platform::CPUPlace, float, ops::ReduceGradKernel<paddle::platform::CPUPlace, float,
ops::MaxOrMinGradFunctor>); ops::MaxOrMinGradFunctor>);
REGISTER_OP_WITHOUT_GRADIENT(norm, ops::NormOp, ops::NormOpMaker);
...@@ -85,5 +85,33 @@ class Test1DReduce(OpTest): ...@@ -85,5 +85,33 @@ class Test1DReduce(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestNorm(OpTest):
def setUp(self):
# use x away from 0 to avoid errors of numerical gradient when gradient near 0
x = np.random.random((5, 6, 10)).astype("float32") + 0.2
p = 2
dim = 1
keep_dim = False
abs_out = np.absolute(x)
pow_out = np.power(x, p)
sum_out = np.sum(pow_out, axis=dim, keepdims=keep_dim)
out = np.power(sum_out, 1. / p)
self.op_type = "norm"
self.inputs = {'X': x}
self.attrs = {"p": p, "dim": dim, "keep_dim": keep_dim}
self.outputs = {
"AbsOut": abs_out,
"PowOut": pow_out,
"SumOut": sum_out,
"Out": out
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', max_relative_error=0.01)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册