提交 b9a1b5f2 编写于 作者: H Hui Zhang

patch func to var

上级 26524031
......@@ -131,12 +131,14 @@ if not hasattr(paddle.Tensor, 'long'):
"override long of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.long = func_long
paddle.static.Variable.long = func_long
if not hasattr(paddle.Tensor, 'numel'):
logger.debug(
"override numel of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.numel = paddle.numel
paddle.static.Variable.numel = paddle.numel
def new_full(x: paddle.Tensor,
......@@ -151,6 +153,7 @@ if not hasattr(paddle.Tensor, 'new_full'):
"override new_full of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.new_full = new_full
paddle.static.Variable.new_full = new_full
def eq(xs: paddle.Tensor, ys: Union[paddle.Tensor, float]) -> paddle.Tensor:
......@@ -166,6 +169,7 @@ if not hasattr(paddle.Tensor, 'eq'):
"override eq of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.eq = eq
paddle.static.Variable.eq = eq
if not hasattr(paddle, 'eq'):
logger.debug(
......@@ -182,6 +186,7 @@ if not hasattr(paddle.Tensor, 'contiguous'):
"override contiguous of paddle.Tensor if exists or register, remove this when fixed!"
)
paddle.Tensor.contiguous = contiguous
paddle.static.Variable.contiguous = contiguous
def size(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
......@@ -200,6 +205,7 @@ logger.debug(
"(`to_static` do not process `size` property, maybe some `paddle` api dependent on it), remove this when fixed!"
)
paddle.Tensor.size = size
paddle.static.Variable.size = size
def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
......@@ -209,6 +215,7 @@ def view(xs: paddle.Tensor, *args: int) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'view'):
logger.debug("register user view to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view = view
paddle.static.Variable.view = view
def view_as(xs: paddle.Tensor, ys: paddle.Tensor) -> paddle.Tensor:
......@@ -219,6 +226,7 @@ if not hasattr(paddle.Tensor, 'view_as'):
logger.debug(
"register user view_as to paddle.Tensor, remove this when fixed!")
paddle.Tensor.view_as = view_as
paddle.static.Variable.view_as = view_as
def is_broadcastable(shp1, shp2):
......@@ -246,6 +254,7 @@ if not hasattr(paddle.Tensor, 'masked_fill'):
logger.debug(
"register user masked_fill to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill = masked_fill
paddle.static.Variable.masked_fill = masked_fill
def masked_fill_(xs: paddle.Tensor,
......@@ -264,6 +273,7 @@ if not hasattr(paddle.Tensor, 'masked_fill_'):
logger.debug(
"register user masked_fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.masked_fill_ = masked_fill_
paddle.static.Variable.maksed_fill_ = masked_fill_
def fill_(xs: paddle.Tensor, value: Union[float, int]) -> paddle.Tensor:
......@@ -276,6 +286,7 @@ if not hasattr(paddle.Tensor, 'fill_'):
logger.debug(
"register user fill_ to paddle.Tensor, remove this when fixed!")
paddle.Tensor.fill_ = fill_
paddle.static.Variable.fill_ = fill_
def repeat(xs: paddle.Tensor, *size: Any) -> paddle.Tensor:
......@@ -286,6 +297,7 @@ if not hasattr(paddle.Tensor, 'repeat'):
logger.debug(
"register user repeat to paddle.Tensor, remove this when fixed!")
paddle.Tensor.repeat = repeat
paddle.static.Variable.repeat = repeat
if not hasattr(paddle.Tensor, 'softmax'):
logger.debug(
......@@ -310,6 +322,8 @@ if not hasattr(paddle.Tensor, 'type_as'):
logger.debug(
"register user type_as to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'type_as', type_as)
setattr(paddle.static.Variable, 'type_as', type_as)
def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
......@@ -325,6 +339,7 @@ def to(x: paddle.Tensor, *args, **kwargs) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'to'):
logger.debug("register user to to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'to', to)
setattr(paddle.static.Variable, 'to', to)
def func_float(x: paddle.Tensor) -> paddle.Tensor:
......@@ -335,6 +350,7 @@ if not hasattr(paddle.Tensor, 'float'):
logger.debug(
"register user float to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'float', func_float)
setattr(paddle.static.Variable, 'float', func_float)
def func_int(x: paddle.Tensor) -> paddle.Tensor:
......@@ -344,6 +360,7 @@ def func_int(x: paddle.Tensor) -> paddle.Tensor:
if not hasattr(paddle.Tensor, 'int'):
logger.debug("register user int to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'int', func_int)
setattr(paddle.static.Variable, 'int', func_int)
def tolist(x: paddle.Tensor) -> List[Any]:
......@@ -354,6 +371,8 @@ if not hasattr(paddle.Tensor, 'tolist'):
logger.debug(
"register user tolist to paddle.Tensor, remove this when fixed!")
setattr(paddle.Tensor, 'tolist', tolist)
setattr(paddle.static.Variable, 'tolist', tolist)
########### hack paddle.nn #############
from paddle.nn import Layer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册