提交 fccaa4c5 编写于 作者: K Kavya Srinet

Fixing python tests

上级 8ab2b5ce
...@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype): ...@@ -67,24 +67,24 @@ def convert_np_dtype_to_dtype_(np_dtype):
Args: Args:
np_dtype(np.dtype): the data type in numpy np_dtype(np.dtype): the data type in numpy
Returns(core.DataType): the data type in Paddle Returns(core.VarType): the data type in Paddle
""" """
dtype = np.dtype(np_dtype) dtype = np.dtype(np_dtype)
if dtype == np.float32: if dtype == np.float32:
return core.DataType.FP32 return core.VarType.FP32
elif dtype == np.float64: elif dtype == np.float64:
return core.DataType.FP64 return core.VarType.FP64
elif dtype == np.float16: elif dtype == np.float16:
return core.DataType.FP16 return core.VarType.FP16
elif dtype == np.int32: elif dtype == np.int32:
return core.DataType.INT32 return core.VarType.INT32
elif dtype == np.int16: elif dtype == np.int16:
return core.DataType.INT16 return core.VarType.INT16
elif dtype == np.int64: elif dtype == np.int64:
return core.DataType.INT64 return core.VarType.INT64
elif dtype == np.bool: elif dtype == np.bool:
return core.DataType.BOOL return core.VarType.BOOL
else: else:
raise ValueError("Not supported numpy dtype " + str(dtype)) raise ValueError("Not supported numpy dtype " + str(dtype))
...@@ -99,10 +99,10 @@ def dtype_is_floating(dtype): ...@@ -99,10 +99,10 @@ def dtype_is_floating(dtype):
Returns(bool): True if data type is a float value Returns(bool): True if data type is a float value
""" """
if not isinstance(dtype, core.DataType): if not isinstance(dtype, core.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
return dtype in [core.DataType.FP16, core.DataType.FP32, core.DataType.FP64] return dtype in [core.VarType.FP16, core.VarType.FP32, core.VarType.FP64]
def _debug_string_(proto, throw_on_error=True): def _debug_string_(proto, throw_on_error=True):
...@@ -200,7 +200,7 @@ class Variable(object): ...@@ -200,7 +200,7 @@ class Variable(object):
"shape is {1}; the new shape is {2}. They are not " "shape is {1}; the new shape is {2}. They are not "
"matched.".format(self.name, old_shape, shape)) "matched.".format(self.name, old_shape, shape))
if dtype is not None: if dtype is not None:
if not isinstance(dtype, core.DataType): if not isinstance(dtype, core.VarType):
dtype = convert_np_dtype_to_dtype_(dtype) dtype = convert_np_dtype_to_dtype_(dtype)
if is_new_var: if is_new_var:
self.desc.set_dtype(dtype) self.desc.set_dtype(dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册