diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 2107b12b3a8f525fbf78b038a74a92a8d756f2f5..60c144d550028dbc103911ed12e8333792a829f6 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -20,6 +20,7 @@ import sys import paddle from .. import framework +from ..framework import convert_np_dtype_to_dtype_ from .. import core from .. import unique_name from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase @@ -172,25 +173,24 @@ def monkey_patch_varbase(): else: self.value().set_string_list(value) else: - value_np = value - if isinstance(value, base_tensor): - value_np = value.numpy() - - self_tensor_np = self.numpy() - - assert self_tensor_np.shape == value_np.shape, \ + assert self.shape == list(value.shape), \ "Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format( - self.name, self_tensor_np.shape, value_np.shape) + self.name, self.shape, value.shape) + + if isinstance(value, base_tensor): + dtype = value.dtype + else: + dtype = convert_np_dtype_to_dtype_(value.dtype) - assert self_tensor_np.dtype == value_np.dtype, \ + assert self.dtype == dtype, \ "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( - self.name, self_tensor_np.dtype, value_np.dtype) + self.name, self.dtype, dtype) # NOTE(wuweilong): self could be VarBase or Tensor, the subsequent behavior are defined in different files # if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc # if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc # this Interface behavior will be unifed in the future. - self.value().get_tensor().set(value_np, + self.value().get_tensor().set(value, framework._current_expected_place()) @framework.dygraph_only