未验证 提交 c18aa8a3 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Optimize python api to speed up eager exec (#45262)

* optimize python api to speed up eager exec

* optimize python api to speed up eager exec

* optimize python api to speed up eager exec
上级 e3574f72
...@@ -379,8 +379,9 @@ def monkey_patch_math_varbase(): ...@@ -379,8 +379,9 @@ def monkey_patch_math_varbase():
('__rtruediv__', ('__rtruediv__',
_binary_creator_('rtruediv__', 'elementwise_div', True, None)), _binary_creator_('rtruediv__', 'elementwise_div', True, None)),
('__pow__', ('__pow__',
_binary_creator_('__pow__', 'final_state_elementwise_pow', False, None, _binary_creator_('__pow__', 'final_state_elementwise_pow', False,
True)) if framework._in_eager_mode_ else _C_ops.final_state_pow, True))
if framework._in_eager_mode_ else
('__pow__', ('__pow__',
_binary_creator_('__pow__', 'elementwise_pow', False, None)), _binary_creator_('__pow__', 'elementwise_pow', False, None)),
('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True, ('__rpow__', _binary_creator_('__rpow__', 'elementwise_pow', True,
......
...@@ -736,7 +736,9 @@ def floor_divide(x, y, name=None): ...@@ -736,7 +736,9 @@ def floor_divide(x, y, name=None):
""" """
op_type = 'elementwise_floordiv' op_type = 'elementwise_floordiv'
axis = -1 axis = -1
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_floor_divide(x, y)
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=axis, op_name=op_type) x, y, axis=axis, op_name=op_type)
...@@ -776,7 +778,9 @@ def remainder(x, y, name=None): ...@@ -776,7 +778,9 @@ def remainder(x, y, name=None):
""" """
op_type = 'elementwise_mod' op_type = 'elementwise_mod'
axis = -1 axis = -1
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_modulo(x, y)
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=axis, op_name=op_type) x, y, axis=axis, op_name=op_type)
...@@ -894,7 +898,9 @@ def maximum(x, y, name=None): ...@@ -894,7 +898,9 @@ def maximum(x, y, name=None):
op_type = 'elementwise_max' op_type = 'elementwise_max'
axis = -1 axis = -1
act = None act = None
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_maximum(x, y)
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type) x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals())) return _elementwise_op(LayerHelper(op_type, **locals()))
...@@ -953,7 +959,9 @@ def minimum(x, y, name=None): ...@@ -953,7 +959,9 @@ def minimum(x, y, name=None):
op_type = 'elementwise_min' op_type = 'elementwise_min'
axis = -1 axis = -1
act = None act = None
if paddle.in_dynamic_mode(): if in_dygraph_mode():
return _C_ops.final_state_minimum(x, y)
if _in_legacy_dygraph():
return _elementwise_op_in_dygraph( return _elementwise_op_in_dygraph(
x, y, axis=axis, act=act, op_name=op_type) x, y, axis=axis, act=act, op_name=op_type)
return _elementwise_op(LayerHelper(op_type, **locals())) return _elementwise_op(LayerHelper(op_type, **locals()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册