未验证 提交 50714d5c 编写于 作者: A Aurelius84 提交者: GitHub

[Eager]Fix eager no take effect problem (#41291)

* [Eager]Fix eager no take effect problem

* add element_wise and fix greater_than
上级 78200976
......@@ -1279,6 +1279,15 @@ static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_element_size(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
uint32_t element_size = framework::DataTypeSize(self->tensor.dtype());
return ToPyObject(element_size);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__bump_inplace_version(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -1417,6 +1426,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL},
{"to_dense", (PyCFunction)(void (*)(void))tensor_method_to_dense,
METH_VARARGS | METH_KEYWORDS, NULL},
{"element_size", (PyCFunction)(void (*)(void))tensor_method_element_size,
METH_VARARGS | METH_KEYWORDS, NULL},
/***the method of sparse tensor****/
{"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version,
METH_VARARGS | METH_KEYWORDS, NULL},
......
......@@ -48,7 +48,10 @@ from .framework.dtype import bfloat16 # noqa: F401
from .framework.dtype import bool # noqa: F401
from .framework.dtype import complex64 # noqa: F401
from .framework.dtype import complex128 # noqa: F401
from .framework import VarBase as Tensor # noqa: F401
if fluid.framework._in_eager_mode_:
Tensor = framework.core.eager.Tensor
else:
from .framework import VarBase as Tensor # noqa: F401
Tensor.__qualname__ = 'Tensor' # noqa: F401
import paddle.compat # noqa: F401
import paddle.distributed # noqa: F401
......
......@@ -48,10 +48,10 @@ class TestCrossOp(OpTest):
self.outputs = {'Out': np.array(z_list).reshape(self.shape)}
def test_check_output(self):
self.check_output(check_eager=False)
self.check_output(check_eager=True)
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', check_eager=False)
self.check_grad(['X', 'Y'], 'Out', check_eager=True)
class TestCrossOpCase1(TestCrossOp):
......
......@@ -27,6 +27,9 @@ from paddle import _C_ops
__all__ = []
# Consistent with kDefaultDim from C++ Backend
K_DEFAULT_DIM = 9
def matmul(x, y, transpose_x=False, transpose_y=False, name=None):
"""
......@@ -1157,6 +1160,7 @@ def cross(x, y, axis=None, name=None):
# [0. 0. 0.]]
"""
if in_dygraph_mode():
axis = K_DEFAULT_DIM if axis is None else axis
return _C_ops.final_state_cross(x, y, axis)
else:
if _in_legacy_dygraph():
......
......@@ -280,7 +280,8 @@ def greater_than(x, y, name=None):
print(result1) # result1 = [False False True]
"""
if in_dygraph_mode():
return _C_ops.final_state_greater_than(x, y)
axis = -1 # default value
return _C_ops.final_state_greater_than(x, y, axis)
else:
if _in_legacy_dygraph():
return _C_ops.greater_than(x, y)
......
......@@ -610,21 +610,21 @@
func : gelu
backward : gelu_grad
- api : greater
- api : greater_equal
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
infer_meta :
func : CompareInferMeta
kernel :
func : greater
func : greater_equal
- api : greater_equal
- api : greater_than
args : (Tensor x, Tensor y, int axis = -1)
output : Tensor
infer_meta :
func : CompareInferMeta
kernel :
func : greater_equal
func : greater_than
- api : gumbel_softmax
args : (Tensor x, float temperature, bool hard, int axis)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册