diff --git a/paddle/fluid/operators/group_norm_op_npu.cc b/paddle/fluid/operators/group_norm_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ef8320cbdecd630069da05f71814bffb7f8e063 --- /dev/null +++ b/paddle/fluid/operators/group_norm_op_npu.cc @@ -0,0 +1,306 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/group_norm_op.h" +#include +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +struct GroupNormFunction { + public: + explicit GroupNormFunction(const framework::ExecutionContext& ctx) + : ctx(ctx) { + place = ctx.GetPlace(); + stream = ctx.template device_context() + .stream(); + } + void ReduceMean(const Tensor* x, Tensor* y, const std::vector& dim, + bool keep_dims = true) { + // y should be init first + const auto& runner = NpuOpRunner("ReduceMeanD", {*x}, {*y}, + {{"axes", dim}, {"keep_dims", keep_dims}}); + runner.Run(stream); + } + void ReduceSum(const Tensor* x, Tensor* y, const std::vector& dim, + bool keep_dims = true) { + // y should be init first + const auto& runner = NpuOpRunner("ReduceSumD", {*x}, {*y}, + {{"axes", dim}, {"keep_dims", keep_dims}}); + runner.Run(stream); + } + void Add(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("AddV2", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Sub(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("Sub", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Mul(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("Mul", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Div(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("Div", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void DivNoNan(const Tensor* x, const Tensor* y, Tensor* z) { + // y should be init first + const auto& runner = NpuOpRunner("DivNoNan", {*x, *y}, {*z}, {}); + runner.Run(stream); + } + void Transpose(const Tensor* x, Tensor* y, const std::vector& axis) { + // y should be init first + const auto& runner = + NpuOpRunner("TransposeD", {*x}, {*y}, {{"perm", axis}}); + runner.Run(stream); + } + void Sqrt(const Tensor* x, Tensor* y) { + // y should be init first + const auto& runner = NpuOpRunner("Sqrt", {*x}, {*y}, {}); + runner.Run(stream); + } + void Adds(const Tensor* x, float scalar, Tensor* y) { + // y should be init first + const auto& runner = NpuOpRunner("Adds", {*x}, {*y}, {{"value", scalar}}); + runner.Run(stream); + } + Tensor ReduceMeanToNG(const Tensor* x, const DataLayout& data_layout, + const int64_t N, const int64_t C, const int64_t H, + const int64_t W, const int G) { + Tensor y(x->type()); + // y.mutable_data( {N,G,1}, place ); + if (data_layout == DataLayout::kNCHW) { + y.mutable_data({N, G, 1}, place); + // shape of x is [N, G, C*H*W/G] + this->ReduceMean(x, &y, std::vector{2}); + } else { + y.mutable_data({N, 1, G}, place); + // shape of x is [N, C*H*W/G, G] + Tensor x_trans(x->type()); + x_trans.mutable_data({N, G, C * H * W / G}, place); + this->Transpose(x, &x_trans, std::vector{0, 2, 1}); + this->ReduceMean(&x_trans, &y, std::vector{2}); + } + return y; + } + + private: + platform::Place place; + aclrtStream stream; + const framework::ExecutionContext& ctx; +}; + +template +class GroupNormNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + const float epsilon = ctx.Attr("epsilon"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* x = ctx.Input("X"); + + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* var = ctx.Output("Variance"); + const auto groups = ctx.Attr("groups"); + + auto place = ctx.GetPlace(); + Tensor xnorm(x->type()); + xnorm.mutable_data(x->dims(), place); + GroupNormFunction F(ctx); + if (data_layout != DataLayout::kNCHW) { + xnorm.Resize({x->dims()[0], x->dims()[3], x->dims()[1], x->dims()[2]}); + F.Transpose(x, &xnorm, std::vector{0, 3, 1, 2}); + } else { + TensorCopy(*x, platform::NPUPlace(), &xnorm); + } + auto N = xnorm.dims()[0]; + auto C = xnorm.dims()[1]; + auto H = xnorm.dims()[2]; + auto W = xnorm.dims()[3]; + xnorm.Resize({N * groups, C * H * W / groups}); + std::vector axis = {1}; + auto reduce_dim = mean->dims(); + + mean->mutable_data({N * groups, 1}, place); + var->mutable_data({N * groups, 1}, place); + y->mutable_data(place); + F.ReduceMean(&xnorm, mean, axis); + + F.Sub(&xnorm, mean, &xnorm); + Tensor sqr(x->type()); + sqr.mutable_data(xnorm.dims(), place); + + F.Mul(&xnorm, &xnorm, &sqr); + F.ReduceMean(&sqr, var, axis); + Tensor std(x->type()); + std.mutable_data(var->dims(), place); + F.Adds(var, epsilon, &std); + F.Sqrt(&std, &std); + y->Resize(xnorm.dims()); + F.Div(&xnorm, &std, y); + y->Resize({N, C, H, W}); + if (scale) { + Tensor scale_t(scale->type()); + scale_t.ShareDataWith(*scale); + scale_t.Resize({C, 1, 1}); + F.Mul(y, &scale_t, y); + } + if (bias) { + Tensor bias_t(bias->type()); + bias_t.ShareDataWith(*bias); + bias_t.Resize({C, 1, 1}); + F.Add(y, &bias_t, y); + } + if (data_layout != DataLayout::kNCHW) { + F.Transpose(y, y, std::vector{0, 2, 3, 1}); + y->Resize({x->dims()}); + } + mean->Resize(reduce_dim); + var->Resize(reduce_dim); + } +}; + +template +class GroupNormGradNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const std::string data_layout_str = ctx.Attr("data_layout"); + const DataLayout data_layout = + framework::StringToDataLayout(data_layout_str); + const float epsilon = ctx.Attr("epsilon"); + auto* y = ctx.Input("Y"); + auto* var = ctx.Input("Variance"); + + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* d_y = ctx.Input(framework::GradVarName("Y")); + const auto G = ctx.Attr("groups"); + + // init output + auto* d_x = ctx.Output(framework::GradVarName("X")); + auto* d_scale = ctx.Output(framework::GradVarName("Scale")); + auto* d_bias = ctx.Output(framework::GradVarName("Bias")); + + GroupNormFunction F(ctx); + auto place = ctx.GetPlace(); + auto _type = y->type(); + + Tensor xnorm(_type); + xnorm.mutable_data(y->dims(), place); + Tensor scale_share(_type); + scale_share.ShareDataWith(*scale); + Tensor bias_share(_type); + bias_share.ShareDataWith(*bias); + + int64_t N = y->dims()[0]; + int64_t C, H, W; + framework::DDim scale_bias_dim; + if (data_layout == DataLayout::kNCHW) { + C = y->dims()[1]; + H = y->dims()[2]; + W = y->dims()[3]; + scale_bias_dim = framework::make_ddim({C, 1, 1}); + } else { + C = y->dims()[3]; + H = y->dims()[1]; + W = y->dims()[2]; + scale_bias_dim = framework::make_ddim({1, 1, C}); + } + scale_share.Resize(scale_bias_dim); + bias_share.Resize(scale_bias_dim); + F.Sub(y, &bias_share, &xnorm); + F.DivNoNan(&xnorm, &scale_share, &xnorm); + + if (d_bias) { + d_bias->mutable_data(place); + if (data_layout == DataLayout::kNCHW) { + F.ReduceSum(d_y, d_bias, std::vector{0, 2, 3}, false); + } else { + F.ReduceSum(d_y, d_bias, std::vector{0, 1, 2}, false); + } + } + if (d_scale) { + d_scale->mutable_data(place); + Tensor dy_xnorm(_type); + dy_xnorm.mutable_data(d_y->dims(), place); + F.Mul(d_y, &xnorm, &dy_xnorm); + if (data_layout == DataLayout::kNCHW) { + F.ReduceSum(&dy_xnorm, d_scale, std::vector{0, 2, 3}); + } else { + F.ReduceSum(&dy_xnorm, d_scale, std::vector{0, 1, 2}); + } + } + + // std = Sqrt(var+epsilon), init shape = [ N, G ] + Tensor std(_type); + std.mutable_data(var->dims(), place); + F.Adds(var, epsilon, &std); + F.Sqrt(&std, &std); + // d_xnorm_std = dy_proc * scale / std + Tensor d_xnorm_std(_type); + d_xnorm_std.mutable_data(y->dims(), place); + F.Mul(d_y, &scale_share, &d_xnorm_std); + if (data_layout == DataLayout::kNCHW) { + xnorm.Resize({N, G, C * H * W / G}); + d_xnorm_std.Resize({N, G, C * H * W / G}); + std.Resize({N, G, 1}); + } else { + xnorm.Resize({N, C * H * W / G, G}); + d_xnorm_std.Resize({N, C * H * W / G, G}); + std.Resize({N, 1, G}); + } + F.Div(&d_xnorm_std, &std, &d_xnorm_std); + + // d_x = d_xnorm_std + // - Mean ( d_xnorm_std * x_norm, axis=1, keepdim=True ) * x_norm + // - Mean ( d_xnorm_std, axis=1, keepdim=True ) + d_x->mutable_data(place); + d_x->Resize(xnorm.dims()); + F.Mul(&d_xnorm_std, &xnorm, d_x); + Tensor dx1 = F.ReduceMeanToNG(d_x, data_layout, N, C, H, W, G); + F.Mul(&dx1, &xnorm, d_x); + + Tensor dx2 = F.ReduceMeanToNG(&d_xnorm_std, data_layout, N, C, H, W, G); + + F.Sub(&d_xnorm_std, d_x, d_x); + F.Sub(d_x, &dx2, d_x); + + d_x->Resize(y->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_NPU_KERNEL(group_norm, ops::GroupNormNPUKernel, + ops::GroupNormNPUKernel); +REGISTER_OP_NPU_KERNEL(group_norm_grad, ops::GroupNormGradNPUKernel, + ops::GroupNormGradNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab1161be36dd850facb9ac1ee19f213a509616a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_group_norm_op_npu.py @@ -0,0 +1,217 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import unittest +import numpy as np + +import sys +sys.path.append("..") + +from operator import mul +from op_test import OpTest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid + +paddle.enable_static() + + +def group_norm_naive(x, scale, bias, epsilon, groups, data_layout): + if data_layout == "NHWC": + x = np.transpose(x, (0, 3, 1, 2)) # NHWC => NCHW + N, C, H, W = x.shape + G = groups + x = x.reshape((N * G, -1)) + mean = np.mean(x, axis=1, keepdims=True) + var = np.var(x, axis=1, keepdims=True) + xnorm = (x - mean) / np.sqrt(var + epsilon) + xnorm = xnorm.reshape((N, C, H, W)) + output = xnorm * scale.reshape((-1, 1, 1)) + bias.reshape((-1, 1, 1)) + if data_layout == "NHWC": + output = np.transpose(output, (0, 2, 3, 1)) # NCHW => NHWC + xnorm = np.transpose(xnorm, (0, 2, 3, 1)) + return output, mean.reshape((N, G)), var.reshape((N, G)) + + +class TestGroupNormOpError(unittest.TestCase): + def test_errors(self): + with fluid.program_guard(fluid.Program(), fluid.Program()): + + def test_x_type(): + input = np.random.random(2, 100, 3, 5).astype('float32') + groups = 2 + fluid.layers.group_norm(input, groups) + + self.assertRaises(TypeError, test_x_type) + + def test_x_dtype(): + x2 = fluid.layers.data( + name='x2', shape=[2, 100, 3, 5], dtype='int32') + groups = 2 + fluid.layers.group_norm(x2, groups) + + self.assertRaises(TypeError, test_x_dtype) + + +class TestGroupNormOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = 'group_norm' + self.place = paddle.NPUPlace(0) + + self.init_dtype() + + self.data_format = "NCHW" + self.atol = 1e-6 + self.max_relative_error = 0.005 + self.shape = (2, 100, 3, 5) + self.attrs = {'epsilon': 1e-5, 'groups': 2, 'data_layout': "NCHW"} + self.compare_between_place = False + self.init_test_case() + + input = np.random.random(self.shape).astype(self.dtype) + if self.data_format == "NHWC": + input = np.transpose(input, (0, 2, 3, 1)) + scale = np.random.random([self.shape[1]]).astype(self.dtype) + bias = np.random.random([self.shape[1]]).astype(self.dtype) + output, mean, var = group_norm_naive( + input, scale, bias, self.attrs['epsilon'], self.attrs['groups'], + self.data_format) + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(input), + 'Scale': OpTest.np_dtype_to_fluid_dtype(scale), + 'Bias': OpTest.np_dtype_to_fluid_dtype(bias) + } + self.outputs = {'Y': output, 'Mean': mean, 'Variance': var} + self.attrs['data_layout'] = self.data_format + + def set_npu(self): + self.__class__.use_npu = True + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=self.atol) + + def test_check_grad(self): + if self.dtype == np.float16: + return + + self.__class__.exist_check_grad = True + inputs_to_check = ['X', 'Scale', 'Bias'] + output_names = 'Y' + no_grad_set = set() + cpu_place = fluid.CPUPlace() + cpu_grads = self._get_gradient(inputs_to_check, cpu_place, output_names, + no_grad_set) + npu_grads = self._get_gradient(inputs_to_check, self.place, + output_names, no_grad_set) + + self._assert_is_close(cpu_grads, npu_grads, inputs_to_check, + self.max_relative_error, + "Gradient Check between places") + + def init_test_case(self): + pass + + +class TestGroupNormOp1(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + + +class TestGroupNormOp2(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + + +class TestGroupNormOpBigEps1(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + self.attrs['epsilon'] = 0.5 + + +class TestGroupNormOpBigEps2(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + self.attrs['epsilon'] = 0.5 + + +class TestGroupNormOpBigEps3(TestGroupNormOp): + def init_test_case(self): + self.attrs['epsilon'] = 0.5 + + +class TestGroupNormOp1_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + self.data_format = "NHWC" + + +class TestGroupNormOp2_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + self.data_format = "NHWC" + + +class TestGroupNormOpBigEps1_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 1 + self.attrs['epsilon'] = 0.5 + self.data_format = "NHWC" + + +class TestGroupNormOpBigEps2_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['groups'] = 4 + self.attrs['epsilon'] = 0.5 + self.data_format = "NHWC" + + +class TestGroupNormOpBigEps3_With_NHWC(TestGroupNormOp): + def init_test_case(self): + self.attrs['epsilon'] = 0.5 + self.data_format = "NHWC" + + +class TestGroupNormOpFP16(TestGroupNormOp): + def init_dtype(self): + self.dtype = np.float16 + + +class TestGroupNormOpFP16_With_NHWC(TestGroupNormOp): + def init_dtype(self): + self.dtype = np.float16 + + def init_test_case(self): + self.data_format = "NHWC" + + +class TestGroupNormException(unittest.TestCase): + # data_layout is not NHWC or NCHW + def test_exception(self): + data = fluid.data(name='data', shape=[None, 3, 3, 4], dtype="float64") + + def attr_data_format(): + out = fluid.layers.group_norm( + input=data, groups=2, data_layout="NDHW") + + self.assertRaises(ValueError, attr_data_format) + + +if __name__ == '__main__': + unittest.main()