未验证 提交 9eb18c75 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support is empty (#43032)

* support is empty

* fix error

* fix code error

* change to fake empty

* using fake empty first

* using fake empty first
上级 4d32f417
......@@ -28,33 +28,40 @@
namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
const paddle::experimental::Tensor& t) {
if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl
const paddle::experimental::Tensor& t,
bool is_fake_empty) {
if (is_fake_empty) {
*tensor = t;
} else {
// Accumulation
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t, tensor);
if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl
*tensor = t;
} else {
// Accumulation
if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t,
tensor);
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
// add_dygraph_function once it's supported
paddle::experimental::Tensor new_buffer(
std::make_shared<phi::DenseTensor>(), "tmp_accumulator");
paddle::imperative::SelectedRowsAddTensor(*tensor, t, &new_buffer);
tensor->set_impl(new_buffer.impl());
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with
// add_dygraph_function once it's supported
paddle::experimental::Tensor new_buffer(
std::make_shared<phi::DenseTensor>(), "tmp_accumulator");
paddle::imperative::SelectedRowsAddTensor(*tensor, t, &new_buffer);
tensor->set_impl(new_buffer.impl());
}
} else {
// TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function
// once it's supported
if (tensor->is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, tensor);
} else {
*tensor = std::move(*paddle::imperative::SelectedRowsMerge<
paddle::experimental::Tensor>(t, *tensor));
// add_dygraph_function
// once it's supported
if (tensor->is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, tensor);
} else {
*tensor = std::move(*paddle::imperative::SelectedRowsMerge<
paddle::experimental::Tensor>(t, *tensor));
}
}
}
}
......@@ -91,7 +98,8 @@ GradNodeAccumulation::operator()(
if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out);
CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
is_fake_empty_ = false;
}
// Apply Reduce Hooks
......
......@@ -64,14 +64,16 @@ class GradNodeAccumulation : public GradNodeBase {
new GradNodeAccumulation(nullptr));
}
void SetFakeEmpty(bool is_fake_empty) { is_fake_empty_ = is_fake_empty; }
private:
// TODO(Jiabin): remove this when we make our clear gradient really cleared;
bool is_fake_empty_ = {false};
std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)>
retain_grad_hook_;
std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
};
} // namespace egr
......@@ -494,7 +494,8 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
}
paddle::experimental::Tensor* grad;
if (egr::egr_utils_api::IsLeafTensor(self->tensor)) {
bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
if (is_leaf) {
grad = egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE(grad != nullptr,
paddle::platform::errors::Fatal(
......@@ -518,6 +519,11 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
if (grad->initialized()) {
if (set_to_zero) {
grad->set_impl(paddle::experimental::zeros_like(*grad).impl());
if (is_leaf) {
std::static_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::grad_node(self->tensor))
->SetFakeEmpty(true);
}
} else {
VLOG(4) << "Gradient of " << self->tensor.name()
<< " is initialized, will be released.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册