未验证 提交 3dab2e20 编写于 作者: Z Zhong Hui 提交者: GitHub

Add op define extra for norm and frobenius norm op. (#35329)

上级 a53460aa
...@@ -35,7 +35,12 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -35,7 +35,12 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Norm", AddOutput("Norm",
"(Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will " "(Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will "
"be used in backward kernel.") "be used in backward kernel.")
.AsIntermediate(); .AsIntermediate()
.AsExtra();
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training.")
.SetDefault(false);
AddOutput("Out", "(Tensor) A tensor of the same shape as X."); AddOutput("Out", "(Tensor) A tensor of the same shape as X.");
AddComment(R"DOC( AddComment(R"DOC(
...@@ -59,11 +64,14 @@ class NormOp : public framework::OperatorWithKernel { ...@@ -59,11 +64,14 @@ class NormOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NormOp"); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "NormOp");
auto xdim = ctx->GetInputDim("X"); auto xdim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", xdim); ctx->SetOutputDim("Out", xdim);
if (ctx->Attrs().Get<bool>("is_test") == false) {
int axis = ctx->Attrs().Get<int>("axis"); int axis = ctx->Attrs().Get<int>("axis");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
xdim[axis] = 1; xdim[axis] = 1;
ctx->SetOutputDim("Norm", xdim); ctx->SetOutputDim("Norm", xdim);
} }
}
}; };
class NormOpGrad : public framework::OperatorWithKernel { class NormOpGrad : public framework::OperatorWithKernel {
......
...@@ -65,16 +65,29 @@ class NormCUDAKernel : public framework::OpKernel<T> { ...@@ -65,16 +65,29 @@ class NormCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X"); auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out"); auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
const T* x = in_x->data<T>();
T* y = out_y->mutable_data<T>(ctx.GetPlace());
T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims(); auto xdim = in_x->dims();
auto ndim = out_norm->dims();
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
bool is_test = ctx.Attr<bool>("is_test");
framework::Tensor* out_norm;
framework::Tensor out_norm_tmp;
if (is_test) {
auto out_dim = in_x->dims();
out_dim[axis] = 1;
out_norm = &out_norm_tmp;
out_norm->Resize(out_dim);
} else {
out_norm = ctx.Output<framework::Tensor>("Norm");
}
const T* x = in_x->data<T>();
T* y = out_y->mutable_data<T>(ctx.GetPlace());
T* norm = out_norm->mutable_data<T>(ctx.GetPlace());
int pre, n, post; int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post); GetDims(xdim, axis, &pre, &n, &post);
......
...@@ -38,9 +38,6 @@ class NormKernel : public framework::OpKernel<T> { ...@@ -38,9 +38,6 @@ class NormKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X"); auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out"); auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
out_y->mutable_data<T>(ctx.GetPlace());
out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims(); auto xdim = in_x->dims();
T eps = static_cast<T>(ctx.Attr<float>("epsilon")); T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
...@@ -49,6 +46,22 @@ class NormKernel : public framework::OpKernel<T> { ...@@ -49,6 +46,22 @@ class NormKernel : public framework::OpKernel<T> {
int pre, n, post; int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post); GetDims(xdim, axis, &pre, &n, &post);
bool is_test = ctx.Attr<bool>("is_test");
framework::Tensor* out_norm;
framework::Tensor out_norm_tmp;
if (is_test) {
auto out_dim = in_x->dims();
out_dim[axis] = 1;
out_norm = &out_norm_tmp;
out_norm->Resize(out_dim);
} else {
out_norm = ctx.Output<framework::Tensor>("Norm");
}
out_y->mutable_data<T>(ctx.GetPlace());
out_norm->mutable_data<T>(ctx.GetPlace());
auto* place = ctx.template device_context<DeviceContext>().eigen_device(); auto* place = ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 3> shape(pre, n, post); Eigen::DSizes<int, 3> shape(pre, n, post);
......
...@@ -645,7 +645,8 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -645,7 +645,8 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(-1); .SetDefault(-1);
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false)
.AsExtra();
AddComment(string::Sprintf(R"DOC( AddComment(string::Sprintf(R"DOC(
%s Operator. %s Operator.
......
...@@ -89,6 +89,33 @@ class TestNormOp5(TestNormOp): ...@@ -89,6 +89,33 @@ class TestNormOp5(TestNormOp):
pass pass
@skip_check_grad_ci(reason="skip check grad for test mode.")
class TestNormTestOp(OpTest):
def setUp(self):
self.op_type = "norm"
self.init_test_case()
x = np.random.random(self.shape).astype("float64")
y, norm = l2_norm(x, self.axis, self.epsilon)
self.inputs = {'X': x}
self.attrs = {
'epsilon': self.epsilon,
'axis': self.axis,
'is_test': True
}
self.outputs = {'Out': y}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
def init_test_case(self):
self.shape = [2, 3, 4, 5]
self.axis = 1
self.epsilon = 1e-8
class API_NormTest(unittest.TestCase): class API_NormTest(unittest.TestCase):
def test_errors(self): def test_errors(self):
with fluid.program_guard(fluid.Program()): with fluid.program_guard(fluid.Program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册