未验证 提交 d257acc6 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Add get_tensor_from_selected_rows (#45227)

* [Eager] add get_tensor_from_selected_rows

* add PADDLE_ENFORCE to check SelectedRows

* use _ prefix in temp
上级 92125870
......@@ -722,6 +722,33 @@ static PyObject* tensor_method_get_underline_selected_rows(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor_method__get_tensor_from_selected_rows(
TensorObject* self, PyObject* args, PyObject* kwargs) {
EAGER_TRY
PADDLE_ENFORCE(self->tensor.is_selected_rows(),
paddle::platform::errors::Fatal(
"this method is only effective for SelectedRows."));
auto* selected_rows =
static_cast<phi::SelectedRows*>(self->tensor.impl().get());
PADDLE_ENFORCE(
selected_rows->initialized(),
paddle::platform::errors::Fatal("SelectedRows must be initialized."));
auto* dense_tensor = static_cast<paddle::framework::LoDTensor*>(
selected_rows->mutable_value());
VLOG(1) << "dense_tensor: " << dense_tensor->IsInitialized();
auto t = paddle::experimental::Tensor(
egr::Controller::Instance().GenerateUniqueName());
t.set_impl(std::make_shared<phi::DenseTensor>(*dense_tensor));
return ToPyObject(t);
EAGER_CATCH_AND_THROW_RETURN_NULL
}
static PyObject* tensor__getitem_index_not_tensor(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
......@@ -1852,6 +1879,10 @@ PyMethodDef variable_methods[] = {
(PyCFunction)(void (*)(void))tensor_method_get_underline_selected_rows,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_get_tensor_from_selected_rows",
(PyCFunction)(void (*)(void))tensor_method__get_tensor_from_selected_rows,
METH_VARARGS | METH_KEYWORDS,
NULL},
{"_getitem_index_not_tensor",
(PyCFunction)(void (*)(void))tensor__getitem_index_not_tensor,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -71,13 +71,9 @@ def _squared_l2_norm(x):
return sum_square
if in_dygraph_mode():
if x.is_selected_rows():
new_x = paddle.to_tensor(x.numpy())
return _C_ops.final_state_squared_l2_norm(new_x)
return _C_ops.final_state_squared_l2_norm(x)
else:
if _in_legacy_dygraph():
return _C_ops.squared_l2_norm(x)
elif _in_legacy_dygraph():
return _C_ops.squared_l2_norm(x)
op_type = 'squared_l2_norm'
check_variable_and_dtype(x, 'x', ['float32', 'float64'], op_type)
......@@ -495,7 +491,12 @@ class ClipGradByGlobalNorm(ClipGradBase):
if getattr(p, 'need_clip', True) is False:
continue
merge_grad = g
if g.type == core.VarDesc.VarType.SELECTED_ROWS:
if in_dygraph_mode() and g.is_selected_rows():
merge_grad = layers.merge_selected_rows(g)
merge_grad = merge_grad._get_tensor_from_selected_rows()
elif g.type == core.VarDesc.VarType.SELECTED_ROWS:
merge_grad = layers.merge_selected_rows(g)
merge_grad = layers.get_tensor_from_selected_rows(merge_grad)
......
......@@ -13150,6 +13150,8 @@ def merge_selected_rows(x, name=None):
type=fluid.core.VarDesc.VarType.SELECTED_ROWS)
y = fluid.layers.merge_selected_rows(var)
"""
if _non_static_mode():
return _C_ops.merge_selected_rows(x)
helper = LayerHelper("merge_selected_rows", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册