未验证 提交 3792a49d 编写于 作者: L LielinJiang 提交者: GitHub

Fix bilinear_initializer bug when type of input data is float64 (#24771)

* fix bilinear initializer, test=develop
上级 9bbb9542
......@@ -783,7 +783,7 @@ class BilinearInitializer(Initializer):
weight = np.reshape(weight, shape)
# to be compatible of fp16 initalizers
if var.dtype == VarDesc.VarType.FP16:
if var.dtype == VarDesc.VarType.FP16 or var.dtype == VarDesc.VarType.FP64:
out_dtype = VarDesc.VarType.FP32
out_var = block.create_var(
name=unique_name.generate(".".join(
......@@ -800,7 +800,8 @@ class BilinearInitializer(Initializer):
value_name = "fp32_values"
values = [float(v) for v in weight.flat]
else:
raise ValueError("Unsupported dtype %s", input.dtype)
raise TypeError("Unsupported dtype %s", var.dtype)
if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ")
op = block.append_op(
......@@ -812,7 +813,7 @@ class BilinearInitializer(Initializer):
value_name: values
})
if var.dtype == VarDesc.VarType.FP16:
if var.dtype == VarDesc.VarType.FP16 or var.dtype == VarDesc.VarType.FP64:
block.append_op(
type="cast",
inputs={"X": out_var},
......
......@@ -474,18 +474,24 @@ class TestBilinearInitializer(unittest.TestCase):
lod_level=0,
name="param",
initializer=initializer.BilinearInitializer())
num_ops = 2 if dtype == "float16" else 1
num_ops = 2 if dtype == "float16" or dtype == "float64" else 1
self.assertEqual(len(block.ops), num_ops)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'assign_value')
return block
def test_bilinear_initializer_fp64(self):
self.test_bilinear_initializer(dtype='float64')
def test_bilinear_initializer_fp16(self):
"""Test the bilinear initializer with supplied arguments
"""
block = self.test_bilinear_initializer("float16")
self.assertTrue(check_cast_op(block.ops[1]))
def test_type_error(self):
self.assertRaises(TypeError, self.test_bilinear_initializer, 'int32')
class TestNumpyArrayInitializer(unittest.TestCase):
def test_numpy_array_initializer(self, dtype="float32"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册