提交 ee998a9c 编写于 作者: Q Qiao Longfei 提交者: GitHub

CPU Batch Norm Op (#4964)

* init batch norm op

* prepare input output

* compute mean_out var_out save_mean save_var on CPU

* active is test

* use eigen to do computation

* complete batch norm forward

* set default momentum to 0.9

* add batch norm grad op in CPU

* add tensor_format and NHWC support, add python test

* add test training

* add batch norm gradient test

* improve comment, fix foward Python UnitTest

* add gradient test

* fix eigen warning

* follow name style

* fix a bug

* change float to T

* add simple forward test

* test with different place

* add backward test

* refine python test

* remove old python test code

* code clean

* follow code style

* update comment
上级 b54990e9
......@@ -8,7 +8,7 @@ ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG 4e79cb69b9425f5f8c3a84be4350d4ab75b5fd9d
GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/batch_norm_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
class BatchNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "");
PADDLE_ENFORCE(ctx->HasInput("Scale"), "");
PADDLE_ENFORCE(ctx->HasInput("Bias"), "");
PADDLE_ENFORCE(ctx->HasInput("Mean"), "");
PADDLE_ENFORCE(ctx->HasInput("Variance"), "");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "");
PADDLE_ENFORCE(ctx->HasOutput("MeanOut"), "");
PADDLE_ENFORCE(ctx->HasOutput("VarianceOut"), "");
PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), "");
PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), "");
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0],
"Mean and MeanOut should share the same memory");
PADDLE_ENFORCE_EQ(ctx->Inputs("Variance")[0],
ctx->Outputs("VarianceOut")[0],
"Variance and VarianceOut should share the same memory");
const auto x_dims = ctx->GetInputDim("X");
const TensorFormat tensor_format =
StringToTensorFormat(ctx->Attrs().Get<std::string>("tensor_format"));
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Scale")[0], C);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias").size(), 1UL);
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Bias")[0], C);
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("MeanOut", {C});
ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
}
};
class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BatchNormOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<bool>("is_test", "").SetDefault(false);
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "").SetDefault(1e-5);
AddAttr<std::string>("tensor_format", "").SetDefault("NCHW");
AddInput("X", "The input tensor");
AddInput("Scale",
"Scale is a 1-dimensional tensor of size C "
"to be applied to the output");
AddInput("Bias",
"Bias is a 1-dimensional tensor of size C "
"to be applied to the output");
AddInput("Mean",
"The global mean (for training) or the "
"estimated mean (for testing)");
AddInput("Variance",
"The global variance (for training) "
"or the estimated Variance (for testing)");
AddOutput("Y", "result after normalization");
AddOutput("MeanOut",
"Share memory with Mean. "
"Store the global mean when training");
AddOutput("VarianceOut",
"Share memory with Variance. "
"Store the global Variance when training");
AddOutput("SavedMean",
"Mean of the current mini batch, "
"will apply to output when training");
AddOutput("SavedVariance",
"Variance of the current mini batch, "
"will apply to output when training");
AddComment(R"DOC(
https://arxiv.org/pdf/1502.03167.pdf
NHWC `[batch, in_height, in_width, in_channels]`
NCHW `[batch, in_channels, in_height, in_width]`
)DOC");
}
};
template <typename T>
class BatchNormKernel<platform::CPUPlace, T> : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
const auto *x = ctx.Input<Tensor>("X");
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
const int N = x_dims[0];
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C;
auto *y = ctx.Output<Tensor>("Y");
auto *mean_out = ctx.Output<Tensor>("MeanOut");
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
// alloc memory
y->mutable_data<T>(ctx.GetPlace());
mean_out->mutable_data<T>(ctx.GetPlace());
variance_out->mutable_data<T>(ctx.GetPlace());
saved_mean->mutable_data<T>(ctx.GetPlace());
saved_variance->mutable_data<T>(ctx.GetPlace());
if (!is_test) {
// saved_xx is use just in this batch of data
EigenVectorArrayMap<T> saved_mean_e(
saved_mean->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> saved_variance_e(
saved_variance->mutable_data<T>(ctx.GetPlace()), C);
saved_mean_e.setZero();
saved_variance_e.setZero();
switch (tensor_format) {
case TensorFormat::NCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
saved_mean_e(nc % C) += x_arr.col(nc).sum();
}
saved_mean_e /= N * sample_size;
for (int nc = 0; nc < N * C; ++nc) {
saved_variance_e(nc % C) +=
(x_arr.col(nc) - saved_mean_e(nc % C)).matrix().squaredNorm();
}
saved_variance_e /= N * sample_size;
break;
}
case TensorFormat::NHWC: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
for (int i = 0; i < N * sample_size; ++i) {
saved_mean_e += x_arr.col(i);
}
saved_mean_e /= N * sample_size;
for (int i = 0; i < N * sample_size; ++i) {
saved_variance_e +=
(x_arr.col(i) - saved_mean_e) * (x_arr.col(i) - saved_mean_e);
}
saved_variance_e /= N * sample_size;
break;
}
default:
PADDLE_THROW("Unknown storage order: %s", tensor_format_str);
}
EigenVectorArrayMap<T> running_mean_arr(
mean_out->mutable_data<T>(ctx.GetPlace()), C);
EigenVectorArrayMap<T> running_var_arr(
variance_out->mutable_data<T>(ctx.GetPlace()), C);
running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr =
running_var_arr * momentum + saved_variance_e * (1. - momentum);
}
// use SavedMean and SavedVariance to do normalize
Eigen::Array<T, Eigen::Dynamic, 1> inv_std(C);
if (is_test) {
ConstEigenVectorArrayMap<T> var_arr(
ctx.Input<Tensor>("Variance")->data<T>(), C);
inv_std = (var_arr + epsilon).sqrt().inverse();
} else {
EigenVectorArrayMap<T> saved_inv_std(
ctx.Output<Tensor>("SavedVariance")->data<T>(), C);
// inverse SavedVariance first, gradient will use it too.
saved_inv_std = (saved_inv_std + epsilon).inverse().sqrt();
inv_std = saved_inv_std;
}
ConstEigenVectorArrayMap<T> mean_arr(
is_test ? ctx.Input<Tensor>("Mean")->data<T>()
: ctx.Output<Tensor>("SavedMean")->data<T>(),
C);
// ((x - est_mean) * (inv_var) * scale + bias
// formula transform ====>
// (x * inv_var * scale) + (bias - est_mean * inv_var * scale)
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *bias = ctx.Input<Tensor>("Bias");
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> bias_arr(bias->data<T>(), C);
Eigen::Array<T, Eigen::Dynamic, 1> new_scale = inv_std * scale_arr;
Eigen::Array<T, Eigen::Dynamic, 1> new_bias =
bias_arr - mean_arr * inv_std * scale_arr;
switch (tensor_format) {
case TensorFormat::NCHW: {
EigenArrayMap<T> y_arr(y->mutable_data<T>(ctx.GetPlace()), sample_size,
N * C);
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
for (int nc = 0; nc < N * C; ++nc) {
y_arr.col(nc) = x_arr.col(nc) * new_scale(nc % C) + new_bias(nc % C);
}
break;
}
case TensorFormat::NHWC: {
EigenArrayMap<T>(y->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size) =
(ConstEigenArrayMap<T>(x->data<T>(), C, N * sample_size).colwise() *
new_scale)
.colwise() +
new_bias;
break;
}
default:
PADDLE_THROW("Unknown storage order: %d", tensor_format);
}
}
};
class BatchNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
PADDLE_ENFORCE(ctx->HasInput("X"));
PADDLE_ENFORCE(ctx->HasInput("Scale"), "");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "");
PADDLE_ENFORCE(ctx->HasInput("SavedMean"), "");
PADDLE_ENFORCE(ctx->HasInput("SavedVariance"), "");
// check output
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Scale")), "");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), "");
const auto x_dims = ctx->GetInputDim("X");
const TensorFormat tensor_format =
StringToTensorFormat(ctx->Attrs().Get<std::string>("tensor_format"));
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
}
};
template <typename T>
class BatchNormGradKernel<platform::CPUPlace, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
// SavedVariance have been reverted in forward operator
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string tensor_format_str =
ctx.Attr<std::string>("tensor_format");
const TensorFormat tensor_format = StringToTensorFormat(tensor_format_str);
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims();
PADDLE_ENFORCE(x_dims.size() >= 3 && x_dims.size() <= 5,
"The Input dim size should be between 3 and 5");
const int N = x_dims[0];
const int C =
(tensor_format == TensorFormat::NCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
const int sample_size = x->numel() / N / C;
ConstEigenVectorArrayMap<T> scale_arr(scale->data<T>(), C);
ConstEigenVectorArrayMap<T> mean_arr(saved_mean->data<T>(), C);
ConstEigenVectorArrayMap<T> inv_var_arr(saved_inv_variance->data<T>(), C);
// init output
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
d_x->mutable_data<T>(ctx.GetPlace());
d_scale->mutable_data<T>(ctx.GetPlace());
d_bias->mutable_data<T>(ctx.GetPlace());
// d_bias = np.sum(d_y, axis=0)
// d_scale = np.sum((X - mean) / inv_std * dy, axis=0)
// d_x = (1. / N) * scale * inv_var * (N * d_y - np.sum(d_y, axis=0)
// - (X - mean) * inv_var * inv_var * np.sum(d_y * (X - mean), axis=0))
EigenVectorArrayMap<T> d_bias_arr(d_bias->mutable_data<T>(ctx.GetPlace()),
C);
EigenVectorArrayMap<T> d_scale_arr(d_scale->mutable_data<T>(ctx.GetPlace()),
C);
d_bias_arr.setZero();
d_scale_arr.setZero();
const auto scale_inv_var_nhw = scale_arr * inv_var_arr / (N * sample_size);
switch (tensor_format) {
case TensorFormat::NCHW: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), sample_size, N * C);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), sample_size, N * C);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()),
sample_size, N * C);
d_x_arr.setZero();
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_bias_arr(c) += d_y_arr.col(nc).sum();
d_scale_arr(c) +=
((x_arr.col(nc) - mean_arr(c)) * inv_var_arr(c) * d_y_arr.col(nc))
.sum();
}
for (int nc = 0; nc < N * C; ++nc) {
int c = nc % C;
d_x_arr.col(nc) +=
scale_inv_var_nhw(c) *
(d_y_arr.col(nc) * N * sample_size - d_bias_arr(c) -
(x_arr.col(nc) - mean_arr[c]) * d_scale_arr(c) * inv_var_arr(c));
}
break;
}
case TensorFormat::NHWC: {
ConstEigenArrayMap<T> x_arr(x->data<T>(), C, N * sample_size);
ConstEigenArrayMap<T> d_y_arr(d_y->data<T>(), C, N * sample_size);
EigenArrayMap<T> d_x_arr(d_x->mutable_data<T>(ctx.GetPlace()), C,
N * sample_size);
d_x_arr.setZero();
const auto d_y_row_sum = d_y_arr.rowwise().sum();
const auto x_minus_mean = x_arr.colwise() - mean_arr;
const auto d_y_mul_x_minus_mean_row_sum =
(d_y_arr * x_minus_mean).rowwise().sum();
const auto inv_var_sqr = inv_var_arr * inv_var_arr;
for (int nhw = 0; nhw < N * sample_size; ++nhw) {
d_bias_arr += d_y_arr.col(nhw);
d_scale_arr +=
(x_arr.col(nhw) - mean_arr) * inv_var_arr * d_y_arr.col(nhw);
d_x_arr.col(nhw) +=
scale_inv_var_nhw *
(d_y_arr.col(nhw) * N * sample_size - d_y_row_sum -
x_minus_mean.col(nhw) * inv_var_sqr *
d_y_mul_x_minus_mean_row_sum);
}
break;
}
default:
PADDLE_THROW("Unknown storage order: %s", tensor_format_str);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(batch_norm, ops::BatchNormOp, ops::BatchNormOpMaker,
batch_norm_grad, ops::BatchNormGradOp);
REGISTER_OP_CPU_KERNEL(batch_norm,
ops::BatchNormKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
batch_norm_grad,
ops::BatchNormGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
enum TensorFormat {
NHWC = 0,
NCHW = 1,
};
inline TensorFormat StringToTensorFormat(const std::string& str) {
if (str == "NHWC" || str == "nhwc") {
return TensorFormat::NHWC;
} else if (str == "NCHW" || str == "nchw") {
return TensorFormat::NCHW;
} else {
PADDLE_THROW("Unknown storage order string: %s", str);
}
}
template <typename Place, typename T>
class BatchNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename Place, typename T>
class BatchNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
} // namespace operators
} // namespace paddle
......@@ -390,7 +390,8 @@ class OpTest(unittest.TestCase):
output_names,
no_grad_set=None,
in_place=False,
max_relative_error=0.005):
max_relative_error=0.005,
user_defined_grads=None):
self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict()
op_outputs = self.outputs if hasattr(self, "outputs") else dict()
......@@ -403,7 +404,7 @@ class OpTest(unittest.TestCase):
if not type(output_names) is list:
output_names = [output_names]
numeric_grads = [
numeric_grads = user_defined_grads or [
get_numeric_gradient(
self.scope,
self.op,
......
import unittest
import numpy as np
from op_test import OpTest, get_backward_op, grad_var_name
import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator
def _reference_training(x, scale, offset, epsilon, data_format):
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format)
x_square = x * x
x_square_sum = np.sum(x_square, (0, 1, 2))
x_sum = np.sum(x, axis=(0, 1, 2))
element_count = np.size(x) / int(np.shape(x)[-1])
mean = x_sum / element_count
var = x_square_sum / element_count - mean * mean
normalized = (x - mean) / np.sqrt(var + epsilon)
return (normalized * scale + offset), mean, var
def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
# Use the following formulas to calculate gradients:
# grad_scale =
# sum(grad_y * (x - mean)) * rsqrt(var + epsilon)
#
# grad_offset = sum(output_y)
#
# grad_x =
# 1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
# (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
if data_format != "NHWC":
raise ValueError("data_format must be NHWC, got %s." % data_format)
grad_x = scale * (grad_y - np.mean(
grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean(
grad_y * (x - mean), axis=(0, 1, 2)) /
(var + epsilon)) / np.sqrt(var + epsilon)
grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon),
axis=(0, 1, 2))
grad_offset = np.sum(grad_y, axis=(0, 1, 2))
return grad_x, grad_scale, grad_offset
def create_or_get_tensor(scope, var_name, var, place):
tensor = scope.var(var_name).get_tensor()
if var is not None:
assert isinstance(var, np.ndarray)
tensor.set_lod([[]])
tensor.set_dims(var.shape)
tensor.set(var, place)
return tensor
def set_output_grad(scope, outputs, place):
def __set_tensor__(name):
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place)
for output in outputs:
__set_tensor__(output)
class TestBatchNormOp(OpTest):
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def test_forward_backward(self):
# attr
data_format = "NHWC"
epsilon = 0.00001
momentum = 0.9
channel_num = 2
x_shape = [2, 3, 4, channel_num]
scale_shape = [channel_num]
# input
x_val = np.random.random_sample(x_shape).astype(np.float32)
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
mean = np.zeros(scale_shape).astype(np.float32)
variance = np.zeros(scale_shape).astype(np.float32)
# run forward
y_out, saved_mean, var_ref = _reference_training(
x_val, scale_val, bias_val, epsilon, data_format)
# run backward
mean_out = saved_mean * (1 - momentum)
variance_out = var_ref * (1 - momentum)
saved_variance = 1 / np.sqrt(var_ref + epsilon)
# for gradient test
y_grad = np.ones(x_shape).astype(np.float32)
x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, data_format)
def test_with_place(place):
scope = core.Scope()
# create input
x_tensor = create_or_get_tensor(scope, "x_val", x_val, place)
scale_tensor = create_or_get_tensor(scope, "scale_val", scale_val,
place)
bias_tensor = create_or_get_tensor(scope, "bias_val", bias_val,
place)
mean_tensor = create_or_get_tensor(scope, "mean", mean, place)
variance_tensor = create_or_get_tensor(scope, "variance", variance,
place)
# create output
y_tensor = create_or_get_tensor(scope, "y_out", None, place)
saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
place)
saved_variance_tensor = create_or_get_tensor(
scope, "saved_variance", None, place)
mean_out_tensor = mean_tensor
variance_out_tensor = variance_tensor
batch_norm_op = Operator(
"batch_norm",
# inputs
X="x_val",
Scale="scale_val",
Bias="bias_val",
Mean="mean",
Variance="variance",
# outputs
Y="y_out",
MeanOut="mean",
VarianceOut="variance",
SavedMean="saved_mean",
SavedVariance="saved_variance",
# attrs
is_test=False,
tensor_format=data_format,
momentum=momentum,
epsilon=epsilon)
ctx = core.DeviceContext.create(place)
batch_norm_op.run(scope, ctx)
# check forward result
self.__assert_close(y_tensor, y_out, "y_out")
self.__assert_close(saved_mean_tensor, saved_mean, "saved_mean")
self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out")
# FIXME(qiao) figure out why with cuDNN variance_out have a higher error rate
if isinstance(place, core.GPUPlace):
atol = 5e-2
else:
atol = 1e-4
self.__assert_close(variance_out_tensor, variance_out,
"variance_out", atol)
# run backward
batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
set_output_grad(
scope,
["y_out", "mean", "variance", "saved_mean", "saved_variance"],
place)
batch_norm_op_grad.run(scope, ctx)
x_grad_tensor = create_or_get_tensor(scope,
grad_var_name("x_val"), None,
place)
scale_grad_tensor = create_or_get_tensor(scope,
grad_var_name("scale_val"),
None, place)
bias_grad_tensor = create_or_get_tensor(scope,
grad_var_name("bias_val"),
None, place)
# check gradient output
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0))
for place in places:
test_with_place(place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册