未验证 提交 50714d5c 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix eager no take effect problem (#41291)

* [Eager]Fix eager no take effect problem

* add element_wise and fix greater_than
上级 78200976
...@@ -1279,6 +1279,15 @@ static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args, ...@@ -1279,6 +1279,15 @@ static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor_method_element_size(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
uint32_t element_size = framework::DataTypeSize(self->tensor.dtype());
return ToPyObject(element_size);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__bump_inplace_version(TensorObject* self, static PyObject* tensor__bump_inplace_version(TensorObject* self,
PyObject* args, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
...@@ -1417,6 +1426,8 @@ PyMethodDef variable_methods[] = { ...@@ -1417,6 +1426,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"to_dense", (PyCFunction)(void (*)(void))tensor_method_to_dense, {"to_dense", (PyCFunction)(void (*)(void))tensor_method_to_dense,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)(void (*)(void))tensor_method_element_size,
METH_VARARGS | METH_KEYWORDS, NULL},
/***the method of sparse tensor****/ /***the method of sparse tensor****/
{"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version, {"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
......
...@@ -48,7 +48,10 @@ from .framework.dtype import bfloat16 # noqa: F401 ...@@ -48,7 +48,10 @@ from .framework.dtype import bfloat16 # noqa: F401
from .framework.dtype import bool # noqa: F401 from .framework.dtype import bool # noqa: F401
from .framework.dtype import complex64 # noqa: F401 from .framework.dtype import complex64 # noqa: F401
from .framework.dtype import complex128 # noqa: F401 from .framework.dtype import complex128 # noqa: F401
from .framework import VarBase as Tensor # noqa: F401 if fluid.framework._in_eager_mode_:
Tensor = framework.core.eager.Tensor
else:
from .framework import VarBase as Tensor # noqa: F401
Tensor.__qualname__ = 'Tensor' # noqa: F401 Tensor.__qualname__ = 'Tensor' # noqa: F401
import paddle.compat # noqa: F401 import paddle.compat # noqa: F401
import paddle.distributed # noqa: F401 import paddle.distributed # noqa: F401
......
...@@ -48,10 +48,10 @@ class TestCrossOp(OpTest): ...@@ -48,10 +48,10 @@ class TestCrossOp(OpTest):
self.outputs = {'Out': np.array(z_list).reshape(self.shape)} self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
def test_check_output(self): def test_check_output(self):
self.check_output(check_eager=False) self.check_output(check_eager=True)
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', check_eager=False) self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestCrossOpCase1(TestCrossOp): class TestCrossOpCase1(TestCrossOp):
......
...@@ -27,6 +27,9 @@ from paddle import _C_ops ...@@ -27,6 +27,9 @@ from paddle import _C_ops
__all__ = [] __all__ = []
# Consistent with kDefaultDim from C++ Backend
K_DEFAULT_DIM = 9
def matmul(x, y, transpose_x=False, transpose_y=False, name=None): def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
""" """
...@@ -1157,6 +1160,7 @@ def cross(x, y, axis=None, name=None): ...@@ -1157,6 +1160,7 @@ def cross(x, y, axis=None, name=None):
# [0. 0. 0.]] # [0. 0. 0.]]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
axis = K_DEFAULT_DIM if axis is None else axis
return _C_ops.final_state_cross(x, y, axis) return _C_ops.final_state_cross(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
......
...@@ -280,7 +280,8 @@ def greater_than(x, y, name=None): ...@@ -280,7 +280,8 @@ def greater_than(x, y, name=None):
print(result1) # result1 = [False False True] print(result1) # result1 = [False False True]
""" """
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.final_state_greater_than(x, y) axis = -1 # default value
return _C_ops.final_state_greater_than(x, y, axis)
else: else:
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _C_ops.greater_than(x, y) return _C_ops.greater_than(x, y)
......
...@@ -610,21 +610,21 @@ ...@@ -610,21 +610,21 @@
func : gelu func : gelu
backward : gelu_grad backward : gelu_grad
- api : greater - api : greater_equal
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
func : greater func : greater_equal
- api : greater_equal - api : greater_than
args : (Tensor x, Tensor y, int axis = -1) args : (Tensor x, Tensor y, int axis = -1)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : CompareInferMeta func : CompareInferMeta
kernel : kernel :
func : greater_equal func : greater_than
- api : gumbel_softmax - api : gumbel_softmax
args : (Tensor x, float temperature, bool hard, int axis) args : (Tensor x, float temperature, bool hard, int axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册