未验证 提交 e930f496 编写于 作者: A Abhinav Arora 提交者: GitHub

Improve the initializer Interface for fc, sequence_conv and conv2d layers (#5760)

* Improve the initializer Interface for fc, sequence_conv and conv2d layers
* Fix some typos in python code
* Fix CI
上级 7ce06c8b
......@@ -15,6 +15,37 @@ def unique_name(prefix):
return "_".join([prefix, str(uid)])
def convert_np_dtype_to_dtype_(np_dtype):
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
elif dtype == np.float64:
return core.DataType.FP64
elif dtype == np.float16:
return core.DataType.FP16
elif dtype == np.int32:
return core.DataType.INT32
elif dtype == np.int16:
return core.DataType.INT16
elif dtype == np.int64:
return core.DataType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
def dtype_is_floating(dtype):
if not isinstance(dtype, core.DataType):
dtype = convert_np_dtype_to_dtype_(dtype)
if (dtype == core.DataType.FP16 or dtype == core.DataType.FP32 or
dtype == core.DataType.FP64):
return True
else:
return False
def _debug_string_(proto, throw_on_error=True):
error_fields = list()
if not proto.IsInitialized(error_fields) and throw_on_error:
......@@ -66,7 +97,7 @@ class Variable(object):
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
if not isinstance(dtype, core.DataType):
dtype = Variable._convert_np_dtype_to_dtype_(dtype)
dtype = convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_data_type(dtype)
else:
......@@ -148,26 +179,6 @@ class Variable(object):
uid = core.unique_integer(prefix) # unique during whole process.
return "_".join([prefix, str(uid)])
@staticmethod
def _convert_np_dtype_to_dtype_(np_dtype):
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
elif dtype == np.float64:
return core.DataType.FP64
elif dtype == np.float16:
return core.DataType.FP16
elif dtype == np.int32:
return core.DataType.INT32
elif dtype == np.int16:
return core.DataType.INT16
elif dtype == np.int64:
return core.DataType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
def get_all_op_protos():
"""
......
......@@ -2,7 +2,7 @@ import copy
import itertools
from paddle.v2.fluid.framework import Variable, g_main_program, \
g_startup_program, unique_name, Program
g_startup_program, unique_name, Program, dtype_is_floating
from paddle.v2.fluid.initializer import ConstantInitializer, \
UniformInitializer, XavierInitializer
......@@ -61,7 +61,7 @@ class LayerHelper(object):
@property
def param_attr(self):
default = {'name': None, 'initializer': XavierInitializer()}
default = {'name': None}
actual = self.kwargs.get('param_attr', None)
if actual is None:
actual = default
......@@ -72,7 +72,7 @@ class LayerHelper(object):
@property
def bias_attr(self):
default = {'name': None, 'initializer': ConstantInitializer()}
default = {'name': None}
bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is None:
bias_attr = default
......@@ -119,6 +119,8 @@ class LayerHelper(object):
attr_copy = copy.deepcopy(attr)
if initializer is not None:
attr_copy['initializer'] = initializer
else:
attr_copy['initializer'] = self._get_default_initializer(dtype)
if attr_copy['name'] is None:
attr_copy['name'] = unique_name(".".join([self.name, suffix]))
self.startup_program.global_block().create_parameter(
......@@ -149,13 +151,19 @@ class LayerHelper(object):
persistable=True,
initializer=initializer)
def append_bias_op(self, input_var, dim_start=1, dim_end=None):
def append_bias_op(self,
input_var,
bias_initializer,
dim_start=1,
dim_end=None):
"""
Append bias operator and return its output. If the user does not set
bias_attr, append_bias_op will return input_var
:param input_var: the input variable. The len(input_var.shape) is larger
or equal than 2.
:param input_var: the input variable. The len(input_var.shape) is
larger or equal than 2.
:bias_initializer: an instance of a subclass of Initializer used to
initialize the bias
:param dim_start:
:param dim_end: the shape of the bias will be
input_var.shape[dim_start:dim_end]. The bias is broadcasted to other
......@@ -167,7 +175,11 @@ class LayerHelper(object):
return input_var
b = self.create_parameter(
attr=bias_attr, shape=size, dtype=input_var.data_type, suffix='b')
attr=bias_attr,
shape=size,
dtype=input_var.data_type,
suffix='b',
initializer=bias_initializer)
tmp = self.create_tmp_variable(dtype=input_var.data_type)
self.append_op(
type='elementwise_add',
......@@ -191,3 +203,10 @@ class LayerHelper(object):
outputs={"Y": [tmp]},
attrs=act)
return tmp
def _get_default_initializer(self, dtype):
if dtype is None or dtype_is_floating(dtype) is True:
return XavierInitializer()
else:
# For integer and boolean types, initialize with all zeros
return ConstantInitializer()
......@@ -3,7 +3,7 @@ import paddle.v2.fluid.proto.framework_pb2 as framework_pb2
from paddle.v2.fluid.framework import OpProtoHolder, Variable, Program, \
Operator
from paddle.v2.fluid.initializer import ConstantInitializer, \
NormalInitializer
NormalInitializer, XavierInitializer
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
import re
import cStringIO
......@@ -18,7 +18,9 @@ __all__ = [
def fc(input,
size,
param_attr=None,
param_initializer=None,
bias_attr=None,
bias_initializer=None,
name=None,
act=None,
num_flatten_dims=1,
......@@ -31,7 +33,11 @@ def fc(input,
input: The input tensor to the function
size: The size of the layer
param_attr: The parameters/weights to the FC Layer
param_initializer: Initializer used for the weight/parameter.
If None, XavierInitializer() is used
bias_attr: The bias parameter for the FC layer
bias_initializer: Initializer used for the bias.
If None, then ConstantInitializer() is used
name: Name/alias of the function
act: Activation to be applied to the output of FC layer
num_flatten_dims: Number of columns in input
......@@ -50,10 +56,23 @@ def fc(input,
to the LayerHelper constructor.
"""
def _get_default_param_initializer():
return XavierInitializer()
def _get_default_bias_initializer():
return ConstantInitializer()
helper = LayerHelper('fc', **locals())
dtype = helper.input_dtype()
if param_initializer is None:
param_initializer = _get_default_param_initializer()
if bias_initializer is None:
bias_initializer = _get_default_bias_initializer()
mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape
......@@ -61,7 +80,10 @@ def fc(input,
reduce(lambda a, b: a * b, input_shape[num_flatten_dims:], 1)
] + [size]
w = helper.create_parameter(
attr=param_attr, shape=param_shape, dtype=dtype)
attr=param_attr,
initializer=param_initializer,
shape=param_shape,
dtype=dtype)
tmp = helper.create_tmp_variable(dtype)
helper.append_op(
type="mul",
......@@ -82,7 +104,7 @@ def fc(input,
helper.append_op(
type="sum", inputs={"X": mul_results}, outputs={"Out": pre_bias})
# add bias
pre_activation = helper.append_bias_op(pre_bias)
pre_activation = helper.append_bias_op(pre_bias, bias_initializer)
# add activation
return helper.append_activation(pre_activation)
......@@ -599,7 +621,9 @@ def sequence_conv(input,
act=None,
padding=None,
bias_attr=None,
bias_initializer=None,
param_attr=None,
param_initializer=None,
main_program=None,
startup_program=None):
"""
......@@ -607,6 +631,13 @@ def sequence_conv(input,
other convolutional configurations for the filters and stride as given
in the input parameters to the function.
"""
def _get_default_bias_initializer():
return ConstantInitializer()
def _get_default_param_initializer():
return XavierInitializer()
# FIXME(dzh) : want to unify the argument of python layer
# function. So we ignore some unecessary attributes.
# such as, padding_trainable, context_start.
......@@ -614,9 +645,17 @@ def sequence_conv(input,
helper = LayerHelper('sequence_conv', **locals())
dtype = helper.input_dtype()
if param_initializer is None:
param_initializer = _get_default_param_initializer()
if bias_initializer is None:
bias_initializer = _get_default_bias_initializer()
filter_shape = [filter_size * input.shape[1], num_filters]
filter = helper.create_parameter(
attr=helper.param_attr, shape=filter_shape, dtype=dtype)
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
initializer=param_initializer)
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
......@@ -631,7 +670,7 @@ def sequence_conv(input,
'contextStart': -int(filter_size / 2),
'contextLength': filter_size
})
pre_act = helper.append_bias_op(pre_bias)
pre_act = helper.append_bias_op(pre_bias, bias_initializer)
return helper.append_activation(pre_act)
......@@ -644,7 +683,9 @@ def conv2d(input,
stride=[1, 1],
padding=None,
bias_attr=None,
bias_initializer=None,
param_attr=None,
param_initializer=None,
main_program=None,
startup_program=None):
"""
......@@ -654,6 +695,14 @@ def conv2d(input,
This funciton can also append an activation on top of the
conv-2d output, if mentioned in the input parameters.
"""
def _get_default_bias_initializer():
return ConstantInitializer()
def _get_default_param_initializer(filter_size, num_channels):
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
return NormalInitializer(0.0, std, 0)
helper = LayerHelper('conv2d', **locals())
dtype = helper.input_dtype()
......@@ -675,12 +724,17 @@ def conv2d(input,
input_shape = input.shape
filter_shape = [num_filters, num_filter_channels] + filter_size
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
if param_initializer is None:
param_initializer = _get_default_param_initializer(filter_size,
num_channels)
if bias_initializer is None:
bias_initializer = _get_default_bias_initializer()
filter = helper.create_parameter(
attr=helper.param_attr,
shape=filter_shape,
dtype=dtype,
initializer=NormalInitializer(0.0, std, 0))
initializer=param_initializer)
pre_bias = helper.create_tmp_variable(dtype)
helper.append_op(
......@@ -694,7 +748,8 @@ def conv2d(input,
'paddings': padding,
'groups': groups})
pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
pre_act = helper.append_bias_op(
pre_bias, bias_initializer, dim_start=1, dim_end=2)
return helper.append_activation(pre_act)
......
import unittest
from paddle.v2.fluid.framework import Variable, g_main_program, Program
from paddle.v2.fluid.framework import g_main_program, Program, convert_np_dtype_to_dtype_
import paddle.v2.fluid.core as core
import numpy as np
......@@ -7,7 +7,7 @@ import numpy as np
class TestVariable(unittest.TestCase):
def test_np_dtype_convert(self):
DT = core.DataType
convert = Variable._convert_np_dtype_to_dtype_
convert = convert_np_dtype_to_dtype_
self.assertEqual(DT.FP32, convert(np.float32))
self.assertEqual(DT.FP16, convert("float16"))
self.assertEqual(DT.FP64, convert("float64"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册