未验证 提交 e261c792 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #7789 from chengduoZH/feature/layer_norm

Add CPU implementation of layer normalization operator.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/layer_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename T>
using EigenMatrixMapRowMajor = Eigen::Map<
Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
template <typename T>
using ConstEigenMatrixMapRowMajor = Eigen::Map<
const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>;
class LayerNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Mean"),
"Output(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Variance"),
"Output(Variance) of LayerNormOp should not be null.");
auto x_dim = ctx->GetInputDim("X");
auto begin_norm_axis = ctx->Attrs().Get<int>("begin_norm_axis");
PADDLE_ENFORCE_LT(begin_norm_axis, x_dim.size(),
"'begin_norm_axis' must be less than the rank of X.");
auto matrix_dim = framework::flatten_to_2d(x_dim, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
if (ctx->HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], right);
}
if (ctx->HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], right);
}
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->SetOutputDim("Mean", {left});
ctx->SetOutputDim("Variance", {left});
ctx->ShareLoD("X", "Y");
}
};
class LayerNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LayerNormOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(LoDTensor) The input tensor.");
AddInput("Scale",
"(Tensor, optional) Scale is a 1-dimensional tensor of size "
"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.")
.AsDispensable();
AddInput("Bias",
"(Tensor, optional) Bias is a 1-dimensional tensor of size "
"H(`begin_norm_axis` splits the tensor(`X`) to a matrix [N,H])."
"It is applied to the output.")
.AsDispensable();
AddOutput("Y", "(LoDTensor) Result after normalization.");
AddOutput("Mean", "(Tensor) Mean of the current mini batch.")
.AsIntermediate();
AddOutput("Variance", "(Tensor) Variance of the current mini batch.")
.AsIntermediate();
AddAttr<float>("epsilon",
"(float, default 1e-5) Constant for "
"numerical stability")
.SetDefault(1e-5)
.AddCustomChecker([](const float &epsilon) {
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
"'epsilon' should be between 0.0 and 0.001.");
});
AddAttr<int>("begin_norm_axis",
"(int default:1), the "
"axis of `begin_norm_axis ... Rank(X) - 1` will be "
"normalized. `begin_norm_axis` splits the tensor(`X`) to a "
"matrix [N,H].")
.SetDefault(1)
.AddCustomChecker([](const int &begin_norm_axis) {
PADDLE_ENFORCE_GT(begin_norm_axis, 0,
"'begin_norm_axis' should be greater than zero.");
});
AddComment(R"DOC(
Layer Normalization.
Layer Norm has been implemented as discussed in the paper:
https://arxiv.org/abs/1607.06450
...
)DOC");
}
};
template <typename T>
class LayerNormKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto *output = ctx.Output<Tensor>("Y");
auto *mean = ctx.Output<Tensor>("Mean");
auto *var = ctx.Output<Tensor>("Variance");
output->mutable_data<T>(ctx.GetPlace());
mean->mutable_data<T>(ctx.GetPlace());
var->mutable_data<T>(ctx.GetPlace());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
auto input_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto mean_map = EigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = EigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
auto output_map = EigenMatrixMapRowMajor<T>(output->data<T>(), left, right);
auto squre = [](T ele) { return ele * ele; };
auto add_epslion = [epsilon](T ele) { return ele + epsilon; };
mean_map = input_map.rowwise().mean();
var_map = (input_map - mean_map.replicate(1, right))
.unaryExpr(squre)
.rowwise()
.mean()
.unaryExpr(add_epslion);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): Some thinking about output_map, is it appropriate that
// `output_map` and `input_map` point to the same memory.
auto inv_std = var_map.unaryExpr(inv_std_func);
if (scale && bias) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1)) +
bias_map.replicate(left, 1);
} else if (scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1));
} else if (bias) {
auto bias_map = ConstEigenMatrixMapRowMajor<T>(bias->data<T>(), 1, right);
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right)) +
bias_map.replicate(left, 1);
} else {
output_map = (input_map - mean_map.replicate(1, right))
.cwiseProduct(inv_std.replicate(1, right));
}
}
};
class LayerNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Mean"),
"Input(Mean) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Variance"),
"Input(Variance) of LayerNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) of LayerNormOp should not be null.");
// check output
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
ctx->SetOutputDim(framework::GradVarName("Scale"),
ctx->GetInputDim("Scale"));
}
if (ctx->HasOutput(framework::GradVarName("Bias"))) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
const Tensor *t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t == nullptr) {
PADDLE_THROW("can't find Y@GRAD");
}
return framework::OpKernelType(framework::ToDataType(t->type()),
ctx.GetPlace());
}
};
template <typename T>
class LayerNormGradKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *mean = ctx.Input<Tensor>("Mean");
const auto *var = ctx.Input<Tensor>("Variance");
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto &x_dims = x->dims();
const auto begin_norm_axis = ctx.Attr<int>("begin_norm_axis");
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto x_map = ConstEigenMatrixMapRowMajor<T>(x->data<T>(), left, right);
auto d_y_map = ConstEigenMatrixMapRowMajor<T>(d_y->data<T>(), left, right);
auto mean_map = ConstEigenMatrixMapRowMajor<T>(mean->data<T>(), left, 1);
auto var_map = ConstEigenMatrixMapRowMajor<T>(var->data<T>(), left, 1);
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
auto d_bias_map = EigenMatrixMapRowMajor<T>(d_bias->data<T>(), 1, right);
d_bias_map = d_y_map.colwise().sum();
}
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
auto d_scale_map =
EigenMatrixMapRowMajor<T>(d_scale->data<T>(), 1, right);
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// There are two equation to compute d_scale. One uses "Y" and the other
// does not use "Y"
d_scale_map =
((x_map - mean_map.replicate(1, right))
.cwiseProduct(
var_map.unaryExpr(inv_std_func).replicate(1, right))
.cwiseProduct(d_y_map))
.colwise()
.sum();
}
if (d_x) {
d_x->mutable_data<T>(ctx.GetPlace());
auto d_x_map = EigenMatrixMapRowMajor<T>(d_x->data<T>(), left, right);
auto triple_product_func = [](T ele) { return ele * ele * ele; };
auto inv_std_func = [](T ele) { return std::sqrt(1 / ele); };
// TODO(zcd): these code can be refined
if (d_scale) {
auto scale_map =
ConstEigenMatrixMapRowMajor<T>(scale->data<T>(), 1, right);
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1));
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.cwiseProduct(scale_map.replicate(left, 1))
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(scale_map.replicate(left, 1))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
} else {
// dy_dx
auto dx_end = var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map);
// dy_dmean_dx
auto dx_mean = (T(-1.0) / right) *
var_map.unaryExpr(inv_std_func)
.replicate(1, right)
.cwiseProduct(d_y_map)
.rowwise()
.sum()
.replicate(1, right);
// dy_var_dx
auto dvar_end_part = (x_map - mean_map.replicate(1, right))
.cwiseProduct(d_y_map)
.rowwise()
.sum();
auto dvar_end = var_map.unaryExpr(inv_std_func)
.unaryExpr(triple_product_func)
.cwiseProduct(dvar_end_part)
.replicate(1, right);
auto dx_var =
(T(-1.0) / right) *
(x_map - mean_map.replicate(1, right)).cwiseProduct(dvar_end);
d_x_map = dx_end + dx_mean + dx_var;
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
layer_norm_grad, ops::LayerNormGradOp);
REGISTER_OP_CPU_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LayerNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename DeviceContext, typename T>
class LayerNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from operator import mul
from op_test import OpTest
import paddle.v2.fluid.core as core
from paddle.v2.fluid.op import Operator
from paddle.v2.fluid.framework import grad_var_name
def _reference_layer_norm_naive(x, scale, beta, epsilon, begin_norm_axis=1):
x_shape = x.shape
N = reduce(mul, x_shape[0:begin_norm_axis], 1)
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
x.shape = [N, D]
mean = np.mean(x, axis=1)
var = np.var(x, axis=1) + epsilon
output = scale.reshape([1, D]) * np.divide(
(x - mean.reshape([N, 1])),
(np.sqrt(var)).reshape([N, 1])) + beta.reshape([1, D])
x.shape, output.shape = x_shape, x_shape
return output, mean, var
def _reference_layer_norm_grad(x, grad_y, scale, mean, var, begin_norm_axis=1):
x_shape = x.shape
scale_shape = scale.shape
N = reduce(mul, x_shape[0:begin_norm_axis], 1)
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
x.shape, grad_y.shape = [N, D], [N, D]
var.shape, mean.shape = [N, 1], [N, 1]
scale.shape = [1, D]
# d_bias
d_bias = np.sum(grad_y, axis=0).reshape([1, D])
# d_scale
d_scale = np.sum(((x - mean) * np.sqrt(1 / var)) * grad_y,
axis=0).reshape([1, D])
# dx
dx_end = scale * np.sqrt(1.0 / var) * grad_y
d_mean_0 = np.sum(-np.sqrt(1.0 / var) * grad_y * scale, axis=1).reshape(
[N, 1]) # the second part equals to zero.
d_mean = 1.0 / D * d_mean_0
d_std = np.sum(
-(1.0 / var) * (x - mean) * grad_y * scale, axis=1).reshape([N, 1]) * (
1.0 / D * np.sqrt(1.0 / var).reshape([N, 1]) * (x - mean))
grad_x = dx_end + d_mean + d_std
grad_y.shape = x_shape
x.shape = x_shape
scale.shape = scale_shape
return grad_x, d_scale, d_bias
def get_backward_op(scope, op, no_grad_set):
backward_op = core.Operator.backward(op, no_grad_set)
for input in backward_op.input_vars():
var = scope.var(input)
var.get_tensor()
for output in backward_op.output_vars():
var = scope.var(output)
var.get_tensor()
return backward_op
def create_or_get_tensor(scope, var_name, var, place):
tensor = scope.var(var_name).get_tensor()
if var is not None:
assert isinstance(var, np.ndarray)
tensor.set_lod([[]])
tensor.set_dims(var.shape)
tensor.set(var, place)
return tensor
def set_output_grad(scope, outputs, place, feed_dict=None):
def __set_tensor__(name, data=None):
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if data is None:
if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place)
for output in outputs:
data = None
if output in feed_dict:
data = feed_dict[output]
__set_tensor__(output, data)
class TestLayerNormdOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(
np.allclose(
np.array(tensor).reshape(np_array.shape), np_array, atol=atol),
msg)
def __assert_grad_close(self,
tensor,
np_array,
name,
place,
max_relative_error=0.02):
a = np.array(tensor).reshape(np_array.shape)
b = np_array
abs_a = np.abs(a)
abs_a[abs_a < 1e-5] = 1
diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
def err_msg():
offset = np.argmax(diff_mat > max_relative_error)
return ("%s Variable %s max gradient diff %f over limit %f, "
"the first error element is %d, %f, %f") % (
"Gradient Check On %s" % str(place), name, max_diff,
max_relative_error, offset, a.flatten()[offset],
b.flatten()[offset])
self.assertLessEqual(max_diff, max_relative_error, err_msg())
def check_forward_backward(self, shape, begin_norm_axis):
def test_with_place(place, shape, begin_norm_axis=1):
# setUp
assert begin_norm_axis > 0 and begin_norm_axis < len(
shape), 'begin_norm_axis must be between 0 and len(shape)-1.'
# attr
epsilon = 0.00001
x_shape = shape
D = reduce(mul, x_shape[begin_norm_axis:len(x_shape)], 1)
scale_shape = [D]
np.random.random(123)
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
y_grad = np.random.random_sample(x_shape).astype(np.float32)
# run forward
y_out, saved_mean, var_ref = _reference_layer_norm_naive(
x_val, scale_val, bias_val, epsilon, begin_norm_axis)
naive_fw = {"Y": y_out, "Mean": saved_mean, "Variance": var_ref}
# get gradient
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_layer_norm_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, begin_norm_axis)
naive_grad = {
"X": x_grad_ref,
"Scale": scale_grad_ref,
"Bias": bias_grad_ref
}
scope = core.Scope()
# create input
input_map = {"X": x_val, "Scale": scale_val, "Bias": bias_val}
for i_name in input_map:
create_or_get_tensor(scope, i_name, input_map[i_name], place)
# create output
output_map = {"Y": None, "Mean": None, "Variance": None}
output_tensor = {}
for o_name in output_map:
output_tensor[o_name] = create_or_get_tensor(
scope, o_name, output_map[o_name], place)
layer_norm_op = Operator(
"layer_norm",
# inputs
X="X",
Scale="Scale",
Bias="Bias",
# outputs
Y="Y",
Mean="Mean",
Variance="Variance",
# attrs
epsilon=epsilon,
begin_norm_axis=begin_norm_axis)
layer_norm_op.run(scope, place)
# check forward result
atol = 5e-2 if isinstance(place, core.CUDAPlace) else 1e-4
for o_tensor in output_tensor:
self.__assert_close(output_tensor[o_tensor], naive_fw[o_tensor],
o_tensor, atol)
# run backward
layer_norm_op_grad = get_backward_op(scope, layer_norm_op, set())
set_output_grad(
scope, ["Y", "Mean", "Variance"],
place,
feed_dict={"Y": y_grad})
layer_norm_op_grad.run(scope, place)
# get output
grad_tensor = {}
for o_name in naive_grad:
grad_tensor[o_name] = x_ = create_or_get_tensor(
scope, grad_var_name(o_name), None, place)
# check gradient output
for o_grad in naive_grad:
self.__assert_grad_close(grad_tensor[o_grad],
naive_grad[o_grad], o_grad + "@GRAD",
place)
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("layer_norm"):
places.append(core.CUDAPlace(0))
for place in places:
test_with_place(place, shape, begin_norm_axis)
def test_check_forward_backward_with_scale_and_bias(self):
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=1)
self.check_forward_backward(shape=[2, 3, 4, 5], begin_norm_axis=3)
def test_check_forward_backward_with_scale(self):
pass # TODO(zcd)
def test_check_forward_backward_with_bias(self):
pass # TODO(zcd)
def test_check_forward_backward(self):
pass # TODO(zcd)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册