diff --git a/paddle/operators/layer_norm_op.cc b/paddle/operators/layer_norm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c6d2ae4d05becaeed34d66cad398cc90f9d3ece --- /dev/null +++ b/paddle/operators/layer_norm_op.cc @@ -0,0 +1,370 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/layer_norm_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; +using DataLayout = framework::DataLayout; + +template +using EigenMatrixMapRowMajor = Eigen::Map< + Eigen::Matrix>; +template +using ConstEigenMatrixMapRowMajor = Eigen::Map< + const Eigen::Matrix>; + +class LayerNormOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Y"), + "Output(Y) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Mean"), + "Output(Mean) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Variance"), + "Output(Variance) of LayerNormOp should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto begin_norm_axis = ctx->Attrs().Get("begin_norm_axis"); + PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(), + "'begin_norm_axis' must be less than the rank of X."); + + auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + if (ctx->HasInput("Scale")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right); + } + if (ctx->HasInput("Bias")) { + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right); + } + + ctx->SetOutputDim("Y", ctx->GetInputDim("X")); + ctx->SetOutputDim("Mean", {left}); + ctx->SetOutputDim("Variance", {left}); + ctx->ShareLoD("X", "Y"); + } +}; + +class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "(LoDTensor) The input tensor."); + AddInput("Scale", + "(Tensor, optional) Scale is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); + AddInput("Bias", + "(Tensor, optional) Bias is a 1-dimensional tensor of size " + "H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])." + "It is applied to the output.") + .AsDispensable(); + AddOutput("Y", "(LoDTensor) Result after normalization."); + AddOutput("Mean", "(Tensor) Mean of the current mini batch.") + .AsIntermediate(); + AddOutput("Variance", "(Tensor) Variance of the current mini batch.") + .AsIntermediate(); + + AddAttr("epsilon", + "(float, default 1e-5) Constant for " + "numerical stability") + .SetDefault(1e-5) + .AddCustomChecker([](const float &epsilon) { + PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f, + "'epsilon' should be between 0.0 and 0.001."); + }); + AddAttr("begin_norm_axis", + "(int default:1), the " + "axis of `begin_norm_axis ... Rank(X) - 1` will be " + "normalized. `begin_norm_axis` splits the tensor(`X`) to a " + "matrix [N,H].") + .SetDefault(1) + .AddCustomChecker([](const int &begin_norm_axis) { + PADDLE_ENFORCE_GT(begin_norm_axis, 0, + "'begin_norm_axis' should be greater than zero."); + }); + + AddComment(R"DOC( +Layer Normalization. + +Layer Norm has been implemented as discussed in the paper: +https://arxiv.org/abs/1607.06450 +... +)DOC"); + } +}; + +template +class LayerNormKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const float epsilon = ctx.Attr("epsilon"); + const auto *scale = ctx.Input("Scale"); + const auto *bias = ctx.Input("Bias"); + const auto *x = ctx.Input("X"); + const auto &x_dims = x->dims(); + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + + auto *output = ctx.Output("Y"); + auto *mean = ctx.Output("Mean"); + auto *var = ctx.Output("Variance"); + output->mutable_data(ctx.GetPlace()); + mean->mutable_data(ctx.GetPlace()); + var->mutable_data(ctx.GetPlace()); + + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + + auto input_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); + + auto mean_map = EigenMatrixMapRowMajor(mean->data(), left, 1); + auto var_map = EigenMatrixMapRowMajor(var->data(), left, 1); + auto output_map = EigenMatrixMapRowMajor(output->data(), left, right); + + auto squre = [](T ele) { return ele * ele; }; + auto add_epslion = [epsilon](T ele) { return ele + epsilon; }; + + mean_map = input_map.rowwise().mean(); + var_map = (input_map - mean_map.replicate(1, right)) + .unaryExpr(squre) + .rowwise() + .mean() + .unaryExpr(add_epslion); + + auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; + // TODO(zcd): Some thinking about output_map, is it appropriate that + // `output_map` and `input_map` point to the same memory. + auto inv_std = var_map.unaryExpr(inv_std_func); + if (scale && bias) { + auto scale_map = + ConstEigenMatrixMapRowMajor(scale->data(), 1, right); + auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)) + + bias_map.replicate(left, 1); + } else if (scale) { + auto scale_map = + ConstEigenMatrixMapRowMajor(scale->data(), 1, right); + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)); + } else if (bias) { + auto bias_map = ConstEigenMatrixMapRowMajor(bias->data(), 1, right); + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)) + + bias_map.replicate(left, 1); + } else { + output_map = (input_map - mean_map.replicate(1, right)) + .cwiseProduct(inv_std.replicate(1, right)); + } + } +}; + +class LayerNormGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + // check input + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Scale"), + "Input(Scale) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Mean"), + "Input(Mean) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Variance"), + "Input(Variance) of LayerNormOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), + "Input(Y@GRAD) of LayerNormOp should not be null."); + + // check output + if (ctx->HasOutput(framework::GradVarName("X"))) { + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + ctx->SetOutputDim(framework::GradVarName("Scale"), + ctx->GetInputDim("Scale")); + } + if (ctx->HasOutput(framework::GradVarName("Bias"))) { + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Bias")); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + const auto *var = ctx.InputVar(framework::GradVarName("Y")); + if (var == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + const Tensor *t = nullptr; + if (var->IsType()) { + t = &var->Get(); + } else if (var->IsType()) { + t = &var->Get(); + } + if (t == nullptr) { + PADDLE_THROW("can't find Y@GRAD"); + } + return framework::OpKernelType(framework::ToDataType(t->type()), + ctx.GetPlace()); + } +}; + +template +class LayerNormGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + const auto *mean = ctx.Input("Mean"); + const auto *var = ctx.Input("Variance"); + const auto *scale = ctx.Input("Scale"); + const auto *d_y = ctx.Input(framework::GradVarName("Y")); + + const auto &x_dims = x->dims(); + + const auto begin_norm_axis = ctx.Attr("begin_norm_axis"); + auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); + int left = static_cast(matrix_dim[0]); + int right = static_cast(matrix_dim[1]); + + // init output + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_scale = ctx.Output(framework::GradVarName("Scale")); + auto *d_bias = ctx.Output(framework::GradVarName("Bias")); + + auto x_map = ConstEigenMatrixMapRowMajor(x->data(), left, right); + auto d_y_map = ConstEigenMatrixMapRowMajor(d_y->data(), left, right); + auto mean_map = ConstEigenMatrixMapRowMajor(mean->data(), left, 1); + auto var_map = ConstEigenMatrixMapRowMajor(var->data(), left, 1); + + if (d_bias) { + d_bias->mutable_data(ctx.GetPlace()); + auto d_bias_map = EigenMatrixMapRowMajor(d_bias->data(), 1, right); + d_bias_map = d_y_map.colwise().sum(); + } + if (d_scale) { + d_scale->mutable_data(ctx.GetPlace()); + auto d_scale_map = + EigenMatrixMapRowMajor(d_scale->data(), 1, right); + auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; + // There are two equation to compute d_scale. One uses "Y" and the other + // does not use "Y" + d_scale_map = + ((x_map - mean_map.replicate(1, right)) + .cwiseProduct( + var_map.unaryExpr(inv_std_func).replicate(1, right)) + .cwiseProduct(d_y_map)) + .colwise() + .sum(); + } + + if (d_x) { + d_x->mutable_data(ctx.GetPlace()); + auto d_x_map = EigenMatrixMapRowMajor(d_x->data(), left, right); + auto triple_product_func = [](T ele) { return ele * ele * ele; }; + auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); }; + // TODO(zcd): these code can be refined + if (d_scale) { + auto scale_map = + ConstEigenMatrixMapRowMajor(scale->data(), 1, right); + // dy_dx + auto dx_end = var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .cwiseProduct(scale_map.replicate(left, 1)); + // dy_dmean_dx + auto dx_mean = (T(-1.0) / right) * + var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .cwiseProduct(scale_map.replicate(left, 1)) + .rowwise() + .sum() + .replicate(1, right); + // dy_var_dx + auto dvar_end_part = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(scale_map.replicate(left, 1)) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dvar_end = var_map.unaryExpr(inv_std_func) + .unaryExpr(triple_product_func) + .cwiseProduct(dvar_end_part) + .replicate(1, right); + auto dx_var = + (T(-1.0) / right) * + (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); + + d_x_map = dx_end + dx_mean + dx_var; + } else { + // dy_dx + auto dx_end = var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map); + // dy_dmean_dx + auto dx_mean = (T(-1.0) / right) * + var_map.unaryExpr(inv_std_func) + .replicate(1, right) + .cwiseProduct(d_y_map) + .rowwise() + .sum() + .replicate(1, right); + // dy_var_dx + auto dvar_end_part = (x_map - mean_map.replicate(1, right)) + .cwiseProduct(d_y_map) + .rowwise() + .sum(); + auto dvar_end = var_map.unaryExpr(inv_std_func) + .unaryExpr(triple_product_func) + .cwiseProduct(dvar_end_part) + .replicate(1, right); + auto dx_var = + (T(-1.0) / right) * + (x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end); + + d_x_map = dx_end + dx_mean + dx_var; + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker, + layer_norm_grad, ops::LayerNormGradOp); +REGISTER_OP_CPU_KERNEL( + layer_norm, + ops::LayerNormKernel); +REGISTER_OP_CPU_KERNEL( + layer_norm_grad, + ops::LayerNormGradKernel); diff --git a/paddle/operators/layer_norm_op.h b/paddle/operators/layer_norm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..bca35b91e6f52d35dee14aac9d080b52914942e3 --- /dev/null +++ b/paddle/operators/layer_norm_op.h @@ -0,0 +1,35 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class LayerNormKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +template +class LayerNormGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/fluid/tests/test_layer_norm_op.py b/python/paddle/v2/fluid/tests/test_layer_norm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7d5dc7d1a6e834490c3d499b0d92a10bd11ba9aa --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_layer_norm_op.py @@ -0,0 +1,252 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np + +from operator import mul +from op_test import OpTest +import paddle.v2.fluid.core as core +from paddle.v2.fluid.op import Operator +from paddle.v2.fluid.framework import grad_var_name + + +def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1): + x_shape = x.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + x.shape = [N, D] + + mean = np.mean(x, axis=1) + var = np.var(x, axis=1) + epsilon + output = scale.reshape([1, D]) * np.divide( + (x - mean.reshape([N, 1])), + (np.sqrt(var)).reshape([N, 1])) + beta.reshape([1, D]) + + x.shape, output.shape = x_shape, x_shape + return output, mean, var + + +def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1): + x_shape = x.shape + scale_shape = scale.shape + N = reduce(mul, x_shape[0:begin_norm_axis], 1) + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + x.shape, grad_y.shape = [N, D], [N, D] + var.shape, mean.shape = [N, 1], [N, 1] + scale.shape = [1, D] + + # d_bias + d_bias = np.sum(grad_y, axis=0).reshape([1, D]) + # d_scale + d_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y, + axis=0).reshape([1, D]) + # dx + dx_end = scale * np.sqrt(1.0 / var) * grad_y + d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape( + [N, 1]) # the second part equals to zero. + d_mean = 1.0 / D * d_mean_0 + d_std = np.sum( + -(1.0 / var) * (x - mean) * grad_y * scale, axis=1).reshape([N, 1]) * ( + 1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean)) + + grad_x = dx_end + d_mean + d_std + + grad_y.shape = x_shape + x.shape = x_shape + scale.shape = scale_shape + return grad_x, d_scale, d_bias + + +def get_backward_op(scope, op, no_grad_set): + backward_op = core.Operator.backward(op, no_grad_set) + for input in backward_op.input_vars(): + var = scope.var(input) + var.get_tensor() + for output in backward_op.output_vars(): + var = scope.var(output) + var.get_tensor() + return backward_op + + +def create_or_get_tensor(scope, var_name, var, place): + tensor = scope.var(var_name).get_tensor() + if var is not None: + assert isinstance(var, np.ndarray) + tensor.set_lod([[]]) + tensor.set_dims(var.shape) + tensor.set(var, place) + return tensor + + +def set_output_grad(scope, outputs, place, feed_dict=None): + def __set_tensor__(name, data=None): + out_tensor = scope.find_var(name).get_tensor() + grad_tensor = scope.var(grad_var_name(name)).get_tensor() + out_dtype = out_tensor.dtype() + if data is None: + if out_dtype == core.DataType.FP64: + data = np.ones(out_tensor.shape(), dtype=np.float64) + elif out_dtype == core.DataType.FP32: + data = np.ones(out_tensor.shape(), dtype=np.float32) + else: + raise ValueError("Not supported data type " + str(out_dtype)) + grad_tensor.set(data, place) + + for output in outputs: + data = None + if output in feed_dict: + data = feed_dict[output] + __set_tensor__(output, data) + + +class TestLayerNormdOp(OpTest): + def __assert_close(self, tensor, np_array, msg, atol=1e-4): + self.assertTrue( + np.allclose( + np.array(tensor).reshape(np_array.shape), np_array, atol=atol), + msg) + + def __assert_grad_close(self, + tensor, + np_array, + name, + place, + max_relative_error=0.02): + a = np.array(tensor).reshape(np_array.shape) + b = np_array + abs_a = np.abs(a) + abs_a[abs_a < 1e-5] = 1 + + diff_mat = np.abs(a - b) / abs_a + max_diff = np.max(diff_mat) + + def err_msg(): + offset = np.argmax(diff_mat > max_relative_error) + return ("%s Variable %s max gradient diff %f over limit %f, " + "the first error element is %d, %f, %f") % ( + "Gradient Check On %s" % str(place), name, max_diff, + max_relative_error, offset, a.flatten()[offset], + b.flatten()[offset]) + + self.assertLessEqual(max_diff, max_relative_error, err_msg()) + + def check_forward_backward(self, shape, begin_norm_axis): + def test_with_place(place, shape, begin_norm_axis=1): + # setUp + assert begin_norm_axis > 0 and begin_norm_axis < len( + shape), 'begin_norm_axis must be between 0 and len(shape)-1.' + # attr + epsilon = 0.00001 + x_shape = shape + D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1) + scale_shape = [D] + np.random.random(123) + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + bias_val = np.random.random_sample(scale_shape).astype(np.float32) + y_grad = np.random.random_sample(x_shape).astype(np.float32) + + # run forward + y_out, saved_mean, var_ref = _reference_layer_norm_naive( + x_val, scale_val, bias_val, epsilon, begin_norm_axis) + naive_fw = {"Y": y_out, "Mean": saved_mean, "Variance": var_ref} + + # get gradient + x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_layer_norm_grad( + x_val, y_grad, scale_val, saved_mean, var_ref, begin_norm_axis) + naive_grad = { + "X": x_grad_ref, + "Scale": scale_grad_ref, + "Bias": bias_grad_ref + } + + scope = core.Scope() + + # create input + input_map = {"X": x_val, "Scale": scale_val, "Bias": bias_val} + for i_name in input_map: + create_or_get_tensor(scope, i_name, input_map[i_name], place) + + # create output + output_map = {"Y": None, "Mean": None, "Variance": None} + output_tensor = {} + for o_name in output_map: + output_tensor[o_name] = create_or_get_tensor( + scope, o_name, output_map[o_name], place) + + layer_norm_op = Operator( + "layer_norm", + # inputs + X="X", + Scale="Scale", + Bias="Bias", + # outputs + Y="Y", + Mean="Mean", + Variance="Variance", + # attrs + epsilon=epsilon, + begin_norm_axis=begin_norm_axis) + + layer_norm_op.run(scope, place) + + # check forward result + atol = 5e-2 if isinstance(place, core.CUDAPlace) else 1e-4 + for o_tensor in output_tensor: + self.__assert_close(output_tensor[o_tensor], naive_fw[o_tensor], + o_tensor, atol) + + # run backward + layer_norm_op_grad = get_backward_op(scope, layer_norm_op, set()) + set_output_grad( + scope, ["Y", "Mean", "Variance"], + place, + feed_dict={"Y": y_grad}) + layer_norm_op_grad.run(scope, place) + + # get output + grad_tensor = {} + for o_name in naive_grad: + grad_tensor[o_name] = x_ = create_or_get_tensor( + scope, grad_var_name(o_name), None, place) + + # check gradient output + for o_grad in naive_grad: + self.__assert_grad_close(grad_tensor[o_grad], + naive_grad[o_grad], o_grad + "@GRAD", + place) + + places = [core.CPUPlace()] + if core.is_compile_gpu() and core.op_support_gpu("layer_norm"): + places.append(core.CUDAPlace(0)) + + for place in places: + test_with_place(place, shape, begin_norm_axis) + + def test_check_forward_backward_with_scale_and_bias(self): + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1) + self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3) + + def test_check_forward_backward_with_scale(self): + pass # TODO(zcd) + + def test_check_forward_backward_with_bias(self): + pass # TODO(zcd) + + def test_check_forward_backward(self): + pass # TODO(zcd) + + +if __name__ == '__main__': + unittest.main()