未验证 提交 19fd0717 编写于 作者: Q qingqing01 提交者: GitHub

Make the normalization operator more general and fix bug in l2_normalize. (#11348)

* Add normalization operator.
1. Refine the raw norm_op and let it more general to support to normalize Tensor along any axis.
2. There is a bug in l2_normalize API, which lacks sqrt after `reduce_sum`.
3. Use norm_op to refine the l2_normalize API.
4. Fix bug in test_normalization_wrapper.py.
上级 f15504e5
......@@ -16,40 +16,34 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
"X",
"(Tensor) The input tensor of norm operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddInput("Scale",
"(Tensor) The input tensor of norm operator. "
"The format of input tensor is C * 1.");
AddAttr<AttrType>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
AddInput("X", "(Tensor) A tensor of rank >= axis.");
AddAttr<int>("axis",
"The axis on which to apply normalization. If axis < 0, "
"the dimension to normalization is rank(X) + axis. -1 is "
"the last dimension.");
AddAttr<float>("epsilon",
"(float, default 1e-10) The epsilon value is used "
"to avoid division by zero.")
.SetDefault(1.0e-10f);
AddOutput("Out",
"(Tensor) The output tensor of norm operator."
"N * M."
"M = C * H * W");
AddOutput("Norm",
"(Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will "
"be used in backward kernel.")
.AsIntermediate();
AddOutput("Out", "(Tensor) A tensor of the same shape as X.");
AddComment(R"DOC(
"Input shape: $(N, C, H, W)$
Scale shape: $(C, 1)$
Output shape: $(N, C, H, W)$
Where
forward
$$
[\frac {x_{1}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{2}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{3}}{\sqrt{\sum{x_{i}^{2}}}} \cdot \cdot \cdot \frac {x_{n}}{\sqrt{\sum{x_{i}^{2}}}}]
$$
backward
$$
\frac{\frac{\mathrm{d}L }{\mathrm{d}y_{1}} - \frac {x_{1}\sum {\frac{\mathrm{d} L}{\mathrm{d} y_{j}}}x_{j}}{\sum x_{j}^{2}} }{\sqrt{\sum{x_{j}^{2}}}}
$$
)DOC");
Given a tensor, apply 2-normalization along the provided axis.
$$
y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
$$
where, $\sum {x^2}$ is calculated along the `axis` dimension.
)DOC");
}
};
......@@ -58,15 +52,15 @@ class NormOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of NormOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Scale"),
"Input(Scale) of NormOp"
"should not be null.");
"Input(X) of NormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of NormOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_x_dims);
auto xdim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", xdim);
int axis = ctx->Attrs().Get<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
xdim[axis] = 1;
ctx->SetOutputDim("Norm", xdim);
}
};
......@@ -84,12 +78,12 @@ class NormOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker<float>,
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(norm, ops::NormOp, ops::NormOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(norm_grad, ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL(
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormKernel<paddle::platform::CPUDeviceContext, double, float>);
REGISTER_OP_CPU_KERNEL(
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double, float>);
REGISTER_OP_CPU_KERNEL(norm, ops::NormKernel<CPU, float>,
ops::NormKernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(norm_grad, ops::NormGradKernel<CPU, float>,
ops::NormGradKernel<CPU, double>);
......@@ -16,9 +16,9 @@ limitations under the License. */
#include "paddle/fluid/operators/norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
norm, ops::NormKernel<paddle::platform::CUDADeviceContext, float>,
ops::NormKernel<paddle::platform::CUDADeviceContext, double, float>);
REGISTER_OP_CUDA_KERNEL(
norm_grad, ops::NormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double, float>);
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(norm, ops::NormKernel<CUDA, float>,
ops::NormKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(norm_grad, ops::NormGradKernel<CUDA, float>,
ops::NormGradKernel<CUDA, double>);
......@@ -19,156 +19,110 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T, typename AttrType = T>
inline void GetDims(const framework::DDim& dim, int axis, int* pre, int* n,
int* post) {
*pre = 1;
*post = 1;
*n = dim[axis];
for (int i = 0; i < axis; ++i) {
(*pre) *= dim[i];
}
for (int i = axis + 1; i < dim.size(); ++i) {
(*post) *= dim[i];
}
}
template <typename DeviceContext, typename T>
class NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
auto* out = context.Output<framework::Tensor>("Out");
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
out->mutable_data<T>(context.GetPlace());
int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1];
int height = in_x->dims()[2];
int width = in_x->dims()[3];
int fea_len = height * width;
auto* place =
context.template device_context<DeviceContext>().eigen_device();
auto x =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
// get square
framework::Tensor x_square;
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
auto x_square_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
x_square, framework::make_ddim({batch_size, fea_len * channels}));
x_square_eigen.device(*place) = x.square();
auto scale_eigen =
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten(
*scale);
for (int n = 0; n < batch_size; ++n) {
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
auto in_x_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
in_x_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
auto x_square_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
x_square_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor out_batch = out->Slice(n, n + 1);
auto out_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
out_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto tmp = framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(tmp_tensor);
// get colsum and sqrt , inverse
auto dim = Eigen::array<int, 1>({{0}});
tmp.device(*place) = x_square_batch_eigen.sum(dim);
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
Eigen::array<int, 2> broadcast_dim_col;
broadcast_dim_col[1] = 1;
broadcast_dim_col[0] = channels;
out_batch_eigen.device(*place) =
in_x_batch_eigen * (tmp.broadcast(broadcast_dim_col));
Eigen::array<int, 2> broadcast_dim_row;
broadcast_dim_row[1] = fea_len;
broadcast_dim_row[0] = 1;
out_batch_eigen.device(*place) =
out_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* out_y = ctx.Output<framework::Tensor>("Out");
auto* out_norm = ctx.Output<framework::Tensor>("Norm");
out_y->mutable_data<T>(ctx.GetPlace());
out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
auto ndim = out_norm->dims();
T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 2> norm_shape(pre, post);
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
auto y_e = framework::EigenVector<T>::Flatten(*out_y);
auto norm_e = framework::EigenVector<T>::Flatten(*out_norm);
auto x = x_e.reshape(shape);
auto y = y_e.reshape(shape);
auto norm = norm_e.reshape(norm_shape);
Eigen::DSizes<int, 1> rdim(1);
// y = x / sqrt((sum(x * x) + epsilon))
// norm = sqrt(sum(x * x) + epsilon)
auto sum = x.pow(2).sum(rdim) + eps;
norm.device(*place) = sum.sqrt();
// y = x / norm
Eigen::DSizes<int, 3> rshape(pre, 1, post);
Eigen::DSizes<int, 3> bcast(1, n, 1);
y.device(*place) = x / norm.reshape(rshape).broadcast(bcast);
}
};
template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
in_x_grad->mutable_data<T>(context.GetPlace());
int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1];
int height = in_x->dims()[2];
int width = in_x->dims()[3];
int fea_len = height * width;
auto* place =
context.template device_context<DeviceContext>().eigen_device();
auto scale_eigen =
framework::EigenVector<T, Eigen::RowMajor, Eigen::DenseIndex>::Flatten(
*scale);
auto x =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
// get square
framework::Tensor x_square;
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
auto x_square_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
x_square, framework::make_ddim({batch_size, fea_len * channels}));
x_square_eigen.device(*place) = x.square();
for (int n = 0; n < batch_size; ++n) {
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
auto in_x_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
in_x_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor in_g_batch = in_x_grad->Slice(n, n + 1);
auto in_g_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
in_g_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
auto x_square_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
x_square_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor outg_batch = out_grad->Slice(n, n + 1);
auto outg_batch_eigen =
framework::EigenMatrix<T, Eigen::RowMajor, Eigen::DenseIndex>::From(
outg_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto tmp_eigen =
framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(tmp_tensor);
auto dim = Eigen::array<int, 1>({{0}});
tmp_eigen.device(*place) = (in_x_batch_eigen * outg_batch_eigen).sum(dim);
framework::Tensor norm_tmp_tensor;
norm_tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto norm_tmp_eigen =
framework::EigenVector<T, Eigen::RowMajor,
Eigen::DenseIndex>::Flatten(norm_tmp_tensor);
norm_tmp_eigen.device(*place) =
(x_square_batch_eigen.sum(dim) + epsilon).sqrt();
Eigen::array<int, 2> broadcast_dim_col;
broadcast_dim_col[1] = 1;
broadcast_dim_col[0] = channels;
in_g_batch_eigen.device(*place) =
in_x_batch_eigen * tmp_eigen.broadcast(broadcast_dim_col);
in_g_batch_eigen.device(*place) =
in_g_batch_eigen /
(norm_tmp_eigen * norm_tmp_eigen).broadcast(broadcast_dim_col);
in_g_batch_eigen.device(*place) = outg_batch_eigen - in_g_batch_eigen;
// outg_batch_eigen + (in_g_batch_eigen * -1);
in_g_batch_eigen.device(*place) =
in_g_batch_eigen / norm_tmp_eigen.broadcast(broadcast_dim_col);
Eigen::array<int, 2> broadcast_dim_row;
broadcast_dim_row[1] = fea_len;
broadcast_dim_row[0] = 1;
in_g_batch_eigen.device(*place) =
in_g_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_x = ctx.Input<framework::Tensor>("X");
auto* in_norm = ctx.Input<framework::Tensor>("Norm");
auto* in_dy = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* out_dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
out_dx->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims();
int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis;
int pre, n, post;
GetDims(xdim, axis, &pre, &n, &post);
auto* place = ctx.template device_context<DeviceContext>().eigen_device();
auto x_e = framework::EigenVector<T>::Flatten(*in_x);
auto dy_e = framework::EigenVector<T>::Flatten(*in_dy);
auto norm_e = framework::EigenVector<T>::Flatten(*in_norm);
auto dx_e = framework::EigenVector<T>::Flatten(*out_dx);
Eigen::DSizes<int, 3> shape(pre, n, post);
Eigen::DSizes<int, 2> norm_shape(pre, post);
auto x = x_e.reshape(shape);
auto dy = dy_e.reshape(shape);
auto norm = norm_e.reshape(norm_shape);
auto dx = dx_e.reshape(shape);
framework::Tensor rsum;
rsum.mutable_data<T>({pre, post}, ctx.GetPlace());
auto sum = framework::EigenTensor<T, 2>::From(rsum);
Eigen::DSizes<int, 1> rdim(1);
Eigen::DSizes<int, 3> bcast(1, n, 1);
Eigen::DSizes<int, 3> rshape(pre, 1, post);
// dx = ( dy/sqrt(sum(x*x)) ) * [1 - x*sum(x) / (sum(x*x) + e)]
// = [dy - dy * x * sum(x) / (sum(x*x) + e)] / sqrt(sum(x*x))
// = [dy - x * sum(x*dy) / (sum(x*x) + e)] / sqrt(sum(x*x))
// 1. sum = sum(x*dy)
sum.device(*place) = (x * dy).sum(rdim);
// 2. dx = x * sum
dx.device(*place) = sum.reshape(rshape).broadcast(bcast) * x;
// 3. dx / (sum(x*x) + e)
// where, norm.pow(2) = sum(x*x) + e, which is calculated in forward.
dx.device(*place) = dx / norm.pow(2).broadcast(bcast);
// 4. [dy - dx] / sqrt(sum(x*x))
dx.device(*place) = (dy - dx) / norm.broadcast(bcast);
}
};
} // namespace operators
......
......@@ -2467,19 +2467,21 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
The l2 normalize layer normalizes `x` along dimension `axis` using an L2
norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes
output = x / sqrt(max(sum(x**2), epsilon))
.. math::
y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
For `x` with more dimensions, this layer independently normalizes each 1-D
slice along dimension `axis`.
Args:
x(Variable|list): The input tensor to l2_normalize layer.
axis(int): Dimension along which to normalize the input.
epsilon(float): A lower bound value for `x`'s l2 norm. sqrt(epsilon) will
be used as the divisor if the l2 norm of `x` is less than
sqrt(epsilon).
axis(int): The axis on which to apply normalization. If `axis < 0`,
the dimension to normalization is rank(X) + axis. -1 is the
last dimension.
epsilon(float): The epsilon value is used to avoid division by zero,
the defalut value is 1e-10.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
Returns:
......@@ -2498,46 +2500,17 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
axis = 0
helper = LayerHelper("l2_normalize", **locals())
square = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(type="square", inputs={"X": x}, outputs={"Out": square})
reduced_sum = helper.create_tmp_variable(dtype=x.dtype)
out = helper.create_tmp_variable(dtype=x.dtype)
norm = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reduce_sum",
inputs={"X": square},
outputs={"Out": reduced_sum},
type="norm",
inputs={"X": x},
outputs={"Out": out,
"Norm": norm},
attrs={
"dim": [1] if axis is None else [axis],
"keep_dim": True,
"reduce_all": False
"axis": 1 if axis is None else axis,
"epsilon": epsilon,
})
# TODO(caoying) A lower bound value epsilon for the norm is needed to
# imporve the numeric stability of reciprocal. This requires a maximum_op.
rsquare = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reciprocal", inputs={"X": reduced_sum}, outputs={"Out": rsquare})
# TODO(caoying) the current elementwise_mul operator does not support a
# general broadcast rule which broadcasts input(Y) to have the same
# dimension with Input(X) starting from a specified dimension. So this
# exanpsion is requred. Once a general broadcast rule is spported, this
# expanding canbe removed.
rsquare_expanded = helper.create_tmp_variable(dtype=x.dtype)
expand_times = [1] * len(x.shape)
expand_times[axis] = int(x.shape[axis])
helper.append_op(
type="expand",
inputs={"X": rsquare},
outputs={"Out": rsquare_expanded},
attrs={"expand_times": expand_times})
out = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="elementwise_mul",
inputs={"X": x,
"Y": rsquare_expanded},
outputs={"Out": out})
return out
......
......@@ -387,6 +387,12 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(output)
print(str(program))
def test_l2_normalize(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[8, 7, 10], dtype="float32")
output = layers.l2_normalize(x, axis=1)
def test_maxout(self):
program = Program()
with program_guard(program):
......
......@@ -17,44 +17,23 @@ import numpy as np
from op_test import OpTest
def norm(input, scale, epsilon):
s0, s1, s2, s3 = input.shape
x_square = input * input
for i in xrange(s0):
input_batch = input[i:i + 1, :, :, :]
input_batch = input_batch.reshape(s1, s2 * s3)
x_square_batch = x_square[i:i + 1, :, :, :]
x_square_batch = x_square_batch.reshape(s1, s2 * s3)
square_colsum = x_square_batch.sum(axis=0) + epsilon
tmp = pow(square_colsum, 0.5)
tmp = np.reciprocal(tmp)
tmp_tile = np.tile(tmp, s1)
tmp_tile = tmp_tile.reshape(s1, s2 * s3)
scale_tile = np.tile(scale, (1, s2 * s3))
scale_tile = scale_tile.reshape(s1, s2 * s3)
out_batch = input_batch * tmp_tile * scale_tile
out_batch = out_batch.reshape(1, s1, s2, s3)
if i == 0:
out = out_batch
else:
out = np.concatenate((out, out_batch), 0)
out.reshape(s0, s1, s2, s3)
return out
def l2_norm(x, axis, epsilon):
x2 = x**2
s = np.sum(x2, axis=axis, keepdims=True)
r = np.sqrt(s + epsilon)
y = x / np.broadcast_to(r, x.shape)
return y, r
class TestNormOp(OpTest):
def setUp(self):
self.op_type = "norm"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
scale = np.array([10, 10, 10])
self.inputs = {
'X': input.astype('float32'),
'Scale': scale.astype('float32')
}
self.attrs = {'epsilon': self.epsilon}
output = norm(input, scale, self.epsilon)
self.outputs = {'Out': output.astype('float32')}
x = np.random.random(self.shape).astype("float64")
y, norm = l2_norm(x, self.axis, self.epsilon)
self.inputs = {'X': x}
self.attrs = {'epsilon': self.epsilon, 'axis': self.axis}
self.outputs = {'Out': y, 'Norm': norm}
def test_check_output(self):
self.check_output()
......@@ -63,8 +42,23 @@ class TestNormOp(OpTest):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.shape = [2, 3, 2, 2]
self.epsilon = 1e-6
self.shape = [2, 3, 4, 4]
self.axis = 1
self.epsilon = 1e-8
class TestNormOp2(TestNormOp):
def init_test_case(self):
self.shape = [5, 3, 9, 7]
self.axis = 0
self.epsilon = 1e-8
class TestNormOp3(TestNormOp):
def init_test_case(self):
self.shape = [5, 3, 2, 7]
self.axis = -1
self.epsilon = 1e-8
if __name__ == '__main__':
......
......@@ -70,8 +70,9 @@ class TestNormalization(unittest.TestCase):
def l2_normalize(self, data, axis, epsilon):
""" Compute the groundtruth.
"""
output = data * np.reciprocal(
np.sum(np.square(data), axis=axis, keepdims=True))
output = data / np.broadcast_to(
np.sqrt(np.sum(np.square(data), axis=axis, keepdims=True)),
data.shape)
return output
def test_l2_normalize(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册