From c90d35564be5fc4d51e8c129ee2e908c1e36c4fe Mon Sep 17 00:00:00 2001 From: hong19860320 <9973393+hong19860320@users.noreply.github.com> Date: Tue, 13 Oct 2020 18:23:56 +0800 Subject: [PATCH] Add batch_norm and layer_norm XPU kernels (#27818) --- paddle/fluid/operators/batch_norm_op_xpu.cc | 167 +++++++++++ paddle/fluid/operators/layer_norm_op_xpu.cc | 114 ++++++++ .../unittests/xpu/test_batch_norm_op_xpu.py | 269 ++++++++++++++++++ .../unittests/xpu/test_layer_norm_op_xpu.py | 111 ++++++++ 4 files changed, 661 insertions(+) create mode 100644 paddle/fluid/operators/batch_norm_op_xpu.cc create mode 100644 paddle/fluid/operators/layer_norm_op_xpu.cc create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_layer_norm_op_xpu.py diff --git a/paddle/fluid/operators/batch_norm_op_xpu.cc b/paddle/fluid/operators/batch_norm_op_xpu.cc new file mode 100644 index 00000000000..624d5fe65ea --- /dev/null +++ b/paddle/fluid/operators/batch_norm_op_xpu.cc @@ -0,0 +1,167 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/batch_norm_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +class BatchNormXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto epsilon = ctx.Attr("epsilon"); + const auto momentum = ctx.Attr("momentum"); + const auto is_test = ctx.Attr("is_test"); + const auto use_global_stats = ctx.Attr("use_global_stats"); + const auto trainable_stats = ctx.Attr("trainable_statistics"); + bool test_mode = is_test && (!trainable_stats); + bool global_stats = test_mode || use_global_stats; + const auto& data_layout_str = ctx.Attr("data_layout"); + const auto data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW, + platform::errors::InvalidArgument( + "The 'data_layout' attribute must be NCHW. But " + "recevived 'data_layout' is [%s].", + data_layout_str)); + const auto* x = ctx.Input("X"); + const auto& x_dims = x->dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, + platform::errors::InvalidArgument( + "The input tensor X's dimension must equal to 4. But " + "received X's shape = [%s], X's dimension = [%d].", + x_dims, x_dims.size())); + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + const auto* scale = ctx.Input("Scale"); + const auto* bias = ctx.Input("Bias"); + const auto* x_data = x->data(); + const auto* scale_data = scale->data(); + const auto* bias_data = bias->data(); + auto* y = ctx.Output("Y"); + auto* y_data = y->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + if (!global_stats) { + auto* mean_out = ctx.Output("MeanOut"); + auto* variance_out = ctx.Output("VarianceOut"); + auto* saved_mean = ctx.Output("SavedMean"); + auto* saved_variance = ctx.Output("SavedVariance"); + mean_out->mutable_data(ctx.GetPlace()); + variance_out->mutable_data(ctx.GetPlace()); + saved_mean->mutable_data(ctx.GetPlace()); + saved_variance->mutable_data(ctx.GetPlace()); + auto* mean_out_data = mean_out->data(); + auto* variance_out_data = variance_out->data(); + auto* saved_mean_data = saved_mean->data(); + auto* saved_variance_data = saved_variance->data(); + int r = xpu::batch_norm_train_forward( + dev_ctx.x_context(), epsilon, momentum, N, C, H, W, x_data, y_data, + scale_data, bias_data, mean_out_data, variance_out_data, + saved_mean_data, saved_variance_data); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(batch_norm_train_forward) return " + "wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } else { + const auto* mean = ctx.Input("Mean"); + const auto* variance = ctx.Input("Variance"); + const auto* mean_data = mean->data(); + const auto* variance_data = variance->data(); + int r = xpu::batch_norm_infer_forward( + dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data, + bias_data, mean_data, variance_data); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(batch_norm_infer_forward) return " + "wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } + } +}; + +template +class BatchNormGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input("X"); + const auto* dy = ctx.Input(framework::GradVarName("Y")); + const auto* scale = ctx.Input("Scale"); + const auto* saved_mean = ctx.Input("SavedMean"); + // SavedVariance have been reverted in forward operator + const auto* saved_inv_variance = ctx.Input("SavedVariance"); + const auto& data_layout_str = ctx.Attr("data_layout"); + const auto data_layout = framework::StringToDataLayout(data_layout_str); + PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW, + platform::errors::InvalidArgument( + "The 'data_layout' attribute must be NCHW. But " + "recevived 'data_layout' is [%s].", + data_layout_str)); + const auto& x_dims = x->dims(); + PADDLE_ENFORCE_EQ(x_dims.size(), 4, + platform::errors::InvalidArgument( + "The input tensor X's dimension must equal to 4. But " + "received X's shape = [%s], X's dimension = [%d].", + x_dims, x_dims.size())); + const int N = x_dims[0]; + const int C = x_dims[1]; + const int H = x_dims[2]; + const int W = x_dims[3]; + const auto* x_data = x->data(); + const auto* dy_data = dy->data(); + const auto* scale_data = scale->data(); + const auto* saved_mean_data = saved_mean->data(); + const auto* saved_inv_variance_data = saved_inv_variance->data(); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + auto* dx_data = dx->mutable_data(ctx.GetPlace()); + auto* dscale_data = dscale->mutable_data(ctx.GetPlace()); + auto* dbias_data = dbias->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::batch_norm_backward(dev_ctx.x_context(), N, C, H, W, x_data, + dy_data, scale_data, saved_mean_data, + saved_inv_variance_data, dx_data, + dscale_data, dbias_data); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(batch_norm_infer_forward) return " + "wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + batch_norm, + ops::BatchNormXPUKernel); +REGISTER_OP_XPU_KERNEL( + batch_norm_grad, + ops::BatchNormGradXPUKernel); + +#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/layer_norm_op_xpu.cc b/paddle/fluid/operators/layer_norm_op_xpu.cc new file mode 100644 index 00000000000..5a3c865e26c --- /dev/null +++ b/paddle/fluid/operators/layer_norm_op_xpu.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/layer_norm_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using DDim = framework::DDim; + +template +class LayerNormXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const auto epsilon = ctx.Attr("epsilon"); + const auto* x = ctx.Input("X"); + const auto& x_dims = x->dims(); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + const auto* scale = ctx.Input("Scale"); + const auto* bias = ctx.Input("Bias"); + auto* y = ctx.Output("Y"); + auto* mean = ctx.Output("Mean"); + auto* variance = ctx.Output("Variance"); + const auto* x_data = x->data(); + const auto* scale_data = (scale == nullptr ? nullptr : scale->data()); + const auto* bias_data = (bias == nullptr ? nullptr : bias->data()); + auto* y_data = y->mutable_data(ctx.GetPlace()); + auto* mean_data = mean->mutable_data(ctx.GetPlace()); + auto* variance_data = variance->mutable_data(ctx.GetPlace()); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::layer_norm(dev_ctx.x_context(), left, right, x_data, y_data, + scale_data, bias_data, epsilon, mean_data, + variance_data, false); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(layer_norm) return wrong " + "value[%d], please check whether Baidu " + "Kunlun Card is properly installed.", + r)); + } +}; + +template +class LayerNormGradXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + const auto epsilon = ctx.Attr("epsilon"); + const auto* x = ctx.Input("X"); + const auto& x_dims = x->dims(); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + const auto* mean = ctx.Input("Mean"); + const auto* variance = ctx.Input("Variance"); + const auto* scale = ctx.Input("Scale"); + const auto* dy = ctx.Input(framework::GradVarName("Y")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + const auto* x_data = x->data(); + const auto* dy_data = dy->data(); + const auto* mean_data = mean->data(); + const auto* variance_data = variance->data(); + const auto* scale_data = (scale == nullptr ? nullptr : scale->data()); + auto* dscale_data = + (dscale == nullptr ? nullptr : dscale->mutable_data(ctx.GetPlace())); + auto* dbias_data = + (dbias == nullptr ? nullptr : dbias->mutable_data(ctx.GetPlace())); + auto* dx_data = + (dx == nullptr ? nullptr : dx->mutable_data(ctx.GetPlace())); + auto& dev_ctx = ctx.template device_context(); + int r = xpu::layer_norm_backward( + dev_ctx.x_context(), left, right, x_data, scale_data, variance_data, + mean_data, dy_data, dx_data, dscale_data, dbias_data, epsilon); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(layer_norm_backward) return wrong " + "value[%d], please check whether Baidu " + "Kunlun Card is properly installed.", + r)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + layer_norm, + ops::LayerNormXPUKernel); +REGISTER_OP_XPU_KERNEL( + layer_norm_grad, + ops::LayerNormGradXPUKernel); + +#endif // PADDLE_WITH_XPU diff --git a/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py new file mode 100644 index 00000000000..0d9387d6b75 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_batch_norm_op_xpu.py @@ -0,0 +1,269 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +from scipy.special import expit, erf +import paddle +import paddle.fluid as fluid +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.fluid import compiler, Program, program_guard + + +def ref_batch_norm_infer(x, scale, bias, mean, variance, momentum, epsilon, + data_layout): + if data_layout == "NCHW": + n, c, h, w = x.shape + mean_tile = np.reshape(mean, (1, c, 1, 1)) + mean_tile = np.tile(mean_tile, (n, 1, h, w)) + variance_tile = np.reshape(variance, (1, c, 1, 1)) + variance_tile = np.tile(variance_tile, (n, 1, h, w)) + normalized_x = (x - mean_tile) / np.sqrt(variance_tile + epsilon) + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + bias_tile = np.reshape(bias, (1, c, 1, 1)) + bias_tile = np.reshape(bias_tile, (1, c, 1, 1)) + y = normalized_x * scale_tile + bias_tile + elif data_layout == "NHWC": + normalized_x = (x - mean) / np.sqrt(variance + epsilon) + y = normalized_x * scale + bias + else: + raise ValueError( + "Unsupported data layout! Only NCHW and NHWC is supported, but received " + + data_layout) + return y + + +def ref_batch_norm_train(x, y_grad, scale, bias, mean, variance, momentum, + epsilon, data_layout): + # Forward + if data_layout == "NCHW": + n, c, h, w = x.shape + x_square = x * x + x_square_sum = np.sum(x_square, (0, 2, 3)) + x_sum = np.sum(x, axis=(0, 2, 3)) + element_count = np.size(x) / int(np.shape(x)[1]) + saved_mean = x_sum / element_count + saved_variance = x_square_sum / element_count - saved_mean * saved_mean + saved_mean_tile = np.reshape(saved_mean, (1, c, 1, 1)) + saved_mean_tile = np.tile(saved_mean_tile, (n, 1, h, w)) + saved_variance_tile = np.reshape(saved_variance, (1, c, 1, 1)) + saved_variance_tile = np.tile(saved_variance_tile, (n, 1, h, w)) + normalized_x = ( + x - saved_mean_tile) / np.sqrt(saved_variance_tile + epsilon) + scale_tile = np.reshape(scale, (1, c, 1, 1)) + scale_tile = np.tile(scale_tile, (n, 1, h, w)) + bias_tile = np.reshape(bias, (1, c, 1, 1)) + bias_tile = np.reshape(bias_tile, (1, c, 1, 1)) + y = normalized_x * scale_tile + bias_tile + elif data_layout == "NHWC": + x_square = x * x + x_square_sum = np.sum(x_square, (0, 1, 2)) + x_sum = np.sum(x, axis=(0, 1, 2)) + element_count = np.size(x) / int(np.shape(x)[-1]) + saved_mean = x_sum / element_count + saved_variance = x_square_sum / element_count - saved_mean * saved_mean + normalized_x = (x - saved_mean) / np.sqrt(saved_variance + epsilon) + y = normalized_x * scale + bias + else: + raise ValueError( + "Unsupported data layout! Only NCHW and NHWC is supported, but received " + + data_layout) + mean_out = saved_mean * (1. - momentum) + momentum * mean + variance_out = saved_variance * (1. - momentum) + momentum * variance + saved_inv_std = 1. / np.sqrt(saved_variance + epsilon) + # Backward + # Use the following formulas to calculate gradients: + # grad_scale = + # sum(grad_y * (x - mean)) * rsqrt(variance + epsilon) + # + # grad_bias = sum(y) + # + # x_grad = + # 1/N * scale * rsqrt(variance + epsilon) * (N * grad_y - sum(grad_y) - + # (x - mean) * sum(grad_y * (x - mean)) / (variance + epsilon)) + # Transfer from (N, C, H, W) to (N, H, W, C) to simplify computation + if data_layout == "NCHW": + x = np.transpose(x, (0, 2, 3, 1)) + y_grad = np.transpose(y_grad, (0, 2, 3, 1)) + x_grad = scale * ( + y_grad - np.mean( + y_grad, axis=(0, 1, 2)) - (x - saved_mean) * np.mean( + y_grad * (x - saved_mean), axis=(0, 1, 2)) / + (saved_variance + epsilon)) / np.sqrt(saved_variance + epsilon) + scale_grad = np.sum(y_grad * (x - saved_mean) / + np.sqrt(saved_variance + epsilon), + axis=(0, 1, 2)) + bias_grad = np.sum(y_grad, axis=(0, 1, 2)) + # Transfer back to N, C, H, W + if data_layout == "NCHW": + x_grad = np.transpose(x_grad, (0, 3, 1, 2)) + x = np.transpose(x, (0, 3, 1, 2)) + y_grad = np.transpose(y_grad, (0, 3, 1, 2)) + return y, mean_out, variance_out, saved_mean, saved_inv_std, x_grad, scale_grad, bias_grad + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUBatchNormOp(unittest.TestCase): + def setUp(self): + self.place = paddle.XPUPlace(0) + self.op_type = "batch_norm" + self.dtype = np.float32 + self.shape = [2, 3, 4, 5] + self.data_layout = "NCHW" + self.epsilon = 1e-05 + self.momentum = 0.9 + self.set_attrs() + + if self.data_layout == "NHWC": + channel_size = self.shape[3] + elif self.data_layout == "NCHW": + channel_size = self.shape[1] + else: + raise ValueError( + "Unsupported data layout! Only NCHW and NHWC is supported, but received " + + data_layout) + np.random.seed(1024) + self.x_np = np.random.random_sample(self.shape).astype(self.dtype) + self.scale_np = np.random.random_sample( + [channel_size]).astype(self.dtype) + self.bias_np = np.random.random_sample( + [channel_size]).astype(self.dtype) + self.mean_np = np.zeros([channel_size]).astype(self.dtype) + self.variance_np = np.ones([channel_size]).astype(self.dtype) + self.saved_mean_np = np.zeros([channel_size]).astype(self.dtype) + self.saved_variance_np = np.ones([channel_size]).astype(self.dtype) + + def set_attrs(self): + pass + + def test_infer(self): + paddle.enable_static() + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.data('X', self.x_np.shape, self.x_np.dtype) + scale = paddle.data('Scale', self.scale_np.shape, + self.scale_np.dtype) + bias = paddle.data('Bias', self.bias_np.shape, self.bias_np.dtype) + mean = paddle.data('Mean', self.mean_np.shape, self.mean_np.dtype) + variance = paddle.data('Variance', self.variance_np.shape, + self.variance_np.dtype) + y = F.batch_norm(x, mean, variance, scale, bias, False, + self.momentum, self.epsilon, self.data_layout) + exe = paddle.static.Executor(self.place) + [y_np] = exe.run(feed={ + 'X': self.x_np, + 'Scale': self.scale_np, + 'Bias': self.bias_np, + 'Mean': self.mean_np, + 'Variance': self.variance_np + }, + fetch_list=[y]) + y_np_ref = ref_batch_norm_infer( + self.x_np, self.scale_np, self.bias_np, self.mean_np, + self.variance_np, self.momentum, self.epsilon, self.data_layout) + self.assertEqual(np.allclose(y_np_ref, y_np), True) + + def test_train(self): + y_grad_np = np.random.random_sample(self.shape).astype(self.dtype) + y_np, mean_out_np, variance_out_np, saved_mean_np, saved_variance_np, x_grad_np, scale_grad_np, bias_grad_np = ref_batch_norm_train( + self.x_np, y_grad_np, self.scale_np, self.bias_np, self.mean_np, + self.variance_np, self.momentum, self.epsilon, self.data_layout) + inputs = { + 'X': self.x_np, + 'Scale': self.scale_np, + 'Bias': self.bias_np, + 'Mean': self.mean_np, + 'Variance': self.variance_np, + 'Y@GRAD': y_grad_np + } + outputs = { + 'Y': y_np, + 'Mean': mean_out_np, + 'Variance': variance_out_np, + 'SavedMean': saved_mean_np, + 'SavedVariance': saved_variance_np, + 'X@GRAD': x_grad_np, + 'Scale@GRAD': scale_grad_np, + 'Bias@GRAD': bias_grad_np + } + attrs = { + 'momentum': self.momentum, + 'epsilon': self.epsilon, + 'is_test': False, + 'data_layout': self.data_layout, + 'use_mkldnn': False, + 'fuse_with_relu': False, + 'use_global_stats': False, + } + paddle.enable_static() + program = paddle.static.Program() + with paddle.static.program_guard(program): + block = program.global_block() + # Set inputs, outputs and attributes to the forward op of batch_norm + input_vars = {} + for var_name in inputs: + arg_name = var_name + np_value = inputs[var_name] + if not block.has_var(var_name): + block.create_var( + name=var_name, + shape=np_value.shape, + dtype=np_value.dtype) + input_vars[arg_name] = block.var(var_name) + fetch_list = [] + output_vars = {} + for var_name in outputs: + arg_name = var_name + np_value = outputs[var_name] + if not block.has_var(var_name): + block.create_var( + name=var_name, + shape=np_value.shape, + dtype=np_value.dtype) + if var_name == 'Mean': + arg_name = 'MeanOut' # Share memory + if var_name == 'Variance': + arg_name = 'VarianceOut' # Share memory + output_vars[arg_name] = block.var(var_name) + fetch_list.append(var_name) + batch_norm_op = block.append_op( + type="batch_norm", + inputs=input_vars, + outputs=output_vars, + attrs=attrs) + # Generate the backward op_desc of batch_norm + grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc( + batch_norm_op.desc, set(), []) + grad_op_desc = grad_op_desc_list[0] + new_op_desc = block.desc.append_op() + new_op_desc.copy_from(grad_op_desc) + program._sync_with_cpp() + exe = paddle.static.Executor(self.place) + outs = exe.run(program, feed=inputs, fetch_list=fetch_list) + for id, name in enumerate(fetch_list): + self.assertEqual( + np.allclose( + outputs[name], outs[id], atol=1e-4), True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_layer_norm_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_layer_norm_op_xpu.py new file mode 100644 index 00000000000..b166661c3d6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_layer_norm_op_xpu.py @@ -0,0 +1,111 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +import sys +import unittest +from functools import reduce +sys.path.append("..") +from op_test import OpTest +from operator import mul + +paddle.enable_static() + + +def ref_layer_norm(x, scale, bias, epsilon, begin_norm_axis=1): + x_shape = x.shape + left = reduce(mul, x_shape[0:begin_norm_axis], 1) + right = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + x.shape = [left, right] + mean = np.mean(x, axis=1) + variance = np.var(x, axis=1) + epsilon + y = np.divide((x - mean.reshape([left, 1])), + (np.sqrt(variance)).reshape([left, 1])) + if scale is not None: + y = scale.reshape([1, right]) * y + if bias is not None: + y = y + bias.reshape([1, right]) + x.shape, y.shape = x_shape, x_shape + return y, mean, variance + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULayerNormOp(OpTest): + def setUp(self): + self.op_type = "layer_norm" + self.dtype = np.float32 + self.shape = [2, 3, 4, 5] + self.epsilon = 1e-05 + self.begin_norm_axis = 1 + self.set_attrs() + + right = reduce(mul, self.shape[self.begin_norm_axis:len(self.shape)], 1) + np.random.seed(10) + x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype) + scale_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype) + bias_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype) + ref_y_np, ref_mean_np, ref_variance_np = ref_layer_norm( + x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis) + + self.inputs = {'X': x_np, 'Scale': scale_np, 'Bias': bias_np} + self.outputs = { + 'Y': ref_y_np, + 'Mean': ref_mean_np, + 'Variance': ref_variance_np + } + self.attrs = {'begin_norm_axis': self.begin_norm_axis, 'use_xpu': True} + + def set_attrs(self): + pass + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0), atol=1e-4) + + def test_check_grad(self): + self.check_grad_with_place( + paddle.XPUPlace(0), ['X'], 'Y', max_relative_error=0.02) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULayerNormOpAxis2(TestXPULayerNormOp): + def set_attrs(self): + self.begin_norm_axis = 2 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULayerNormOpAxis3(TestXPULayerNormOp): + def set_attrs(self): + self.begin_norm_axis = 3 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULayerNormOp2D(TestXPULayerNormOp): + def set_attrs(self): + self.shape = [10, 12] + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULayerNormOp3D(TestXPULayerNormOp): + def set_attrs(self): + self.shape = [4, 5, 6] + + +if __name__ == "__main__": + unittest.main() -- GitLab