未验证 提交 9c63b7c1 编写于 作者: K Kaipeng Deng 提交者: GitHub

[cherry-pick] add bn momentum variable (#21435)

* batch_norm momentum support variable. test=develop
上级 5c7c6b1e
......@@ -58,6 +58,13 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) {
auto mom = ctx->Inputs("MomentumTensor");
PADDLE_ENFORCE_EQ(mom.size(), 1,
platform::errors::InvalidArgument(
"Input(MomentumTensor) size must be 1"));
}
PADDLE_ENFORCE_GE(
x_dims.size(), 2,
"ShapeError: the dimension of input X must greater than or equal to 2."
......@@ -173,6 +180,11 @@ void BatchNormOpMaker::Make() {
AddInput("Variance",
"The global variance (for training) "
"or estimated Variance (for testing)");
AddInput("MomentumTensor",
"(Tensor<float32>, optional) If provided, batch_norm will "
"use this as momentum, this has a higher priority than "
"attr(momentum), the shape of this tensor MUST BE [1].")
.AsDispensable();
AddOutput("Y", "result after normalization");
AddOutput("MeanOut",
"Share memory with Mean. "
......@@ -221,7 +233,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum");
float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
......@@ -306,6 +318,13 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
PADDLE_THROW("Unknown storage order: %s", data_layout_str);
}
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
if (ctx.HasInput("MomentumTensor")) {
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
momentum = mom_tensor->data<float>()[0];
}
running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr =
......
......@@ -43,7 +43,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
......@@ -133,6 +133,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(), epsilon));
} else {
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
if (ctx.HasInput("MomentumTensor")) {
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
Tensor mom_cpu;
TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
momentum = mom_cpu.data<float>()[0];
}
// Run training mode.
// obtain running mean and running inv var, and see if we need to
// initialize them.
......
......@@ -4176,13 +4176,14 @@ def batch_norm(input,
sync_batch_norm automatically.
Args:
input(variable): The rank of input variable can be 2, 3, 4, 5. The data type
input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type
is float16 or float32 or float64.
act(string, Default None): Activation type, linear|relu|prelu|...
is_test (bool, Default False): A flag indicating whether it is in
test phrase or not.
momentum(float, Default 0.9): The value used for the moving_mean and
moving_var computation. The updated formula is:
momentum(float|Variable, Default 0.9): The value used for the moving_mean and
moving_var computation. This should be a float number or a Variable with
shape [1] and data type as float32. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9.
......@@ -4228,6 +4229,33 @@ def batch_norm(input,
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.batch_norm(input=hidden1)
.. code-block:: python
# batch_norm with momentum as Variable
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
def get_decay_momentum(momentum_init, decay_steps, decay_rate):
global_step = lr_scheduler._decay_step_counter()
momentum = fluid.layers.create_global_var(
shape=[1],
value=float(momentum_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="momentum")
div_res = global_step / decay_steps
decayed_momentum = momentum_init * (decay_rate**div_res)
fluid.layers.assign(decayed_momentum, momentum)
return momentum
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
momentum = get_decay_momentum(0.9, 1e5, 0.9)
hidden2 = fluid.layers.batch_norm(input=hidden1, momentum=momentum)
"""
assert bias_attr is not False, "bias_attr should not be False in batch_norm."
helper = LayerHelper('batch_norm', **locals())
......@@ -4303,15 +4331,28 @@ def batch_norm(input,
batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype)
helper.append_op(
type="batch_norm",
inputs={
inputs = {
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
}
attrs = {
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats
}
if isinstance(momentum, Variable):
inputs['MomemtumTensor'] = momentum
else:
attrs['momentum'] = momentum
helper.append_op(
type="batch_norm",
inputs=inputs,
outputs={
"Y": batch_norm_out,
"MeanOut": mean_out,
......@@ -4319,15 +4360,7 @@ def batch_norm(input,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={
"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats
})
attrs=attrs)
return helper.append_activation(batch_norm_out)
......
......@@ -310,6 +310,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
self.fuse_with_relu = False
self.data_formats = ["NCHW", "NHWC"]
self.momentum = 0.9
self.use_momentum_variable = False
self.epsilon = 0.00001
self.init_kernel_type()
self.init_test_case()
......@@ -367,6 +368,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
bias = np.random.random_sample(scale_shape).astype(np.float32)
mean, variance = self.set_mean_variance(scale_shape, x, data_layout)
y_grad = np.random.random_sample(shape).astype(np.float32)
momentum_var = np.array([momentum]).astype(np.float32)
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
x, y_grad, scale, bias, mean, variance, epsilon, momentum,
......@@ -380,7 +382,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
var_names = [
'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean',
'saved_variance'
'saved_variance', 'momentum_var'
]
ground_truth = {name: var_dict[name] for name in var_names}
......@@ -392,15 +394,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
name=name,
dtype='float32',
shape=ground_truth[name].shape)
bn_op = block.append_op(
type="batch_norm",
inputs={
inputs = {
"X": block.var('x'),
"Scale": block.var('scale'),
"Bias": block.var('bias'),
"Mean": block.var('mean'),
"Variance": block.var('variance')
},
}
attrs = {
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu,
"use_global_stats": self.use_global_stats
}
if self.use_momentum_variable:
inputs['MomentumTensor'] = block.var('momentum_var')
else:
attrs['momentum'] = momentum
bn_op = block.append_op(
type="batch_norm",
inputs=inputs,
outputs={
"Y": block.var('y'),
"MeanOut": block.var('mean'), # share memory
......@@ -408,15 +423,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
"SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance')
},
attrs={
"momentum": momentum,
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu,
"use_global_stats": self.use_global_stats
})
attrs=attrs)
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
# generate backward op_desc
......@@ -434,12 +441,13 @@ class TestBatchNormOpTraining(unittest.TestCase):
grad_var.set_dtype(core.VarDesc.VarType.FP32)
exe = fluid.Executor(place)
out = exe.run(
program,
out = exe.run(program,
feed={
name: var_dict[name]
for name in
['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD']
for name in [
'x', 'scale', 'bias', 'mean', 'variance',
'y@GRAD', 'momentum_var'
]
},
fetch_list=self.fetch_list)
......@@ -471,6 +479,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self):
self.use_momentum_variable = True
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
'scale@GRAD', 'bias@GRAD'
]
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
def init_test_case(self):
self.use_global_stats = True
......
......@@ -2465,6 +2465,19 @@ class TestBook(LayerTest):
out = layers.batch_norm(data)
return (out)
def make_batch_norm_momentum_variable(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
data = self._get_data(
name='data', shape=[32, 128, 128], dtype="float32")
momentum = self._get_data(
name='momentum',
shape=[1],
dtype='float32',
append_batch_size=False)
out = layers.batch_norm(data, momentum=momentum)
return (out)
def make_range(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册