diff --git a/paddle/operators/norm_op.cc b/paddle/operators/norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3835da630d2142b20134b6554f268824d4718752 --- /dev/null +++ b/paddle/operators/norm_op.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/norm_op.h" +namespace paddle { +namespace operators { + +class NormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of norm operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddInput("Scale", + "(Tensor) The input tensor of norm operator. " + "The format of input tensor is C * 1."); + AddAttr("epsilon", + "(float, default 1e-10) Constant " + "for numerical stability.") + .SetDefault(1.0e-10f); + AddOutput("Out", + "(Tensor) The output tensor of norm operator." + "N * M." + "M = C * H * W"); + AddComment(R"DOC( + "Input shape: $(N, C, H, W)$ + Sclae shape: $(C, 1)$ + Output shape: $(N, C, H, W)$ + Where + forward + $$ + [\frac {x_{1}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{2}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{3}}{\sqrt{\sum{x_{i}^{2}}}} \cdot \cdot \cdot \frac {x_{n}}{\sqrt{\sum{x_{i}^{2}}}}] + $$ + backward + $$ + \frac{\frac{\mathrm{d}L }{\mathrm{d}y_{1}} - \frac {x_{1}\sum {\frac{\mathrm{d} L}{\mathrm{d} y_{j}}}x_{j}}{\sum x_{j}^{2}} }{\sqrt{\sum{x_{j}^{2}}}} + $$ + )DOC"); + } +}; + +class NormOp : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of NormOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of NormOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", in_x_dims); + } +}; + +class NormOpGrad : public framework::OperatorWithKernel { + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad); +REGISTER_OP_CPU_KERNEL( + norm, ops::NormKernel, + ops::NormKernel); +REGISTER_OP_CPU_KERNEL( + norm_grad, ops::NormGradKernel, + ops::NormGradKernel); diff --git a/paddle/operators/norm_op.cu b/paddle/operators/norm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..7d84aaa73248d6bbf1cf104105ce2a2847000eea --- /dev/null +++ b/paddle/operators/norm_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#define EIGEN_USE_GPU + +#include "paddle/operators/norm_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + norm, ops::NormKernel, + ops::NormKernel); +REGISTER_OP_CUDA_KERNEL( + norm_grad, ops::NormGradKernel, + ops::NormGradKernel); diff --git a/paddle/operators/norm_op.h b/paddle/operators/norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d3dcf4834162f7a57c9a33b51736e8f4d4446134 --- /dev/null +++ b/paddle/operators/norm_op.h @@ -0,0 +1,162 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +using EigenVector = framework::EigenVector; +template +using EigenMatrix = framework::EigenMatrix; + +template +class NormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* scale = context.Input("Scale"); + auto* out = context.Output("Out"); + T epsilon = context.Attr("epsilon"); + out->mutable_data(context.GetPlace()); + int batch_size = in_x->dims()[0]; + int channels = in_x->dims()[1]; + int height = in_x->dims()[2]; + int width = in_x->dims()[3]; + int fea_len = height * width; + auto* place = + context.template device_context().eigen_device(); + auto x = EigenMatrix::From( + *in_x, framework::make_ddim({batch_size, fea_len * channels})); + // get square + framework::Tensor x_square; + x_square.mutable_data(in_x->dims(), context.GetPlace()); + auto x_square_eigen = EigenMatrix::From( + x_square, framework::make_ddim({batch_size, fea_len * channels})); + x_square_eigen.device(*place) = x.square(); + auto scale_eigen = EigenVector::Flatten(*scale); + for (int n = 0; n < batch_size; ++n) { + framework::Tensor in_x_batch = in_x->Slice(n, n + 1); + auto in_x_batch_eigen = EigenMatrix::From( + in_x_batch, framework::make_ddim({channels, fea_len})); + framework::Tensor x_square_batch = x_square.Slice(n, n + 1); + auto x_square_batch_eigen = EigenMatrix::From( + x_square_batch, framework::make_ddim({channels, fea_len})); + framework::Tensor out_batch = out->Slice(n, n + 1); + auto out_batch_eigen = EigenMatrix::From( + out_batch, framework::make_ddim({channels, fea_len})); + framework::Tensor tmp_tensor; + tmp_tensor.mutable_data(framework::make_ddim({1, fea_len}), + context.GetPlace()); + auto tmp = EigenVector::Flatten(tmp_tensor); + // get colsum and sqrt , inverse + auto dim = Eigen::array({{0}}); + tmp.device(*place) = x_square_batch_eigen.sum(dim); + tmp.device(*place) = (tmp + epsilon).sqrt().inverse(); + Eigen::array broadcast_dim_col; + broadcast_dim_col[1] = 1; + broadcast_dim_col[0] = channels; + out_batch_eigen.device(*place) = + in_x_batch_eigen * (tmp.broadcast(broadcast_dim_col)); + Eigen::array broadcast_dim_row; + broadcast_dim_row[1] = fea_len; + broadcast_dim_row[0] = 1; + out_batch_eigen.device(*place) = + out_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row)); + } + } +}; +template +class NormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* scale = context.Input("Scale"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + T epsilon = context.Attr("epsilon"); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); + in_x_grad->mutable_data(context.GetPlace()); + int batch_size = in_x->dims()[0]; + int channels = in_x->dims()[1]; + int height = in_x->dims()[2]; + int width = in_x->dims()[3]; + int fea_len = height * width; + auto* place = + context.template device_context().eigen_device(); + + auto scale_eigen = EigenVector::Flatten(*scale); + auto x = EigenMatrix::From( + *in_x, framework::make_ddim({batch_size, fea_len * channels})); + // get square + framework::Tensor x_square; + x_square.mutable_data(in_x->dims(), context.GetPlace()); + auto x_square_eigen = EigenMatrix::From( + x_square, framework::make_ddim({batch_size, fea_len * channels})); + x_square_eigen.device(*place) = x.square(); + + for (int n = 0; n < batch_size; ++n) { + framework::Tensor in_x_batch = in_x->Slice(n, n + 1); + auto in_x_batch_eigen = EigenMatrix::From( + in_x_batch, framework::make_ddim({channels, fea_len})); + framework::Tensor in_g_batch = in_x_grad->Slice(n, n + 1); + auto in_g_batch_eigen = EigenMatrix::From( + in_g_batch, framework::make_ddim({channels, fea_len})); + framework::Tensor x_square_batch = x_square.Slice(n, n + 1); + auto x_square_batch_eigen = EigenMatrix::From( + x_square_batch, framework::make_ddim({channels, fea_len})); + framework::Tensor outg_batch = out_grad->Slice(n, n + 1); + auto outg_batch_eigen = EigenMatrix::From( + outg_batch, framework::make_ddim({channels, fea_len})); + + framework::Tensor tmp_tensor; + tmp_tensor.mutable_data(framework::make_ddim({1, fea_len}), + context.GetPlace()); + auto tmp_eigen = EigenVector::Flatten(tmp_tensor); + auto dim = Eigen::array({{0}}); + tmp_eigen.device(*place) = (in_x_batch_eigen * outg_batch_eigen).sum(dim); + framework::Tensor norm_tmp_tensor; + norm_tmp_tensor.mutable_data(framework::make_ddim({1, fea_len}), + context.GetPlace()); + auto norm_tmp_eigen = EigenVector::Flatten(norm_tmp_tensor); + norm_tmp_eigen.device(*place) = + (x_square_batch_eigen.sum(dim) + epsilon).sqrt(); + Eigen::array broadcast_dim_col; + broadcast_dim_col[1] = 1; + broadcast_dim_col[0] = channels; + in_g_batch_eigen.device(*place) = + in_x_batch_eigen * tmp_eigen.broadcast(broadcast_dim_col); + in_g_batch_eigen.device(*place) = + in_g_batch_eigen / + (norm_tmp_eigen * norm_tmp_eigen).broadcast(broadcast_dim_col); + in_g_batch_eigen.device(*place) = outg_batch_eigen - in_g_batch_eigen; + // outg_batch_eigen + (in_g_batch_eigen * -1); + in_g_batch_eigen.device(*place) = + in_g_batch_eigen / norm_tmp_eigen.broadcast(broadcast_dim_col); + Eigen::array broadcast_dim_row; + broadcast_dim_row[1] = fea_len; + broadcast_dim_row[0] = 1; + in_g_batch_eigen.device(*place) = + in_g_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row)); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_norm_op.py b/python/paddle/v2/fluid/tests/test_norm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..23e6841b916a6a26f3c1b4ec4a2569e70e2fdbf9 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_norm_op.py @@ -0,0 +1,57 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def norm(input, scale, epsilon): + s0, s1, s2, s3 = input.shape + x_square = input * input + for i in xrange(s0): + input_batch = input[i:i + 1, :, :, :] + input_batch = input_batch.reshape(s1, s2 * s3) + x_square_batch = x_square[i:i + 1, :, :, :] + x_square_batch = x_square_batch.reshape(s1, s2 * s3) + square_colsum = x_square_batch.sum(axis=0) + epsilon + tmp = pow(square_colsum, 0.5) + tmp = np.reciprocal(tmp) + tmp_tile = np.tile(tmp, s1) + tmp_tile = tmp_tile.reshape(s1, s2 * s3) + scale_tile = np.tile(scale, (1, s2 * s3)) + scale_tile = scale_tile.reshape(s1, s2 * s3) + out_batch = input_batch * tmp_tile * scale_tile + out_batch = out_batch.reshape(1, s1, s2, s3) + if i == 0: + out = out_batch + else: + out = np.concatenate((out, out_batch), 0) + out.reshape(s0, s1, s2, s3) + return out + + +class TestNormOp(OpTest): + def setUp(self): + self.op_type = "norm" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + scale = np.array([10, 10, 10]) + self.inputs = { + 'X': input.astype('float32'), + 'Scale': scale.astype('float32') + } + self.attrs = {'epsilon': self.epsilon} + output = norm(input, scale, self.epsilon) + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.shape = [1, 3, 2, 2] + self.epsilon = 1e-6 + + +if __name__ == '__main__': + unittest.main()