未验证 提交 3800f192 编写于 作者: C Chen Weihang 提交者: GitHub

fix add_n incompatible error (#43395)

上级 f551d9fe
...@@ -1417,6 +1417,16 @@ static PyObject* tensor_method_get_non_zero_cols(TensorObject* self, ...@@ -1417,6 +1417,16 @@ static PyObject* tensor_method_get_non_zero_cols(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL EAGER_CATCH_AND_THROW_RETURN_NULL
} }
static PyObject* tensor_method_is_dense(TensorObject* self, PyObject* args,
PyObject* kwargs) {
EAGER_TRY
if (!self->tensor.defined()) {
return ToPyObject(false);
}
return ToPyObject(self->tensor.is_dense_tensor());
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args, static PyObject* tensor_method_is_sparse(TensorObject* self, PyObject* args,
PyObject* kwargs) { PyObject* kwargs) {
EAGER_TRY EAGER_TRY
...@@ -1682,6 +1692,8 @@ PyMethodDef variable_methods[] = { ...@@ -1682,6 +1692,8 @@ PyMethodDef variable_methods[] = {
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"clear_gradient", (PyCFunction)(void (*)(void))tensor_clear_gradient, {"clear_gradient", (PyCFunction)(void (*)(void))tensor_clear_gradient,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"is_dense", (PyCFunction)(void (*)(void))tensor_method_is_dense,
METH_VARARGS | METH_KEYWORDS, NULL},
{"_zero_grads", (PyCFunction)(void (*)(void))tensor__zero_grads, {"_zero_grads", (PyCFunction)(void (*)(void))tensor__zero_grads,
METH_VARARGS | METH_KEYWORDS, NULL}, METH_VARARGS | METH_KEYWORDS, NULL},
{"_share_buffer_to", (PyCFunction)(void (*)(void))tensor__share_buffer_to, {"_share_buffer_to", (PyCFunction)(void (*)(void))tensor__share_buffer_to,
......
...@@ -1032,7 +1032,9 @@ class Optimizer(object): ...@@ -1032,7 +1032,9 @@ class Optimizer(object):
assert regularization_term is not None assert regularization_term is not None
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
if grad.is_dense() and regularization_term.is_dense():
return _C_ops.final_state_add_n([grad, regularization_term]) return _C_ops.final_state_add_n([grad, regularization_term])
return _C_ops.sum([grad, regularization_term])
elif framework._in_legacy_dygraph(): elif framework._in_legacy_dygraph():
return _C_ops.sum([grad, regularization_term]) return _C_ops.sum([grad, regularization_term])
......
...@@ -1377,6 +1377,9 @@ def add_n(inputs, name=None): ...@@ -1377,6 +1377,9 @@ def add_n(inputs, name=None):
if in_dygraph_mode(): if in_dygraph_mode():
if isinstance(inputs, Variable): if isinstance(inputs, Variable):
inputs = [inputs] inputs = [inputs]
for x in inputs:
if not x.is_dense():
return _C_ops.sum(inputs, 'use_mkldnn', False)
return _C_ops.final_state_add_n(inputs) return _C_ops.final_state_add_n(inputs)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
if isinstance(inputs, Variable): if isinstance(inputs, Variable):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册