未验证 提交 1042f42e 编写于 作者: C crystal 提交者: GitHub

remove set_value numpy (#41017)

* remove set_value numpy

* optimize code

* optimize to_tensor

* use common function
Co-authored-by: Nroot <root@yq01-sys-hic-k8s-v100-box-a225-0186.yq01.baidu.com>
上级 95265d5c
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
import paddle import paddle
from .. import framework from .. import framework
from ..framework import convert_np_dtype_to_dtype_
from .. import core from .. import core
from .. import unique_name from .. import unique_name
from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase from ..framework import Variable, Parameter, ParamBase, _getitem_impl_, _setitem_impl_, EagerParamBase
...@@ -172,25 +173,24 @@ def monkey_patch_varbase(): ...@@ -172,25 +173,24 @@ def monkey_patch_varbase():
else: else:
self.value().set_string_list(value) self.value().set_string_list(value)
else: else:
value_np = value assert self.shape == list(value.shape), \
if isinstance(value, base_tensor):
value_np = value.numpy()
self_tensor_np = self.numpy()
assert self_tensor_np.shape == value_np.shape, \
"Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format( "Variable Shape not match, Variable [ {} ] need tensor with shape {} but load set tensor with shape {}".format(
self.name, self_tensor_np.shape, value_np.shape) self.name, self.shape, value.shape)
if isinstance(value, base_tensor):
dtype = value.dtype
else:
dtype = convert_np_dtype_to_dtype_(value.dtype)
assert self_tensor_np.dtype == value_np.dtype, \ assert self.dtype == dtype, \
"Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format( "Variable dtype not match, Variable [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
self.name, self_tensor_np.dtype, value_np.dtype) self.name, self.dtype, dtype)
# NOTE(wuweilong): self could be VarBase or Tensor, the subsequent behavior are defined in different files # NOTE(wuweilong): self could be VarBase or Tensor, the subsequent behavior are defined in different files
# if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc # if self is VarBase, method value() return Variable that bindded in imperative.cc, get_tensor() bindded in pybind.cc
# if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc # if self is Tensor, method value() return self that defined in this file, get_tensor() defined in eager_method.cc
# this Interface behavior will be unifed in the future. # this Interface behavior will be unifed in the future.
self.value().get_tensor().set(value_np, self.value().get_tensor().set(value,
framework._current_expected_place()) framework._current_expected_place())
@framework.dygraph_only @framework.dygraph_only
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册