You need to sign in or sign up before continuing.
未验证 提交 c90d3556 编写于 作者: H hong19860320 提交者: GitHub

Add batch_norm and layer_norm XPU kernels (#27818)

上级 ddcd1b53
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/batch_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class BatchNormXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto epsilon = ctx.Attr<float>("epsilon");
const auto momentum = ctx.Attr<float>("momentum");
const auto is_test = ctx.Attr<bool>("is_test");
const auto use_global_stats = ctx.Attr<bool>("use_global_stats");
const auto trainable_stats = ctx.Attr<bool>("trainable_statistics");
bool test_mode = is_test && (!trainable_stats);
bool global_stats = test_mode || use_global_stats;
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW. But "
"recevived 'data_layout' is [%s].",
data_layout_str));
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But "
"received X's shape = [%s], X's dimension = [%d].",
x_dims, x_dims.size()));
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
const auto* x_data = x->data<T>();
const auto* scale_data = scale->data<T>();
const auto* bias_data = bias->data<T>();
auto* y = ctx.Output<Tensor>("Y");
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
if (!global_stats) {
auto* mean_out = ctx.Output<Tensor>("MeanOut");
auto* variance_out = ctx.Output<Tensor>("VarianceOut");
auto* saved_mean = ctx.Output<Tensor>("SavedMean");
auto* saved_variance = ctx.Output<Tensor>("SavedVariance");
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
auto* mean_out_data = mean_out->data<T>();
auto* variance_out_data = variance_out->data<T>();
auto* saved_mean_data = saved_mean->data<T>();
auto* saved_variance_data = saved_variance->data<T>();
int r = xpu::batch_norm_train_forward(
dev_ctx.x_context(), epsilon, momentum, N, C, H, W, x_data, y_data,
scale_data, bias_data, mean_out_data, variance_out_data,
saved_mean_data, saved_variance_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_train_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
} else {
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* mean_data = mean->data<T>();
const auto* variance_data = variance->data<T>();
int r = xpu::batch_norm_infer_forward(
dev_ctx.x_context(), epsilon, N, C, H, W, x_data, y_data, scale_data,
bias_data, mean_data, variance_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
}
};
template <typename DeviceContext, typename T>
class BatchNormGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* x = ctx.Input<Tensor>("X");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* saved_mean = ctx.Input<Tensor>("SavedMean");
// SavedVariance have been reverted in forward operator
const auto* saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const auto& data_layout_str = ctx.Attr<std::string>("data_layout");
const auto data_layout = framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(data_layout, DataLayout::kNCHW,
platform::errors::InvalidArgument(
"The 'data_layout' attribute must be NCHW. But "
"recevived 'data_layout' is [%s].",
data_layout_str));
const auto& x_dims = x->dims();
PADDLE_ENFORCE_EQ(x_dims.size(), 4,
platform::errors::InvalidArgument(
"The input tensor X's dimension must equal to 4. But "
"received X's shape = [%s], X's dimension = [%d].",
x_dims, x_dims.size()));
const int N = x_dims[0];
const int C = x_dims[1];
const int H = x_dims[2];
const int W = x_dims[3];
const auto* x_data = x->data<T>();
const auto* dy_data = dy->data<T>();
const auto* scale_data = scale->data<T>();
const auto* saved_mean_data = saved_mean->data<T>();
const auto* saved_inv_variance_data = saved_inv_variance->data<T>();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto* dscale_data = dscale->mutable_data<T>(ctx.GetPlace());
auto* dbias_data = dbias->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::batch_norm_backward(dev_ctx.x_context(), N, C, H, W, x_data,
dy_data, scale_data, saved_mean_data,
saved_inv_variance_data, dx_data,
dscale_data, dbias_data);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(batch_norm_infer_forward) return "
"wrong value[%d], please check whether "
"Baidu Kunlun Card is properly installed.",
r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
batch_norm,
ops::BatchNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/layer_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class LayerNormXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* variance = ctx.Output<Tensor>("Variance");
const auto* x_data = x->data<T>();
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
const auto* bias_data = (bias == nullptr ? nullptr : bias->data<T>());
auto* y_data = y->mutable_data<T>(ctx.GetPlace());
auto* mean_data = mean->mutable_data<T>(ctx.GetPlace());
auto* variance_data = variance->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm(dev_ctx.x_context(), left, right, x_data, y_data,
scale_data, bias_data, epsilon, mean_data,
variance_data, false);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm) return wrong "
"value[%d], please check whether Baidu "
"Kunlun Card is properly installed.",
r));
}
};
template <typename DeviceContext, typename T>
class LayerNormGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto& x_dims = x->dims();
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
const auto* x_data = x->data<T>();
const auto* dy_data = dy->data<T>();
const auto* mean_data = mean->data<T>();
const auto* variance_data = variance->data<T>();
const auto* scale_data = (scale == nullptr ? nullptr : scale->data<T>());
auto* dscale_data =
(dscale == nullptr ? nullptr : dscale->mutable_data<T>(ctx.GetPlace()));
auto* dbias_data =
(dbias == nullptr ? nullptr : dbias->mutable_data<T>(ctx.GetPlace()));
auto* dx_data =
(dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm_backward(
dev_ctx.x_context(), left, right, x_data, scale_data, variance_data,
mean_data, dy_data, dx_data, dscale_data, dbias_data, epsilon);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm_backward) return wrong "
"value[%d], please check whether Baidu "
"Kunlun Card is properly installed.",
r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
layer_norm,
ops::LayerNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import sys
sys.path.append("..")
import unittest
import numpy as np
import paddle.fluid.core as core
from op_test import OpTest
from scipy.special import expit, erf
import paddle
import paddle.fluid as fluid
import paddle.nn as nn
import paddle.nn.functional as F
from paddle.fluid import compiler, Program, program_guard
def ref_batch_norm_infer(x, scale, bias, mean, variance, momentum, epsilon,
data_layout):
if data_layout == "NCHW":
n, c, h, w = x.shape
mean_tile = np.reshape(mean, (1, c, 1, 1))
mean_tile = np.tile(mean_tile, (n, 1, h, w))
variance_tile = np.reshape(variance, (1, c, 1, 1))
variance_tile = np.tile(variance_tile, (n, 1, h, w))
normalized_x = (x - mean_tile) / np.sqrt(variance_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
bias_tile = np.reshape(bias, (1, c, 1, 1))
bias_tile = np.reshape(bias_tile, (1, c, 1, 1))
y = normalized_x * scale_tile + bias_tile
elif data_layout == "NHWC":
normalized_x = (x - mean) / np.sqrt(variance + epsilon)
y = normalized_x * scale + bias
else:
raise ValueError(
"Unsupported data layout! Only NCHW and NHWC is supported, but received "
+ data_layout)
return y
def ref_batch_norm_train(x, y_grad, scale, bias, mean, variance, momentum,
epsilon, data_layout):
# Forward
if data_layout == "NCHW":
n, c, h, w = x.shape
x_square = x * x
x_square_sum = np.sum(x_square, (0, 2, 3))
x_sum = np.sum(x, axis=(0, 2, 3))
element_count = np.size(x) / int(np.shape(x)[1])
saved_mean = x_sum / element_count
saved_variance = x_square_sum / element_count - saved_mean * saved_mean
saved_mean_tile = np.reshape(saved_mean, (1, c, 1, 1))
saved_mean_tile = np.tile(saved_mean_tile, (n, 1, h, w))
saved_variance_tile = np.reshape(saved_variance, (1, c, 1, 1))
saved_variance_tile = np.tile(saved_variance_tile, (n, 1, h, w))
normalized_x = (
x - saved_mean_tile) / np.sqrt(saved_variance_tile + epsilon)
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
bias_tile = np.reshape(bias, (1, c, 1, 1))
bias_tile = np.reshape(bias_tile, (1, c, 1, 1))
y = normalized_x * scale_tile + bias_tile
elif data_layout == "NHWC":
x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2))
element_count = np.size(x) / int(np.shape(x)[-1])
saved_mean = x_sum / element_count
saved_variance = x_square_sum / element_count - saved_mean * saved_mean
normalized_x = (x - saved_mean) / np.sqrt(saved_variance + epsilon)
y = normalized_x * scale + bias
else:
raise ValueError(
"Unsupported data layout! Only NCHW and NHWC is supported, but received "
+ data_layout)
mean_out = saved_mean * (1. - momentum) + momentum * mean
variance_out = saved_variance * (1. - momentum) + momentum * variance
saved_inv_std = 1. / np.sqrt(saved_variance + epsilon)
# Backward
# Use the following formulas to calculate gradients:
# grad_scale =
# sum(grad_y * (x - mean)) * rsqrt(variance + epsilon)
#
# grad_bias = sum(y)
#
# x_grad =
# 1/N * scale * rsqrt(variance + epsilon) * (N * grad_y - sum(grad_y) -
# (x - mean) * sum(grad_y * (x - mean)) / (variance + epsilon))
# Transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
if data_layout == "NCHW":
x = np.transpose(x, (0, 2, 3, 1))
y_grad = np.transpose(y_grad, (0, 2, 3, 1))
x_grad = scale * (
y_grad - np.mean(
y_grad, axis=(0, 1, 2)) - (x - saved_mean) * np.mean(
y_grad * (x - saved_mean), axis=(0, 1, 2)) /
(saved_variance + epsilon)) / np.sqrt(saved_variance + epsilon)
scale_grad = np.sum(y_grad * (x - saved_mean) /
np.sqrt(saved_variance + epsilon),
axis=(0, 1, 2))
bias_grad = np.sum(y_grad, axis=(0, 1, 2))
# Transfer back to N, C, H, W
if data_layout == "NCHW":
x_grad = np.transpose(x_grad, (0, 3, 1, 2))
x = np.transpose(x, (0, 3, 1, 2))
y_grad = np.transpose(y_grad, (0, 3, 1, 2))
return y, mean_out, variance_out, saved_mean, saved_inv_std, x_grad, scale_grad, bias_grad
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUBatchNormOp(unittest.TestCase):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.op_type = "batch_norm"
self.dtype = np.float32
self.shape = [2, 3, 4, 5]
self.data_layout = "NCHW"
self.epsilon = 1e-05
self.momentum = 0.9
self.set_attrs()
if self.data_layout == "NHWC":
channel_size = self.shape[3]
elif self.data_layout == "NCHW":
channel_size = self.shape[1]
else:
raise ValueError(
"Unsupported data layout! Only NCHW and NHWC is supported, but received "
+ data_layout)
np.random.seed(1024)
self.x_np = np.random.random_sample(self.shape).astype(self.dtype)
self.scale_np = np.random.random_sample(
[channel_size]).astype(self.dtype)
self.bias_np = np.random.random_sample(
[channel_size]).astype(self.dtype)
self.mean_np = np.zeros([channel_size]).astype(self.dtype)
self.variance_np = np.ones([channel_size]).astype(self.dtype)
self.saved_mean_np = np.zeros([channel_size]).astype(self.dtype)
self.saved_variance_np = np.ones([channel_size]).astype(self.dtype)
def set_attrs(self):
pass
def test_infer(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.data('X', self.x_np.shape, self.x_np.dtype)
scale = paddle.data('Scale', self.scale_np.shape,
self.scale_np.dtype)
bias = paddle.data('Bias', self.bias_np.shape, self.bias_np.dtype)
mean = paddle.data('Mean', self.mean_np.shape, self.mean_np.dtype)
variance = paddle.data('Variance', self.variance_np.shape,
self.variance_np.dtype)
y = F.batch_norm(x, mean, variance, scale, bias, False,
self.momentum, self.epsilon, self.data_layout)
exe = paddle.static.Executor(self.place)
[y_np] = exe.run(feed={
'X': self.x_np,
'Scale': self.scale_np,
'Bias': self.bias_np,
'Mean': self.mean_np,
'Variance': self.variance_np
},
fetch_list=[y])
y_np_ref = ref_batch_norm_infer(
self.x_np, self.scale_np, self.bias_np, self.mean_np,
self.variance_np, self.momentum, self.epsilon, self.data_layout)
self.assertEqual(np.allclose(y_np_ref, y_np), True)
def test_train(self):
y_grad_np = np.random.random_sample(self.shape).astype(self.dtype)
y_np, mean_out_np, variance_out_np, saved_mean_np, saved_variance_np, x_grad_np, scale_grad_np, bias_grad_np = ref_batch_norm_train(
self.x_np, y_grad_np, self.scale_np, self.bias_np, self.mean_np,
self.variance_np, self.momentum, self.epsilon, self.data_layout)
inputs = {
'X': self.x_np,
'Scale': self.scale_np,
'Bias': self.bias_np,
'Mean': self.mean_np,
'Variance': self.variance_np,
'Y@GRAD': y_grad_np
}
outputs = {
'Y': y_np,
'Mean': mean_out_np,
'Variance': variance_out_np,
'SavedMean': saved_mean_np,
'SavedVariance': saved_variance_np,
'X@GRAD': x_grad_np,
'Scale@GRAD': scale_grad_np,
'Bias@GRAD': bias_grad_np
}
attrs = {
'momentum': self.momentum,
'epsilon': self.epsilon,
'is_test': False,
'data_layout': self.data_layout,
'use_mkldnn': False,
'fuse_with_relu': False,
'use_global_stats': False,
}
paddle.enable_static()
program = paddle.static.Program()
with paddle.static.program_guard(program):
block = program.global_block()
# Set inputs, outputs and attributes to the forward op of batch_norm
input_vars = {}
for var_name in inputs:
arg_name = var_name
np_value = inputs[var_name]
if not block.has_var(var_name):
block.create_var(
name=var_name,
shape=np_value.shape,
dtype=np_value.dtype)
input_vars[arg_name] = block.var(var_name)
fetch_list = []
output_vars = {}
for var_name in outputs:
arg_name = var_name
np_value = outputs[var_name]
if not block.has_var(var_name):
block.create_var(
name=var_name,
shape=np_value.shape,
dtype=np_value.dtype)
if var_name == 'Mean':
arg_name = 'MeanOut' # Share memory
if var_name == 'Variance':
arg_name = 'VarianceOut' # Share memory
output_vars[arg_name] = block.var(var_name)
fetch_list.append(var_name)
batch_norm_op = block.append_op(
type="batch_norm",
inputs=input_vars,
outputs=output_vars,
attrs=attrs)
# Generate the backward op_desc of batch_norm
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
batch_norm_op.desc, set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
program._sync_with_cpp()
exe = paddle.static.Executor(self.place)
outs = exe.run(program, feed=inputs, fetch_list=fetch_list)
for id, name in enumerate(fetch_list):
self.assertEqual(
np.allclose(
outputs[name], outs[id], atol=1e-4), True)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import numpy as np
import sys
import unittest
from functools import reduce
sys.path.append("..")
from op_test import OpTest
from operator import mul
paddle.enable_static()
def ref_layer_norm(x, scale, bias, epsilon, begin_norm_axis=1):
x_shape = x.shape
left = reduce(mul, x_shape[0:begin_norm_axis], 1)
right = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
x.shape = [left, right]
mean = np.mean(x, axis=1)
variance = np.var(x, axis=1) + epsilon
y = np.divide((x - mean.reshape([left, 1])),
(np.sqrt(variance)).reshape([left, 1]))
if scale is not None:
y = scale.reshape([1, right]) * y
if bias is not None:
y = y + bias.reshape([1, right])
x.shape, y.shape = x_shape, x_shape
return y, mean, variance
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOp(OpTest):
def setUp(self):
self.op_type = "layer_norm"
self.dtype = np.float32
self.shape = [2, 3, 4, 5]
self.epsilon = 1e-05
self.begin_norm_axis = 1
self.set_attrs()
right = reduce(mul, self.shape[self.begin_norm_axis:len(self.shape)], 1)
np.random.seed(10)
x_np = np.random.uniform(0.1, 1, self.shape).astype(self.dtype)
scale_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype)
bias_np = np.random.uniform(0.1, 1, [right]).astype(self.dtype)
ref_y_np, ref_mean_np, ref_variance_np = ref_layer_norm(
x_np, scale_np, bias_np, self.epsilon, self.begin_norm_axis)
self.inputs = {'X': x_np, 'Scale': scale_np, 'Bias': bias_np}
self.outputs = {
'Y': ref_y_np,
'Mean': ref_mean_np,
'Variance': ref_variance_np
}
self.attrs = {'begin_norm_axis': self.begin_norm_axis, 'use_xpu': True}
def set_attrs(self):
pass
def test_check_output(self):
self.check_output_with_place(paddle.XPUPlace(0), atol=1e-4)
def test_check_grad(self):
self.check_grad_with_place(
paddle.XPUPlace(0), ['X'], 'Y', max_relative_error=0.02)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOpAxis2(TestXPULayerNormOp):
def set_attrs(self):
self.begin_norm_axis = 2
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOpAxis3(TestXPULayerNormOp):
def set_attrs(self):
self.begin_norm_axis = 3
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOp2D(TestXPULayerNormOp):
def set_attrs(self):
self.shape = [10, 12]
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPULayerNormOp3D(TestXPULayerNormOp):
def set_attrs(self):
self.shape = [4, 5, 6]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册