From e27a030072d23701aeda91cc6c1b76c84ca361d2 Mon Sep 17 00:00:00 2001 From: guosheng Date: Wed, 24 Jan 2018 19:17:49 +0800 Subject: [PATCH] Add weight normalization --- python/paddle/v2/fluid/layer_helper.py | 186 +++++++++++++++++- python/paddle/v2/fluid/param_attr.py | 17 ++ .../fluid/tests/test_weight_normalization.py | 122 ++++++++++++ 3 files changed, 321 insertions(+), 4 deletions(-) create mode 100644 python/paddle/v2/fluid/tests/test_weight_normalization.py diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index 0b0064ade9..368cf8ed2e 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -18,7 +18,7 @@ import itertools from framework import Variable, Parameter, default_main_program, default_startup_program, \ unique_name, dtype_is_floating from paddle.v2.fluid.initializer import Constant, Xavier -from param_attr import ParamAttr +from param_attr import ParamAttr, WeightNormParamAttr class LayerHelper(object): @@ -103,6 +103,177 @@ class LayerHelper(object): raise ValueError("Data Type mismatch") return dtype + def _create_weight_normalize(self, attr, shape, dtype): + from .layers import elementwise_mul, elementwise_div, reshape + + # Remove these ops when LayerHelper and layers support indicating + # program and block. + def __norm_op(x, + out=None, + p=2, + dim=None, + keep_dim=False, + block=self.startup_program.global_block()): + if out is None: + out = block.create_var( + name=unique_name(".".join([self.name, 'weight_norm_norm'])), + dtype=dtype, + persistable=False) + abs_out = block.create_var( + name=unique_name(".".join([self.name, 'weight_norm_abs'])), + dtype=dtype, + persistable=False) + block.append_op( + type='abs', inputs={'X': x}, outputs={'Out': abs_out}) + pow_out = block.create_var( + name=unique_name(".".join([self.name, 'weight_norm_pow'])), + dtype=dtype, + persistable=False) + block.append_op( + type='pow', + inputs={'X': abs_out}, + outputs={'Out': pow_out}, + attrs={'factor': float(p)}) + sum_out = block.create_var( + name=unique_name(".".join([self.name, 'weight_norm_sum'])), + dtype=dtype, + persistable=False) + block.append_op( + type='reduce_sum', + inputs={'X': pow_out}, + outputs={'Out': sum_out}, + attrs={ + 'dim': dim, + 'keep_dim': keep_dim, + 'reduce_all': True if dim is None else False + }) + block.append_op( + type='pow', + inputs={'X': sum_out}, + outputs={'Out': out}, + attrs={'factor': 1. / p}) + return out + + def __reshape_op(x, + shape, + out=None, + block=self.startup_program.global_block()): + if out is None: + out = block.create_var( + name=unique_name(".".join( + [self.name, 'weight_norm_reshape'])), + dtype=dtype, + persistable=False) + block.append_op( + type='reshape', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'shape': shape}) + return out + + def __transpose_op(x, + axis, + out=None, + block=self.startup_program.global_block()): + if out is None: + out = block.create_var( + name=unique_name(".".join( + [self.name, 'weight_norm_transpose'])), + dtype=dtype, + persistable=False) + block.append_op( + type='transpose', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'axis': axis}) + return out + + def __norm_except_dim(x, + out=None, + dim=None, + block=self.startup_program.global_block()): + """Computes the norm over all dimensions except dim""" + if out is None: + out = block.create_var( + name=unique_name(".".join([self.name, 'weight_norm_norm'])), + dtype=dtype, + persistable=False) + if dim is None: + __norm_op(x, out, dim=dim, block=block) + elif dim == 0: + out_shape = [x.shape[0]] + [1] * (len(x.shape) - 1) + reshape = __reshape_op(x, shape=[x.shape[0], -1], block=block) + norm = __norm_op(reshape, dim=1, block=block) + __reshape_op(norm, out=out, shape=out_shape, block=block) + elif dim == len(x.shape) - 1: + out_shape = [1] * (len(x.shape) - 1) + [x.shape[-1]] + reshape = __reshape_op(x, shape=[-1, x.shape[-1]], block=block) + norm = __norm_op(reshape, dim=0, block=block) + __reshape_op(norm, out=out, shape=out_shape, block=block) + else: + perm = range(len(x.shape)) + perm[0], perm[dim] = dim, 0 + transpose = __transpose_op(x, perm, block=block) + norm = __norm_op(transpose, dim=0, block=block) + __transpose_op(norm, perm, out=out, block=block) + return out + + def __weight_normalize(g, v, dim): + """Calculations for weight normalization""" + norm = __norm_except_dim( + v, dim=dim, block=self.main_program.current_block()) + scale = elementwise_div( + x=g, y=norm) # The shapes of g and norm are the same. + # Currently, elementwise_mul only support broadcast when the shape + # of y is a subset of x. Thus, we should reshape y to squeeze to + # achive it. + w = elementwise_mul( + x=v, + y=scale if dim is None else reshape( + x=scale, shape=[v.shape[dim]]), + axis=-1 if dim is None else dim) + # To serialize the original parameter for inference, maybe a + # parameter rather than a variable should be returned. + return w + + g_param_attr = copy.deepcopy(attr) + g_param_attr.name = attr.name + '_g' + g_param_shape = [1] * len(shape) + if attr.dim is not None: + g_param_shape[attr.dim] = shape[attr.dim] + v_param_attr = copy.deepcopy(attr) + v_param_attr.name = attr.name + '_v' + v_param_shape = shape + + # Add to startup_program to initialize g and v. + # Try to reconstruct the initializer of w by initializing g and v. + # Set the initializers of g and v as below, then the distribution + # of w is the same as initializing w with the given initializer. + # For Data-Dependent Initialization, please compute the init-values + # of g and v in external and then feed the values to g and v by + # executing an extra program. + g_param = self.startup_program.global_block().create_parameter( + dtype=dtype, + shape=g_param_shape, + **g_param_attr.to_kwargs(with_initializer=False)) + v_param = self.startup_program.global_block().create_parameter( + dtype=dtype, + shape=v_param_shape, + **v_param_attr.to_kwargs(with_initializer=True)) + __norm_except_dim( + x=v_param, + out=g_param, + dim=attr.dim, + block=self.startup_program.global_block()) + + # Add weight normalization to main_program + g_param = self.main_program.global_block().create_parameter( + dtype=dtype, shape=g_param_shape, **g_param_attr.to_kwargs()) + v_param = self.main_program.global_block().create_parameter( + dtype=dtype, shape=v_param_shape, **v_param_attr.to_kwargs()) + w_param = __weight_normalize(g_param, v_param, dim=attr.dim) + return w_param + def create_parameter(self, attr, shape, @@ -112,16 +283,23 @@ class LayerHelper(object): # Deepcopy the attr so that parameters can be shared in program assert isinstance(attr, ParamAttr) suffix = 'b' if is_bias else 'w' + if attr.name is None: + attr.name = unique_name(".".join([self.name, suffix])) - if default_initializer is None: + if default_initializer is None and attr.initializer is None: if is_bias: attr.set_default_bias_initializer() else: attr.set_default_param_initializer() else: attr.set_default_initializer(default_initializer) - if attr.name is None: - attr.name = unique_name(".".join([self.name, suffix])) + + # If weight normalization is set, insert extra parameters and ops. + # Refer to https://arxiv.org/pdf/1602.07868.pdf + if isinstance(attr, WeightNormParamAttr): + param = self._create_weight_normalize(attr, shape, dtype) + WeightNormParamAttr.params_with_weight_norm.append(param) + return param self.startup_program.global_block().create_parameter( dtype=dtype, shape=shape, **attr.to_kwargs(with_initializer=True)) diff --git a/python/paddle/v2/fluid/param_attr.py b/python/paddle/v2/fluid/param_attr.py index dcca8b6c54..1218e71ca1 100644 --- a/python/paddle/v2/fluid/param_attr.py +++ b/python/paddle/v2/fluid/param_attr.py @@ -82,3 +82,20 @@ class ParamAttr(object): if with_initializer: kwargs['initializer'] = self.initializer return kwargs + + +class WeightNormParamAttr(ParamAttr): + """ + Used for weight normalization. Any field in ParamAttr can also be set here. + Besides, an extra field dim can be set to indicate the dimension except + which to normalize. + """ + # List to record the parameters reparameterized by weight normalization. + # If these parameters are treated as Variable rather than Parameter, + # it can be used to discriminate these parameters and help to serialize + # these paramters for inference. + params_with_weight_norm = [] + + def __init__(self, dim=None, **kwargs): + super(WeightNormParamAttr, self).__init__(**kwargs) + self.dim = dim diff --git a/python/paddle/v2/fluid/tests/test_weight_normalization.py b/python/paddle/v2/fluid/tests/test_weight_normalization.py new file mode 100644 index 0000000000..200b5b9dc0 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_weight_normalization.py @@ -0,0 +1,122 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy +import collections +import paddle.v2.fluid as fluid +import paddle.v2.fluid.core as core +from paddle.v2.fluid.initializer import ConstantInitializer +from paddle.v2.fluid.param_attr import WeightNormParamAttr + + +class TestWeightNormalization(unittest.TestCase): + batch_size = 3 + hidden_size = 5 + data_desc = (['x', [10], 0], ) + + @classmethod + def setUpClass(cls): + cls.set_program() + + @classmethod + def set_program(cls): + data = fluid.layers.data( + name=cls.data_desc[0][0], shape=cls.data_desc[0][1]) + out = fluid.layers.fc(input=data, + size=cls.hidden_size, + param_attr=WeightNormParamAttr( + dim=None, + name='weight_norm_param', + initializer=ConstantInitializer(1.0)), + bias_attr=False, + act=None) + loss = fluid.layers.reduce_sum(out) + fluid.backward.append_backward(loss=loss) + cls.fetch_list = [ + 'weight_norm_param_g', 'weight_norm_param_v', + 'weight_norm_param_g@GRAD' + ] + + def run_program(self): + outputs = [] + places = [core.CPUPlace()] + if core.is_compile_gpu(): + places.append(core.CUDAPlace(0)) + for place in places: + self.set_inputs(place) + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + output = exe.run(fluid.default_main_program(), + feed=self.inputs, + fetch_list=self.fetch_list, + return_numpy=False) + outputs.append(output) + self.actual_outputs = outputs + + def set_data(self): + self.data = collections.OrderedDict() + for desc in self.data_desc: + data_name = desc[0] + data_shape = desc[1] + data_lod_level = desc[2] + data_lod = [] + for i in range(data_lod_level): + lod_level_i = numpy.random.randint( + low=1, + high=5, + size=self.batch_size if i == 0 else lod_level_i[-1]) + lod_level_i = [0] + numpy.cumsum(lod_level_i).tolist() + data_lod.append(lod_level_i) + data_value = numpy.random.random( + size=[data_lod[-1][-1] if data_lod else self.batch_size + ] + data_shape).astype('float32') + self.data[data_name] = (data_value, data_lod) + + def set_inputs(self, place): + self.inputs = {} + for desc in self.data_desc: + tensor = fluid.Tensor() + tensor.set(self.data[desc[0]][0], place) + if self.data[desc[0]][1]: + tensor.set_lod(self.data[desc[0]][1]) + self.inputs[desc[0]] = tensor + + def weight_normalize(self): + v = numpy.ones((self.data[self.data_desc[0][0]][0].shape[-1], + self.hidden_size)) + g = numpy.linalg.norm(v, axis=None, keepdims=True) + w = g * v / numpy.linalg.norm(v, axis=None, keepdims=True) + x = self.data[self.data_desc[0][0]][0] + out = numpy.dot(x, w) + g_grad = (numpy.dot(x.T, numpy.ones_like(out)) * (v / numpy.linalg.norm( + v, axis=None, keepdims=True))).sum(axis=None, keepdims=True) + return g, v, g_grad + + def test_weight_normalization(self): + self.set_data() + self.run_program() + expect_output = self.weight_normalize() + for actual_output in self.actual_outputs: + [ + self.assertTrue( + numpy.allclose( + numpy.array(actual_output), expect_output, atol=0.001)) + for expect_output, actual_output in zip(expect_output, + actual_output) + ] + + +if __name__ == '__main__': + unittest.main() -- GitLab