提交 a5c9e6ac 编写于 作者: X xuwei06

Fix conv2d bias

The size of the bias parameter should be the number of filters.
上级 374e1685
...@@ -35,7 +35,7 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None): ...@@ -35,7 +35,7 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
:param executor: executor that save variable :param executor: executor that save variable
:param dirname: directory path :param dirname: directory path
:param main_program: program. If vars is None, then filter all variables in this :param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program. program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable :param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be saved. as a bool. If it returns true, the variables will be saved.
...@@ -96,11 +96,11 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None): ...@@ -96,11 +96,11 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
:param executor: executor that save variable :param executor: executor that save variable
:param dirname: directory path :param dirname: directory path
:param main_program: program. If vars is None, then filter all variables in this :param main_program: program. If vars is None, then filter all variables in this
program which fit `predicate`. Default g_program. program which fit `predicate`. Default g_program.
:param predicate: The Predicate describes a callable that returns a variable :param predicate: The Predicate describes a callable that returns a variable
as a bool. If it returns true, the variables will be loaded. as a bool. If it returns true, the variables will be loaded.
:param vars: variables need to be loaded. If specify vars, program & :param vars: variables need to be loaded. If specify vars, program &
predicate will be ignored predicate will be ignored
:return: None :return: None
""" """
...@@ -157,15 +157,15 @@ def save_inference_model(dirname, ...@@ -157,15 +157,15 @@ def save_inference_model(dirname,
executor, executor,
main_program=None): main_program=None):
""" """
Build a model especially for inference, Build a model especially for inference,
and save it to directory by the executor. and save it to directory by the executor.
:param dirname: directory path :param dirname: directory path
:param feeded_var_names: Names of variables that need to be feeded data during inference :param feeded_var_names: Names of variables that need to be feeded data during inference
:param target_vars: Variables from which we can get inference results. :param target_vars: Variables from which we can get inference results.
:param executor: executor that save inference model :param executor: executor that save inference model
:param main_program: original program, which will be pruned to build the inference model. :param main_program: original program, which will be pruned to build the inference model.
Default g_program. Default g_main_program.
:return: None :return: None
""" """
...@@ -234,3 +234,34 @@ def load_inference_model(dirname, executor): ...@@ -234,3 +234,34 @@ def load_inference_model(dirname, executor):
fetch_vars = [program.global_block().var(name) for name in fetch_var_names] fetch_vars = [program.global_block().var(name) for name in fetch_var_names]
return [program, feed_var_names, fetch_vars] return [program, feed_var_names, fetch_vars]
def get_parameter_value(para, executor):
"""
Get the LoDTensor for the parameter
:param executor: executor for retrieving the value
:param para: the given parameter
:return: the LoDTensor for the parameter
"""
get_program = Program()
block = get_program.global_block()
new_var = _clone_var_in_block_(block, para)
return executor.run(get_program, feed={}, fetch_list=[new_var])[0]
def get_parameter_value_by_name(name, executor, program=None):
"""
Get the LoDTensor for paramter with the given name
:param executor: executor for retrieving the value
:param name: the name of the parameter
:param program: the program where the variable is found
Default g_main_program.
:return: the LoDTensor for the variable
"""
if program is None:
program = g_main_program
var = program.global_block().var(name)
assert is_parameter(var)
return get_parameter_value(var, executor)
...@@ -72,7 +72,7 @@ class LayerHelper(object): ...@@ -72,7 +72,7 @@ class LayerHelper(object):
@property @property
def bias_attr(self): def bias_attr(self):
default = {'name': None, 'initializer': XavierInitializer()} default = {'name': None, 'initializer': ConstantInitializer()}
bias_attr = self.kwargs.get('bias_attr', None) bias_attr = self.kwargs.get('bias_attr', None)
if bias_attr is None: if bias_attr is None:
bias_attr = default bias_attr = default
...@@ -149,24 +149,19 @@ class LayerHelper(object): ...@@ -149,24 +149,19 @@ class LayerHelper(object):
persistable=True, persistable=True,
initializer=initializer) initializer=initializer)
def append_bias_op(self, input_var, num_flatten_dims=None): def append_bias_op(self, input_var, dim_start=1, dim_end=None):
""" """
Append bias operator and return its output. If the user does not set Append bias operator and return its output. If the user does not set
bias_attr, append_bias_op will return input_var bias_attr, append_bias_op will return input_var
:param input_var: the input variable. The len(input_var.shape) is larger :param input_var: the input variable. The len(input_var.shape) is larger
or equal than 2. or equal than 2.
:param num_flatten_dims: The input tensor will be flatten as a matrix :param dim_start:
when adding bias. :param dim_end: the shape of the bias will be
`matrix.shape = product(input_var.shape[0:num_flatten_dims]), product( input_var.shape(dim_start:dim_end). The bias is broadcast to other
input_var.shape[num_flatten_dims:])` dimensions and added to input_var to get the output
""" """
if num_flatten_dims is None: size = list(input_var.shape[dim_start:dim_end])
num_flatten_dims = self.kwargs.get('num_flatten_dims', None)
if num_flatten_dims is None:
num_flatten_dims = 1
size = list(input_var.shape[num_flatten_dims:])
bias_attr = self.bias_attr bias_attr = self.bias_attr
if not bias_attr: if not bias_attr:
return input_var return input_var
...@@ -178,7 +173,8 @@ class LayerHelper(object): ...@@ -178,7 +173,8 @@ class LayerHelper(object):
type='elementwise_add', type='elementwise_add',
inputs={'X': [input_var], inputs={'X': [input_var],
'Y': [b]}, 'Y': [b]},
outputs={'Out': [tmp]}) outputs={'Out': [tmp]},
attrs={'axis': dim_start})
return tmp return tmp
def append_activation(self, input_var): def append_activation(self, input_var):
......
...@@ -250,7 +250,7 @@ def _convert_(name): ...@@ -250,7 +250,7 @@ def _convert_(name):
def _generate_doc_string_(op_proto): def _generate_doc_string_(op_proto):
""" """
Generate docstring by OpProto Generate docstring by OpProto
Args: Args:
op_proto (framework_pb2.OpProto): a protobuf message typed OpProto op_proto (framework_pb2.OpProto): a protobuf message typed OpProto
...@@ -676,6 +676,7 @@ def conv2d(input, ...@@ -676,6 +676,7 @@ def conv2d(input,
filter_shape = [num_filters, num_filter_channels] + filter_size filter_shape = [num_filters, num_filter_channels] + filter_size
std = (2.0 / (filter_size[0]**2 * num_channels))**0.5 std = (2.0 / (filter_size[0]**2 * num_channels))**0.5
print 'name=', name, 'std=', std
filter = helper.create_parameter( filter = helper.create_parameter(
attr=helper.param_attr, attr=helper.param_attr,
shape=filter_shape, shape=filter_shape,
...@@ -694,7 +695,7 @@ def conv2d(input, ...@@ -694,7 +695,7 @@ def conv2d(input,
'paddings': padding, 'paddings': padding,
'groups': groups}) 'groups': groups})
pre_act = helper.append_bias_op(pre_bias, 1) pre_act = helper.append_bias_op(pre_bias, dim_start=1, dim_end=2)
return helper.append_activation(pre_act) return helper.append_activation(pre_act)
......
import unittest import unittest
from paddle.v2.fluid.framework import g_main_program from paddle.v2.fluid.framework import g_main_program
import paddle.v2.fluid.core as core import paddle.v2.fluid.core as core
from paddle.v2.fluid.executor import Executor
import paddle.v2.fluid.io as io
from paddle.v2.fluid.initializer import ConstantInitializer
import numpy as np
class TestParameter(unittest.TestCase): class TestParameter(unittest.TestCase):
def test_param(self): def test_param(self):
b = g_main_program.create_block() shape = [784, 100]
val = 1.0625
b = g_main_program.global_block()
param = b.create_parameter( param = b.create_parameter(
name='fc.w', name='fc.w',
shape=[784, 100], shape=shape,
dtype='float32', dtype='float32',
initialize_attr={ initializer=ConstantInitializer(val))
'type': 'uniform_random',
'seed': 13,
'min': -5.0,
'max': 5.0
})
self.assertIsNotNone(param) self.assertIsNotNone(param)
self.assertEqual('fc.w', param.name) self.assertEqual('fc.w', param.name)
self.assertEqual((784, 100), param.shape) self.assertEqual((784, 100), param.shape)
self.assertEqual(core.DataType.FP32, param.data_type) self.assertEqual(core.DataType.FP32, param.data_type)
self.assertEqual(0, param.block.idx) self.assertEqual(0, param.block.idx)
exe = Executor(core.CPUPlace())
p = exe.run(g_main_program, fetch_list=[param])[0]
self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val))
p = io.get_parameter_value_by_name('fc.w', exe, g_main_program)
self.assertTrue(np.allclose(np.array(p), np.ones(shape) * val))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册