未验证 提交 3792a49d 编写于 作者: L LielinJiang 提交者: GitHub

Fix bilinear_initializer bug when type of input data is float64 (#24771)

* fix bilinear initializer, test=develop
上级 9bbb9542
...@@ -783,7 +783,7 @@ class BilinearInitializer(Initializer): ...@@ -783,7 +783,7 @@ class BilinearInitializer(Initializer):
weight = np.reshape(weight, shape) weight = np.reshape(weight, shape)
# to be compatible of fp16 initalizers # to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16 or var.dtype == VarDesc.VarType.FP64:
out_dtype = VarDesc.VarType.FP32 out_dtype = VarDesc.VarType.FP32
out_var = block.create_var( out_var = block.create_var(
name=unique_name.generate(".".join( name=unique_name.generate(".".join(
...@@ -800,7 +800,8 @@ class BilinearInitializer(Initializer): ...@@ -800,7 +800,8 @@ class BilinearInitializer(Initializer):
value_name = "fp32_values" value_name = "fp32_values"
values = [float(v) for v in weight.flat] values = [float(v) for v in weight.flat]
else: else:
raise ValueError("Unsupported dtype %s", input.dtype) raise TypeError("Unsupported dtype %s", var.dtype)
if np.prod(shape) > 1024 * 1024: if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ") raise ValueError("The size of input is too big. ")
op = block.append_op( op = block.append_op(
...@@ -812,7 +813,7 @@ class BilinearInitializer(Initializer): ...@@ -812,7 +813,7 @@ class BilinearInitializer(Initializer):
value_name: values value_name: values
}) })
if var.dtype == VarDesc.VarType.FP16: if var.dtype == VarDesc.VarType.FP16 or var.dtype == VarDesc.VarType.FP64:
block.append_op( block.append_op(
type="cast", type="cast",
inputs={"X": out_var}, inputs={"X": out_var},
......
...@@ -474,18 +474,24 @@ class TestBilinearInitializer(unittest.TestCase): ...@@ -474,18 +474,24 @@ class TestBilinearInitializer(unittest.TestCase):
lod_level=0, lod_level=0,
name="param", name="param",
initializer=initializer.BilinearInitializer()) initializer=initializer.BilinearInitializer())
num_ops = 2 if dtype == "float16" else 1 num_ops = 2 if dtype == "float16" or dtype == "float64" else 1
self.assertEqual(len(block.ops), num_ops) self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0] init_op = block.ops[0]
self.assertEqual(init_op.type, 'assign_value') self.assertEqual(init_op.type, 'assign_value')
return block return block
def test_bilinear_initializer_fp64(self):
self.test_bilinear_initializer(dtype='float64')
def test_bilinear_initializer_fp16(self): def test_bilinear_initializer_fp16(self):
"""Test the bilinear initializer with supplied arguments """Test the bilinear initializer with supplied arguments
""" """
block = self.test_bilinear_initializer("float16") block = self.test_bilinear_initializer("float16")
self.assertTrue(check_cast_op(block.ops[1])) self.assertTrue(check_cast_op(block.ops[1]))
def test_type_error(self):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')
class TestNumpyArrayInitializer(unittest.TestCase): class TestNumpyArrayInitializer(unittest.TestCase):
def test_numpy_array_initializer(self, dtype="float32"): def test_numpy_array_initializer(self, dtype="float32"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册