From ab4715840d0da3cde6f024fd5268f4d55701bbba Mon Sep 17 00:00:00 2001 From: Wu Yi Date: Tue, 29 Jan 2019 09:25:16 +0800 Subject: [PATCH] fix default create_parameter dtype maching initializers (#15521) * fix default create_parameter dtype maching initializers test=develop * update type check test=develop * update test=develop --- python/paddle/fluid/layer_helper.py | 11 +++++++++++ python/paddle/fluid/tests/unittests/test_layers.py | 3 ++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/layer_helper.py b/python/paddle/fluid/layer_helper.py index 972c51938f2..a172141b3a0 100644 --- a/python/paddle/fluid/layer_helper.py +++ b/python/paddle/fluid/layer_helper.py @@ -300,6 +300,17 @@ class LayerHelper(object): attr.name = unique_name.generate(".".join([self.name, suffix])) if default_initializer is None and attr.initializer is None: + if isinstance(dtype, core.VarDesc.VarType): + if dtype != core.VarDesc.VarType.FP32 and \ + dtype != core.VarDesc.VarType.FP64: + raise TypeError( + "Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" + ) + else: + if not (dtype.startswith("float") or dtype == "double"): + raise TypeError( + "Can not create parameter with default initializer when dtype is not float type. Set default_initializer to fit the parameter dtype!" + ) if is_bias: attr._set_default_bias_initializer() else: diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index c13f03e86f3..e7bc1601a54 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -58,7 +58,8 @@ class TestBook(unittest.TestCase): def test_simple_conv2d(self): program = Program() with program_guard(program, startup_program=Program()): - images = layers.data(name='pixel', shape=[3, 48, 48], dtype='int32') + images = layers.data( + name='pixel', shape=[3, 48, 48], dtype='float32') layers.conv2d(input=images, num_filters=3, filter_size=[4, 4]) print(str(program)) -- GitLab