提交 fccaa4c5 编写于 作者: K Kavya Srinet

Fixing python tests

上级 8ab2b5ce
......@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype):
Args:
np_dtype(np.dtype): the data type in numpy
Returns(core.DataType): the data type in Paddle
Returns(core.VarType): the data type in Paddle
"""
dtype = np.dtype(np_dtype)
if dtype == np.float32:
return core.DataType.FP32
return core.VarType.FP32
elif dtype == np.float64:
return core.DataType.FP64
return core.VarType.FP64
elif dtype == np.float16:
return core.DataType.FP16
return core.VarType.FP16
elif dtype == np.int32:
return core.DataType.INT32
return core.VarType.INT32
elif dtype == np.int16:
return core.DataType.INT16
return core.VarType.INT16
elif dtype == np.int64:
return core.DataType.INT64
return core.VarType.INT64
elif dtype == np.bool:
return core.DataType.BOOL
return core.VarType.BOOL
else:
raise ValueError("Not supported numpy dtype " + str(dtype))
......@@ -99,10 +99,10 @@ def dtype_is_floating(dtype):
Returns(bool): True if data type is a float value
"""
if not isinstance(dtype, core.DataType):
if not isinstance(dtype, core.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
return dtype in [core.DataType.FP16, core.DataType.FP32, core.DataType.FP64]
return dtype in [core.VarType.FP16, core.VarType.FP32, core.VarType.FP64]
def _debug_string_(proto, throw_on_error=True):
......@@ -200,7 +200,7 @@ class Variable(object):
"shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape))
if dtype is not None:
if not isinstance(dtype, core.DataType):
if not isinstance(dtype, core.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if is_new_var:
self.desc.set_dtype(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册