未验证 提交 67c836fb 编写于 作者: K Kaipeng Deng 提交者: GitHub

batch_norm momentum support variable (#21246)

* batch_norm momentum support variable. test=develop

* fix format. test=develop

* add batch_norm momentum variable example. test=develop

* move MomentumTensor to training branch. test=develop

* split example. test=develop

* fix doc. test=develop

* fix PADDLE_ENFORCE ci. test=develop

* fix format. test=develop
上级 c0aa1367
...@@ -58,6 +58,13 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -58,6 +58,13 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
if (ctx->IsRuntime() && ctx->HasInput("MomentumTensor")) {
auto mom = ctx->Inputs("MomentumTensor");
PADDLE_ENFORCE_EQ(mom.size(), 1,
platform::errors::InvalidArgument(
"Input(MomentumTensor) size must be 1"));
}
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
x_dims.size(), 2, x_dims.size(), 2,
"ShapeError: the dimension of input X must greater than or equal to 2." "ShapeError: the dimension of input X must greater than or equal to 2."
...@@ -173,6 +180,11 @@ void BatchNormOpMaker::Make() { ...@@ -173,6 +180,11 @@ void BatchNormOpMaker::Make() {
AddInput("Variance", AddInput("Variance",
"The global variance (for training) " "The global variance (for training) "
"or estimated Variance (for testing)"); "or estimated Variance (for testing)");
AddInput("MomentumTensor",
"(Tensor<float32>, optional) If provided, batch_norm will "
"use this as momentum, this has a higher priority than "
"attr(momentum), the shape of this tensor MUST BE [1].")
.AsDispensable();
AddOutput("Y", "result after normalization"); AddOutput("Y", "result after normalization");
AddOutput("MeanOut", AddOutput("MeanOut",
"Share memory with Mean. " "Share memory with Mean. "
...@@ -221,7 +233,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -221,7 +233,7 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
...@@ -306,6 +318,13 @@ class BatchNormKernel<platform::CPUDeviceContext, T> ...@@ -306,6 +318,13 @@ class BatchNormKernel<platform::CPUDeviceContext, T>
PADDLE_THROW("Unknown storage order: %s", data_layout_str); PADDLE_THROW("Unknown storage order: %s", data_layout_str);
} }
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
if (ctx.HasInput("MomentumTensor")) {
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
momentum = mom_tensor->data<float>()[0];
}
running_mean_arr = running_mean_arr =
running_mean_arr * momentum + saved_mean_e * (1. - momentum); running_mean_arr * momentum + saved_mean_e * (1. - momentum);
running_var_arr = running_var_arr =
......
...@@ -43,7 +43,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -43,7 +43,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace."); "It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon")); double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum"); float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats"); const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout"); const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
...@@ -133,6 +133,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T> ...@@ -133,6 +133,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
est_mean->template data<BatchNormParamType<T>>(), est_mean->template data<BatchNormParamType<T>>(),
est_var->template data<BatchNormParamType<T>>(), epsilon)); est_var->template data<BatchNormParamType<T>>(), epsilon));
} else { } else {
// if MomentumTensor is set, use MomentumTensor value, momentum
// is only used in this training branch
if (ctx.HasInput("MomentumTensor")) {
const auto *mom_tensor = ctx.Input<Tensor>("MomentumTensor");
Tensor mom_cpu;
TensorCopySync(*mom_tensor, platform::CPUPlace(), &mom_cpu);
momentum = mom_cpu.data<float>()[0];
}
// Run training mode. // Run training mode.
// obtain running mean and running inv var, and see if we need to // obtain running mean and running inv var, and see if we need to
// initialize them. // initialize them.
......
...@@ -2432,13 +2432,14 @@ def batch_norm(input, ...@@ -2432,13 +2432,14 @@ def batch_norm(input,
sync_batch_norm automatically. sync_batch_norm automatically.
Args: Args:
input(variable): The rank of input variable can be 2, 3, 4, 5. The data type input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type
is float16 or float32 or float64. is float16 or float32 or float64.
act(string, Default None): Activation type, linear|relu|prelu|... act(string, Default None): Activation type, linear|relu|prelu|...
is_test (bool, Default False): A flag indicating whether it is in is_test (bool, Default False): A flag indicating whether it is in
test phrase or not. test phrase or not.
momentum(float, Default 0.9): The value used for the moving_mean and momentum(float|Variable, Default 0.9): The value used for the moving_mean and
moving_var computation. The updated formula is: moving_var computation. This should be a float number or a Variable with
shape [1] and data type as float32. The updated formula is:
:math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)` :math:`moving\_mean = moving\_mean * momentum + new\_mean * (1. - momentum)`
:math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)` :math:`moving\_var = moving\_var * momentum + new\_var * (1. - momentum)`
Default is 0.9. Default is 0.9.
...@@ -2487,6 +2488,33 @@ def batch_norm(input, ...@@ -2487,6 +2488,33 @@ def batch_norm(input,
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32') x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w') hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.batch_norm(input=hidden1) hidden2 = fluid.layers.batch_norm(input=hidden1)
.. code-block:: python
# batch_norm with momentum as Variable
import paddle.fluid as fluid
import paddle.fluid.layers.learning_rate_scheduler as lr_scheduler
def get_decay_momentum(momentum_init, decay_steps, decay_rate):
global_step = lr_scheduler._decay_step_counter()
momentum = fluid.layers.create_global_var(
shape=[1],
value=float(momentum_init),
dtype='float32',
# set persistable for save checkpoints and resume
persistable=True,
name="momentum")
div_res = global_step / decay_steps
decayed_momentum = momentum_init * (decay_rate**div_res)
fluid.layers.assign(decayed_momentum, momentum)
return momentum
x = fluid.data(name='x', shape=[3, 7, 3, 7], dtype='float32')
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
momentum = get_decay_momentum(0.9, 1e5, 0.9)
hidden2 = fluid.layers.batch_norm(input=hidden1, momentum=momentum)
""" """
assert bias_attr is not False, "bias_attr should not be False in batch_norm." assert bias_attr is not False, "bias_attr should not be False in batch_norm."
helper = LayerHelper('batch_norm', **locals()) helper = LayerHelper('batch_norm', **locals())
...@@ -2551,15 +2579,28 @@ def batch_norm(input, ...@@ -2551,15 +2579,28 @@ def batch_norm(input,
batch_norm_out = input if in_place else helper.create_variable_for_type_inference( batch_norm_out = input if in_place else helper.create_variable_for_type_inference(
dtype) dtype)
inputs = {
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
}
attrs = {
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats
}
if isinstance(momentum, Variable):
inputs['MomemtumTensor'] = momentum
else:
attrs['momentum'] = momentum
helper.append_op( helper.append_op(
type="batch_norm", type="batch_norm",
inputs={ inputs=inputs,
"X": input,
"Scale": scale,
"Bias": bias,
"Mean": mean,
"Variance": variance
},
outputs={ outputs={
"Y": batch_norm_out, "Y": batch_norm_out,
"MeanOut": mean_out, "MeanOut": mean_out,
...@@ -2567,15 +2608,7 @@ def batch_norm(input, ...@@ -2567,15 +2608,7 @@ def batch_norm(input,
"SavedMean": saved_mean, "SavedMean": saved_mean,
"SavedVariance": saved_variance "SavedVariance": saved_variance
}, },
attrs={ attrs=attrs)
"momentum": momentum,
"epsilon": epsilon,
"is_test": is_test,
"data_layout": data_layout,
"use_mkldnn": False,
"fuse_with_relu": False,
"use_global_stats": use_global_stats
})
return helper.append_activation(batch_norm_out) return helper.append_activation(batch_norm_out)
......
...@@ -310,6 +310,7 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -310,6 +310,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
self.fuse_with_relu = False self.fuse_with_relu = False
self.data_formats = ["NCHW", "NHWC"] self.data_formats = ["NCHW", "NHWC"]
self.momentum = 0.9 self.momentum = 0.9
self.use_momentum_variable = False
self.epsilon = 0.00001 self.epsilon = 0.00001
self.init_kernel_type() self.init_kernel_type()
self.init_test_case() self.init_test_case()
...@@ -367,6 +368,7 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -367,6 +368,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
bias = np.random.random_sample(scale_shape).astype(np.float32) bias = np.random.random_sample(scale_shape).astype(np.float32)
mean, variance = self.set_mean_variance(scale_shape, x, data_layout) mean, variance = self.set_mean_variance(scale_shape, x, data_layout)
y_grad = np.random.random_sample(shape).astype(np.float32) y_grad = np.random.random_sample(shape).astype(np.float32)
momentum_var = np.array([momentum]).astype(np.float32)
y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward( y, mean_out, variance_out, saved_mean, saved_variance, x_grad, scale_grad, bias_grad = self.ref_forward_backward(
x, y_grad, scale, bias, mean, variance, epsilon, momentum, x, y_grad, scale, bias, mean, variance, epsilon, momentum,
...@@ -380,7 +382,7 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -380,7 +382,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
var_names = [ var_names = [
'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean', 'x', 'scale', 'bias', 'mean', 'variance', 'y', 'saved_mean',
'saved_variance' 'saved_variance', 'momentum_var'
] ]
ground_truth = {name: var_dict[name] for name in var_names} ground_truth = {name: var_dict[name] for name in var_names}
...@@ -392,15 +394,28 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -392,15 +394,28 @@ class TestBatchNormOpTraining(unittest.TestCase):
name=name, name=name,
dtype='float32', dtype='float32',
shape=ground_truth[name].shape) shape=ground_truth[name].shape)
inputs = {
"X": block.var('x'),
"Scale": block.var('scale'),
"Bias": block.var('bias'),
"Mean": block.var('mean'),
"Variance": block.var('variance')
}
attrs = {
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu,
"use_global_stats": self.use_global_stats
}
if self.use_momentum_variable:
inputs['MomentumTensor'] = block.var('momentum_var')
else:
attrs['momentum'] = momentum
bn_op = block.append_op( bn_op = block.append_op(
type="batch_norm", type="batch_norm",
inputs={ inputs=inputs,
"X": block.var('x'),
"Scale": block.var('scale'),
"Bias": block.var('bias'),
"Mean": block.var('mean'),
"Variance": block.var('variance')
},
outputs={ outputs={
"Y": block.var('y'), "Y": block.var('y'),
"MeanOut": block.var('mean'), # share memory "MeanOut": block.var('mean'), # share memory
...@@ -408,15 +423,7 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -408,15 +423,7 @@ class TestBatchNormOpTraining(unittest.TestCase):
"SavedMean": block.var('saved_mean'), "SavedMean": block.var('saved_mean'),
"SavedVariance": block.var('saved_variance') "SavedVariance": block.var('saved_variance')
}, },
attrs={ attrs=attrs)
"momentum": momentum,
"epsilon": epsilon,
"is_test": False,
"data_layout": data_layout,
"use_mkldnn": self.use_mkldnn,
"fuse_with_relu": self.fuse_with_relu,
"use_global_stats": self.use_global_stats
})
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape) block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
# generate backward op_desc # generate backward op_desc
...@@ -434,14 +441,15 @@ class TestBatchNormOpTraining(unittest.TestCase): ...@@ -434,14 +441,15 @@ class TestBatchNormOpTraining(unittest.TestCase):
grad_var.set_dtype(core.VarDesc.VarType.FP32) grad_var.set_dtype(core.VarDesc.VarType.FP32)
exe = fluid.Executor(place) exe = fluid.Executor(place)
out = exe.run( out = exe.run(program,
program, feed={
feed={ name: var_dict[name]
name: var_dict[name] for name in [
for name in 'x', 'scale', 'bias', 'mean', 'variance',
['x', 'scale', 'bias', 'mean', 'variance', 'y@GRAD'] 'y@GRAD', 'momentum_var'
}, ]
fetch_list=self.fetch_list) },
fetch_list=self.fetch_list)
for id, name in enumerate(self.fetch_list): for id, name in enumerate(self.fetch_list):
if name == 'variance': if name == 'variance':
...@@ -471,6 +479,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining): ...@@ -471,6 +479,17 @@ class TestBatchNormOpTrainingCase1(TestBatchNormOpTraining):
self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD'] self.fetch_list = ['y', 'mean', 'variance', 'x@GRAD']
class TestBatchNormOpTrainingMomentumVariable(TestBatchNormOpTraining):
def init_test_case(self):
self.use_momentum_variable = True
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'mean', 'variance', 'saved_mean', 'saved_variance', 'x@GRAD',
'scale@GRAD', 'bias@GRAD'
]
class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining): class TestBatchNormOpFreezeStatsTraining(TestBatchNormOpTraining):
def init_test_case(self): def init_test_case(self):
self.use_global_stats = True self.use_global_stats = True
......
...@@ -2465,6 +2465,19 @@ class TestBook(LayerTest): ...@@ -2465,6 +2465,19 @@ class TestBook(LayerTest):
out = layers.batch_norm(data) out = layers.batch_norm(data)
return (out) return (out)
def make_batch_norm_momentum_variable(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
data = self._get_data(
name='data', shape=[32, 128, 128], dtype="float32")
momentum = self._get_data(
name='momentum',
shape=[1],
dtype='float32',
append_batch_size=False)
out = layers.batch_norm(data, momentum=momentum)
return (out)
def make_range(self): def make_range(self):
with program_guard(fluid.default_main_program(), with program_guard(fluid.default_main_program(),
fluid.default_startup_program()): fluid.default_startup_program()):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册