未验证 提交 ecd6db43 编写于 作者: F fwenguang 提交者: GitHub

[MLU] add layernorm mlu kernel (#42356)

上级 4e5fb733
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
template <typename T>
class LayerNormMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("Mean");
auto* variance = ctx.Output<Tensor>("Variance");
auto place = ctx.GetPlace();
y->mutable_data<T>(place);
mean->mutable_data<T>(place);
variance->mutable_data<T>(place);
const auto& x_dims = x->dims();
std::vector<int> scale_bias_axes;
std::vector<int> mean_var_axes;
for (auto i = 0; i < x_dims.size(); ++i) {
if (i >= begin_norm_axis) {
scale_bias_axes.push_back(x_dims[i]);
} else {
mean_var_axes.push_back(x_dims[i]);
}
}
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc y_desc(*y);
MLUCnnlTensorDesc mean_var_desc(mean_var_axes.size(), mean_var_axes.data(),
ToCnnlDataType<T>());
// cnnl only support both of scale and bias is NULL or not.
if (!scale && !bias) {
MLUCnnl::LayerNormForward(
ctx, begin_norm_axis, x_desc.get(), GetBasePtr(x),
nullptr /*scale_bias_desc*/, nullptr /*scale*/, nullptr /*bias*/,
epsilon, y_desc.get(), GetBasePtr(y), mean_var_desc.get(),
GetBasePtr(mean), GetBasePtr(variance));
} else {
Tensor tmp_scale(x->dtype());
if (!scale) {
tmp_scale.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
FillMLUTensorWithHostValue(ctx, static_cast<T>(1), &tmp_scale);
} else {
tmp_scale = *scale;
}
Tensor tmp_bias(x->dtype());
if (!bias) {
tmp_bias.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
FillMLUTensorWithHostValue(ctx, static_cast<T>(0), &tmp_bias);
} else {
tmp_bias = *bias;
}
// scale and bias should have same type with x/y
MLUCnnlTensorDesc float32_desc(scale_bias_axes.size(),
scale_bias_axes.data(), CNNL_DTYPE_FLOAT);
MLUCnnlTensorDesc float16_desc(scale_bias_axes.size(),
scale_bias_axes.data(), CNNL_DTYPE_HALF);
cnnlCastDataType_t cast_type = GetCastDataType(VT::FP32, VT::FP16);
Tensor final_scale(x->dtype());
if (final_scale.dtype() == DataType::FLOAT16 &&
tmp_scale.dtype() == DataType::FLOAT32) {
final_scale.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
// cast scale to fp16
MLUCnnl::Cast(ctx, cast_type, float32_desc.get(),
GetBasePtr(&tmp_scale), float16_desc.get(),
GetBasePtr(&final_scale));
} else {
final_scale = tmp_scale;
}
Tensor final_bias(x->dtype());
if (final_bias.dtype() == DataType::FLOAT16 &&
tmp_bias.dtype() == DataType::FLOAT32) {
final_bias.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
// cast bias to fp16
MLUCnnl::Cast(ctx, cast_type, float32_desc.get(), GetBasePtr(&tmp_bias),
float16_desc.get(), GetBasePtr(&final_bias));
} else {
final_bias = tmp_bias;
}
MLUCnnlTensorDesc scale_bias_desc(
scale_bias_axes.size(), scale_bias_axes.data(), ToCnnlDataType<T>());
MLUCnnl::LayerNormForward(
ctx, begin_norm_axis, x_desc.get(), GetBasePtr(x),
scale_bias_desc.get(), GetBasePtr(&final_scale),
GetBasePtr(&final_bias), epsilon, y_desc.get(), GetBasePtr(y),
mean_var_desc.get(), GetBasePtr(mean), GetBasePtr(variance));
}
}
};
template <typename T>
class LayerNormGradMLUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
const auto* x = ctx.Input<Tensor>("X");
const auto* mean = ctx.Input<Tensor>("Mean");
const auto* variance = ctx.Input<Tensor>("Variance");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
const auto& x_dims = x->dims();
std::vector<int> scale_bias_axes;
std::vector<int> mean_var_axes;
for (auto i = 0; i < x_dims.size(); ++i) {
if (i >= begin_norm_axis) {
scale_bias_axes.push_back(x_dims[i]);
} else {
mean_var_axes.push_back(x_dims[i]);
}
}
MLUCnnlTensorDesc x_desc(*x);
MLUCnnlTensorDesc dy_desc(*dy);
MLUCnnlTensorDesc mean_var_desc(mean_var_axes.size(), mean_var_axes.data(),
ToCnnlDataType<T>());
MLUCnnlTensorDesc dx_desc(*dx);
Tensor tmp_scale(x->dtype());
if (!scale) {
tmp_scale.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
FillMLUTensorWithHostValue(ctx, static_cast<T>(1), &tmp_scale);
} else {
tmp_scale = *scale;
}
MLUCnnlTensorDesc float32_desc(scale_bias_axes.size(),
scale_bias_axes.data(), CNNL_DTYPE_FLOAT);
MLUCnnlTensorDesc float16_desc(scale_bias_axes.size(),
scale_bias_axes.data(), CNNL_DTYPE_HALF);
cnnlCastDataType_t cast_fp32_to_fp16 = GetCastDataType(VT::FP32, VT::FP16);
cnnlCastDataType_t cast_fp16_to_fp32 = GetCastDataType(VT::FP16, VT::FP32);
Tensor final_scale(x->dtype());
if (final_scale.dtype() == DataType::FLOAT16 &&
tmp_scale.dtype() == DataType::FLOAT32) {
final_scale.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
// cast scale to fp16
MLUCnnl::Cast(ctx, cast_fp32_to_fp16, float32_desc.get(),
GetBasePtr(&tmp_scale), float16_desc.get(),
GetBasePtr(&final_scale));
} else {
final_scale = tmp_scale;
}
Tensor tmp_dscale(x->dtype());
if (dscale && (tmp_dscale.dtype() == dscale->dtype())) {
dscale->mutable_data<T>(place);
tmp_dscale = *dscale;
} else {
tmp_dscale.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
}
Tensor tmp_dbias(x->dtype());
if (dbias && (tmp_dbias.dtype() == dbias->dtype())) {
dbias->mutable_data<T>(place);
tmp_dbias = *dbias;
} else {
tmp_dbias.mutable_data<T>(phi::make_ddim(scale_bias_axes), place);
}
MLUCnnlTensorDesc scale_desc(scale_bias_axes.size(), scale_bias_axes.data(),
ToCnnlDataType<T>());
MLUCnnl::LayerNormBackward(
ctx, begin_norm_axis, x_desc.get(), GetBasePtr(x), dy_desc.get(),
GetBasePtr(dy), scale_desc.get(), GetBasePtr(&final_scale),
mean_var_desc.get(), GetBasePtr(mean), GetBasePtr(variance),
dx_desc.get(), GetBasePtr(dx), GetBasePtr(&tmp_dscale),
GetBasePtr(&tmp_dbias));
if (dscale && (tmp_dscale.dtype() == DataType::FLOAT16 &&
dscale->dtype() == DataType::FLOAT32)) {
dscale->mutable_data<T>(place);
MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(),
GetBasePtr(&tmp_dscale), float32_desc.get(),
GetBasePtr(dscale));
}
if (dbias && (tmp_dbias.dtype() == DataType::FLOAT16 &&
dbias->dtype() == DataType::FLOAT32)) {
dbias->mutable_data<T>(place);
MLUCnnl::Cast(ctx, cast_fp16_to_fp32, float16_desc.get(),
GetBasePtr(&tmp_dbias), float32_desc.get(),
GetBasePtr(dbias));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_MLU_KERNEL(layer_norm, ops::LayerNormMLUKernel<float>,
ops::LayerNormMLUKernel<plat::float16>);
REGISTER_OP_MLU_KERNEL(layer_norm_grad, ops::LayerNormGradMLUKernel<float>,
ops::LayerNormGradMLUKernel<plat::float16>);
......@@ -2077,6 +2077,45 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
}
}
/* static */ void MLUCnnl::LayerNormForward(
const ExecutionContext& ctx, int axis, const cnnlTensorDescriptor_t x_desc,
const void* x, const cnnlTensorDescriptor_t weight_bias_desc,
const void* weight, const void* bias, float eps,
const cnnlTensorDescriptor_t y_desc, void* y,
const cnnlTensorDescriptor_t mean_rstd_desc, void* saved_mean,
void* saved_rstd) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
size_t workspace_size;
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlGetLayerNormOpWorkspaceSize(handle, axis, x_desc, &workspace_size));
auto& dev_ctx = GetDevCtxFromCTX(ctx);
Tensor workspace = ctx.AllocateTmpTensor<int8_t, MLUDeviceContext>(
{static_cast<int64_t>(workspace_size)}, dev_ctx);
void* workspace_ptr = workspace.mutable_data(ctx.GetPlace());
PADDLE_ENFORCE_MLU_SUCCESS(
cnnlLayerNormForward(handle, x_desc, x, axis, weight_bias_desc, weight,
bias, eps, workspace_ptr, workspace_size, y_desc, y,
mean_rstd_desc, saved_mean, saved_rstd));
}
/* static */ void MLUCnnl::LayerNormBackward(
const ExecutionContext& ctx, int axis, const cnnlTensorDescriptor_t x_desc,
const void* x, const cnnlTensorDescriptor_t diff_z_desc, const void* diff_z,
const cnnlTensorDescriptor_t weight_bias_desc, const void* weight,
const cnnlTensorDescriptor_t mean_rstd_desc, const void* saved_mean,
const void* saved_rstd, const cnnlTensorDescriptor_t diff_x_desc,
void* diff_x, void* diff_weight, void* diff_bias) {
cnnlHandle_t handle = GetHandleFromCTX(ctx);
PADDLE_ENFORCE_MLU_SUCCESS(cnnlLayerNormBackward(
handle, x_desc, x, axis, diff_z_desc, diff_z, weight_bias_desc, weight,
mean_rstd_desc, saved_mean, saved_rstd, diff_x_desc, diff_x, diff_weight,
diff_bias));
}
/* static */ void MLUCnnl::QuantizeParam(
const ExecutionContext& ctx, const cnnlQuantizeMode_t mode,
const int bitwidth, const cnnlTensorDescriptor_t input_desc,
......
......@@ -146,10 +146,8 @@ const std::map<std::pair<VT::Type, VT::Type>, cnnlCastDataType_t>
{{VT::FP16, /*cast to*/ VT::BOOL}, CNNL_CAST_HALF_TO_BOOL},
{{VT::INT32, /*cast to*/ VT::FP32}, CNNL_CAST_INT32_TO_FLOAT},
{{VT::INT32, /*cast to*/ VT::FP16}, CNNL_CAST_INT32_TO_HALF},
{{VT::INT32, /*cast to*/ VT::INT64}, CNNL_CAST_INT32_TO_INT64},
{{VT::INT32, /*cast to*/ VT::INT16}, CNNL_CAST_INT32_TO_INT16},
{{VT::INT32, /*cast to*/ VT::INT8}, CNNL_CAST_INT32_TO_INT8},
{{VT::INT32, /*cast to*/ VT::BOOL}, CNNL_CAST_INT32_TO_BOOL},
{{VT::INT32, /*cast to*/ VT::INT16}, CNNL_CAST_INT32_TO_INT16},
{{VT::INT16, /*cast to*/ VT::FP32}, CNNL_CAST_INT16_TO_FLOAT},
{{VT::INT16, /*cast to*/ VT::FP16}, CNNL_CAST_INT16_TO_HALF},
{{VT::INT16, /*cast to*/ VT::INT32}, CNNL_CAST_INT16_TO_INT32},
......@@ -158,12 +156,21 @@ const std::map<std::pair<VT::Type, VT::Type>, cnnlCastDataType_t>
{{VT::INT8, /*cast to*/ VT::INT32}, CNNL_CAST_INT8_TO_INT32},
{{VT::UINT8, /*cast to*/ VT::FP32}, CNNL_CAST_UINT8_TO_FLOAT},
{{VT::UINT8, /*cast to*/ VT::FP16}, CNNL_CAST_UINT8_TO_HALF},
{{VT::UINT8, /*cast to*/ VT::INT64}, CNNL_CAST_UINT8_TO_INT64},
{{VT::UINT8, /*cast to*/ VT::INT32}, CNNL_CAST_UINT8_TO_INT32},
{{VT::BOOL, /*cast to*/ VT::FP32}, CNNL_CAST_BOOL_TO_FLOAT},
{{VT::BOOL, /*cast to*/ VT::FP16}, CNNL_CAST_BOOL_TO_HALF},
{{VT::BOOL, /*cast to*/ VT::INT32}, CNNL_CAST_BOOL_TO_INT32},
{{VT::UINT8, /*cast to*/ VT::INT32}, CNNL_CAST_UINT8_TO_INT32},
{{VT::INT32, /*cast to*/ VT::INT64}, CNNL_CAST_INT32_TO_INT64},
{{VT::INT64, /*cast to*/ VT::INT32}, CNNL_CAST_INT64_TO_INT32},
{{VT::INT32, /*cast to*/ VT::BOOL}, CNNL_CAST_INT32_TO_BOOL},
{{VT::UINT8, /*cast to*/ VT::INT64}, CNNL_CAST_UINT8_TO_INT64},
{{VT::INT8, /*cast to*/ VT::INT16}, CNNL_CAST_INT8_TO_INT16},
{{VT::FP32, /*cast to*/ VT::FP64}, CNNL_CAST_FLOAT_TO_DOUBLE},
{{VT::FP64, /*cast to*/ VT::FP32}, CNNL_CAST_DOUBLE_TO_FLOAT},
{{VT::INT64, /*cast to*/ VT::FP32}, CNNL_CAST_INT64_TO_FLOAT},
{{VT::INT64, /*cast to*/ VT::FP16}, CNNL_CAST_INT64_TO_HALF},
{{VT::FP32, /*cast to*/ VT::INT64}, CNNL_CAST_FLOAT_TO_INT64},
{{VT::FP16, /*cast to*/ VT::INT64}, CNNL_CAST_HALF_TO_INT64},
};
cnnlCastDataType_t GetCastDataType(const VT::Type& src_type,
......@@ -1103,6 +1110,24 @@ class MLUCnnl {
const cnnlTensorDescriptor_t x_backprop_desc, void* x_backprop,
void* scale_backprop, void* offset_backprop);
static void LayerNormForward(const ExecutionContext& ctx, int axis,
const cnnlTensorDescriptor_t x_desc,
const void* x,
const cnnlTensorDescriptor_t weight_bias_desc,
const void* weight, const void* bias, float eps,
const cnnlTensorDescriptor_t y_desc, void* y,
const cnnlTensorDescriptor_t mean_rstd_desc,
void* saved_mean, void* saved_rstd);
static void LayerNormBackward(
const ExecutionContext& ctx, int axis,
const cnnlTensorDescriptor_t x_desc, const void* x,
const cnnlTensorDescriptor_t diff_z_desc, const void* diff_z,
const cnnlTensorDescriptor_t weight_bias_desc, const void* weight,
const cnnlTensorDescriptor_t mean_rstd_desc, const void* saved_mean,
const void* saved_rstd, const cnnlTensorDescriptor_t diff_x_desc,
void* diff_x, void* diff_weight, void* diff_bias);
static void Transpose(const ExecutionContext& ctx,
const std::vector<int> perm, const int input_dim,
const cnnlTensorDescriptor_t input_desc,
......@@ -1230,5 +1255,13 @@ inline void TransposeFromMLUTensor(const ExecutionContext& ctx,
GetBasePtr(transformed_output));
}
template <typename T>
inline void FillMLUTensorWithHostValue(const ExecutionContext& ctx, T value,
Tensor* out) {
MLUCnnlTensorDesc out_desc(*out);
MLUCnnl::Fill(ctx, CNNL_POINTER_MODE_HOST, &value, out_desc.get(),
GetBasePtr(out));
}
} // namespace operators
} // namespace paddle
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from operator import mul
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.nn.functional as F
from functools import reduce
import sys
sys.path.append('..')
from op_test import _set_use_system_allocator
from paddle.fluid import Program, program_guard
from paddle.fluid.contrib.mixed_precision.fp16_utils import _keep_layer_norm_scale_bias_to_fp32
from test_layer_norm_op import _reference_layer_norm_naive, _reference_layer_norm_grad
paddle.enable_static()
np.random.random(123)
_set_use_system_allocator(True)
class TestLayerNormOp(unittest.TestCase):
def setUp(self):
self.use_cudnn = True
self.place = paddle.device.MLUPlace(0)
self.__class__.use_mlu = True
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def check_forward_backward(self,
shape,
begin_norm_axis,
has_scale=True,
has_bias=True,
y_grad_scale=1.0,
use_mkldnn=False):
def test_with_place(place,
shape,
begin_norm_axis,
use_mkldnn=use_mkldnn):
# attr
epsilon = 0.00001
x_shape = shape
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
scale_shape = [D]
np.random.seed(123)
x = np.random.random_sample(x_shape).astype(np.float32)
scale = np.random.random_sample(scale_shape).astype(
np.float32) if has_scale else None
bias = np.random.random_sample(scale_shape).astype(
np.float32) if has_bias else None
y_grad = (np.random.random_sample(x_shape) *
y_grad_scale).astype(np.float32)
# reference forward & backward
y, mean, variance = _reference_layer_norm_naive(
x, scale, bias, epsilon, begin_norm_axis)
x_grad, scale_grad, bias_grad = _reference_layer_norm_grad(
x, y_grad, scale, bias, mean, variance, begin_norm_axis)
var_dict = locals()
var_dict['y@GRAD'] = y_grad
var_names = ['x', 'mean', 'variance', 'y', 'y@GRAD']
if has_scale:
var_names += ['scale']
if has_bias:
var_names += ['bias']
ground_truth = {name: var_dict[name] for name in var_names}
program = fluid.Program()
with fluid.program_guard(program):
block = program.global_block()
for name in ground_truth:
block.create_var(
name=name,
dtype='float32',
shape=ground_truth[name].shape)
inputs = {"X": block.var('x')}
fetch_list = [
'y',
'mean',
'variance',
'x@GRAD',
]
if has_scale:
inputs["Scale"] = block.var('scale')
fetch_list += ['scale@GRAD']
if has_bias:
inputs["Bias"] = block.var('bias')
fetch_list += ['bias@GRAD']
layer_norm_op = block.append_op(
type="layer_norm",
inputs=inputs,
outputs={
"Y": block.var('y'),
"Mean": block.var('mean'), # share the same memory
"Variance":
block.var('variance'), # share the same memory
},
attrs={
"epsilon": epsilon,
"begin_norm_axis": begin_norm_axis,
"use_mkldnn": use_mkldnn
})
# generate backward op_desc
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
layer_norm_op.desc, set(), [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
program._sync_with_cpp()
exe = fluid.Executor(place)
out = exe.run(program,
feed={
name: var_dict[name]
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=fetch_list)
self.__assert_close(y, out[0], "y")
self.__assert_close(mean, out[1], "mean")
self.__assert_close(1 / np.sqrt(variance), out[2], "variance",
1e-3)
self.__assert_close(x_grad, out[3], "x_grad")
if has_scale:
self.__assert_close(scale_grad,
out[fetch_list.index('scale@GRAD')],
"scale_grad", 1e-3)
if has_bias:
self.__assert_close(bias_grad,
out[fetch_list.index('bias@GRAD')],
"bias_grad")
test_with_place(self.place, shape, begin_norm_axis)
def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[1, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(
shape=[2, 3, 4, 5],
begin_norm_axis=1,
has_scale=False,
has_bias=True)
self.check_forward_backward(
shape=[2, 3, 4, 5],
begin_norm_axis=1,
has_scale=True,
has_bias=False)
self.check_forward_backward(
shape=[2, 3, 4, 5],
begin_norm_axis=1,
has_scale=False,
has_bias=False)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
self.check_forward_backward(
shape=[92, 513, 129], begin_norm_axis=2, y_grad_scale=0.1)
self.check_forward_backward(shape=[3, 34, 1134], begin_norm_axis=2)
self.check_forward_backward(
shape=[92, 513, 1134], begin_norm_axis=2, y_grad_scale=0.1)
self.check_forward_backward(
shape=[92, 513, 1134],
begin_norm_axis=2,
has_scale=False,
has_bias=True,
y_grad_scale=0.1)
self.check_forward_backward(
shape=[92, 513, 1134],
begin_norm_axis=2,
has_scale=True,
has_bias=False,
y_grad_scale=0.1)
self.check_forward_backward(
shape=[92, 513, 1134],
begin_norm_axis=2,
has_scale=False,
has_bias=False,
y_grad_scale=0.1)
self.check_forward_backward(
shape=[512, 1024], begin_norm_axis=1, has_scale=True, has_bias=True)
class TestLayerNormAPI(unittest.TestCase):
def test_case(self):
x = fluid.layers.data(
name='x',
shape=[64, 32, 256],
dtype='float32',
append_batch_size=False)
x = fluid.layers.layer_norm(
x,
scale=True,
shift=True,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None)
x = fluid.layers.layer_norm(
x,
scale=False,
shift=False,
begin_norm_axis=1,
epsilon=1e-05,
param_attr=None,
bias_attr=None)
x = fluid.layers.layer_norm(
x,
scale=False,
shift=False,
begin_norm_axis=1,
epsilon=1e-05,
param_attr="scale",
bias_attr="shift")
class TestDygraphLayerNormAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
paddle.enable_static()
layer_norm = fluid.LayerNorm([32, 32])
# the input of LayerNorm must be Variable.
x1 = np.random.random((3, 32, 32)).astype('float32')
self.assertRaises(TypeError, layer_norm, x1)
# the input dtype of LayerNorm must be float32 or float16
x2 = fluid.layers.data(name='x2', shape=[3, 32, 32], dtype="int32")
self.assertRaises(TypeError, layer_norm, x2)
class TestFP16ScaleBiasLayerNorm(unittest.TestCase):
def check_main(self, x_np, weight_np, bias_np, dtype):
paddle.disable_static()
weight_np = weight_np.astype(dtype)
bias_np = bias_np.astype(dtype)
x = paddle.to_tensor(x_np)
weight = paddle.to_tensor(weight_np)
bias = paddle.to_tensor(bias_np)
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
y = F.layer_norm(x, x.shape[1:], weight, bias)
x_g, w_g, b_g = paddle.grad(y, [x, weight, bias])
y_np = y.numpy().astype('float32')
x_g_np = x_g.numpy().astype('float32')
w_g_np = w_g.numpy().astype('float16')
b_g_np = b_g.numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np, w_g_np, b_g_np
def test_main(self):
x_np = np.random.random([10, 20]).astype('float16')
weight_np = np.random.random([20]).astype('float16')
bias_np = np.random.random([20]).astype('float16')
y_np_1, x_g_np_1, w_g_np_1, b_g_np_1 = self.check_main(
x_np, weight_np, bias_np, 'float16')
y_np_2, x_g_np_2, w_g_np_2, b_g_np_2 = self.check_main(
x_np, weight_np, bias_np, 'float32')
def assert_equal(x, y):
self.assertTrue(np.array_equal(x, y))
assert_equal(y_np_1, y_np_2)
assert_equal(x_g_np_1, x_g_np_2)
assert_equal(w_g_np_1, w_g_np_2)
assert_equal(b_g_np_1, b_g_np_2)
class TestGetSetKeepLayerNormScaleBiasFP32Flag(unittest.TestCase):
def test_main(self):
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(False)
self.assertFalse(_keep_layer_norm_scale_bias_to_fp32())
_keep_layer_norm_scale_bias_to_fp32(True)
self.assertTrue(_keep_layer_norm_scale_bias_to_fp32())
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册