未验证 提交 9eb18c75 编写于 作者: J Jiabin Yang 提交者: GitHub

[Eager] Support is empty (#43032)

* support is empty

* fix error

* fix code error

* change to fake empty

* using fake empty first

* using fake empty first
上级 4d32f417
...@@ -28,7 +28,11 @@ ...@@ -28,7 +28,11 @@
namespace egr { namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
const paddle::experimental::Tensor& t) { const paddle::experimental::Tensor& t,
bool is_fake_empty) {
if (is_fake_empty) {
*tensor = t;
} else {
if (!tensor->defined() || !tensor->initialized()) { if (!tensor->defined() || !tensor->initialized()) {
// Simply copy tensor->impl // Simply copy tensor->impl
*tensor = t; *tensor = t;
...@@ -36,7 +40,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -36,7 +40,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
// Accumulation // Accumulation
if (LIKELY(t.is_dense_tensor())) { if (LIKELY(t.is_dense_tensor())) {
if (LIKELY(tensor->is_dense_tensor())) { if (LIKELY(tensor->is_dense_tensor())) {
paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t, tensor); paddle::imperative::TensorAdd<paddle::experimental::Tensor>(t,
tensor);
} else { } else {
// TODO(jiabin): Support Other TensorBase later // TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with // TODO(zhanlve): Replace SelectedRowsAddTensor with
...@@ -48,7 +53,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -48,7 +53,8 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
} }
} else { } else {
// TODO(jiabin): Support Other TensorBase later // TODO(jiabin): Support Other TensorBase later
// TODO(zhanlve): Replace SelectedRowsAddTensor with add_dygraph_function // TODO(zhanlve): Replace SelectedRowsAddTensor with
// add_dygraph_function
// once it's supported // once it's supported
if (tensor->is_dense_tensor()) { if (tensor->is_dense_tensor()) {
paddle::imperative::SelectedRowsAddToTensor(t, tensor); paddle::imperative::SelectedRowsAddToTensor(t, tensor);
...@@ -58,6 +64,7 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -58,6 +64,7 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
} }
} }
} }
}
} }
paddle::small_vector<std::vector<paddle::experimental::Tensor>, paddle::small_vector<std::vector<paddle::experimental::Tensor>,
...@@ -91,7 +98,8 @@ GradNodeAccumulation::operator()( ...@@ -91,7 +98,8 @@ GradNodeAccumulation::operator()(
if (!weak_grad_.expired() && !is_new_grad) { if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock(); auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out); CopyOrAddTensor(grad.get(), grad_out, is_fake_empty_);
is_fake_empty_ = false;
} }
// Apply Reduce Hooks // Apply Reduce Hooks
......
...@@ -64,14 +64,16 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -64,14 +64,16 @@ class GradNodeAccumulation : public GradNodeBase {
new GradNodeAccumulation(nullptr)); new GradNodeAccumulation(nullptr));
} }
void SetFakeEmpty(bool is_fake_empty) { is_fake_empty_ = is_fake_empty; }
private: private:
// TODO(Jiabin): remove this when we make our clear gradient really cleared;
bool is_fake_empty_ = {false};
std::weak_ptr<paddle::experimental::Tensor> weak_grad_; std::weak_ptr<paddle::experimental::Tensor> weak_grad_;
std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
std::function<paddle::experimental::Tensor( std::function<paddle::experimental::Tensor(
const paddle::experimental::Tensor&)> const paddle::experimental::Tensor&)>
retain_grad_hook_; retain_grad_hook_;
std::vector<std::shared_ptr<TensorVoidHook>> reduce_hooks_;
}; };
} // namespace egr } // namespace egr
...@@ -494,7 +494,8 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args, ...@@ -494,7 +494,8 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
} }
paddle::experimental::Tensor* grad; paddle::experimental::Tensor* grad;
if (egr::egr_utils_api::IsLeafTensor(self->tensor)) { bool is_leaf = egr::egr_utils_api::IsLeafTensor(self->tensor);
if (is_leaf) {
grad = egr::EagerUtils::mutable_grad(self->tensor); grad = egr::EagerUtils::mutable_grad(self->tensor);
PADDLE_ENFORCE(grad != nullptr, PADDLE_ENFORCE(grad != nullptr,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -518,6 +519,11 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args, ...@@ -518,6 +519,11 @@ static PyObject* tensor_clear_gradient(TensorObject* self, PyObject* args,
if (grad->initialized()) { if (grad->initialized()) {
if (set_to_zero) { if (set_to_zero) {
grad->set_impl(paddle::experimental::zeros_like(*grad).impl()); grad->set_impl(paddle::experimental::zeros_like(*grad).impl());
if (is_leaf) {
std::static_pointer_cast<egr::GradNodeAccumulation>(
egr::EagerUtils::grad_node(self->tensor))
->SetFakeEmpty(true);
}
} else { } else {
VLOG(4) << "Gradient of " << self->tensor.name() VLOG(4) << "Gradient of " << self->tensor.name()
<< " is initialized, will be released."; << " is initialized, will be released.";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册