From 1042f42ea60d3f1a862e6adc3fc3fbc27ab7112f Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Wed, 30 Mar 2022 14:23:53 +0800 Subject: [PATCH] remove set_value numpy (#41017) * remove set_value numpy * optimize code * optimize to_tensor * use common function Co-authored-by: root --- .../fluid/dygraph/varbase_patch_methods.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/paddle/fluid/dygraph/varbase_patch_methods.py b/python/paddle/fluid/dygraph/varbase_patch_methods.py index 2107b12b3a..60c144d550 100644 --- a/python/paddle/fluid/dygraph/varbase_patch_methods.py +++ b/python/paddle/fluid/dygraph/varbase_patch_methods.py @@ -20,6 +20,7 @@ import sys import paddle from .. import framework +from ..framework import convert_np_dtype_to_dtype_ from .. import core from .. import unique_name from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase @@ -172,25 +173,24 @@ def monkey_patch_varbase(): else: self.value().set_string_list(value) else: - value_np = value - if isinstance(value, base_tensor): - value_np = value.numpy() - - self_tensor_np = self.numpy() - - assert self_tensor_np.shape == value_np.shape, \ + assert self.shape == list(value.shape), \ "Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format( - self.name, self_tensor_np.shape, value_np.shape) + self.name, self.shape, value.shape) + + if isinstance(value, base_tensor): + dtype = value.dtype + else: + dtype = convert_np_dtype_to_dtype_(value.dtype) - assert self_tensor_np.dtype == value_np.dtype, \ + assert self.dtype == dtype, \ "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( - self.name, self_tensor_np.dtype, value_np.dtype) + self.name, self.dtype, dtype) # NOTE(wuweilong): self could be VarBase or Tensor, the subsequent behavior are defined in different files # if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc # if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc # this Interface behavior will be unifed in the future. - self.value().get_tensor().set(value_np, + self.value().get_tensor().set(value, framework._current_expected_place()) @framework.dygraph_only -- GitLab