未验证 提交 2998a7d2 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Remove retain_grad_flag in accumulation_nade, add is_new_grad args in operator (#42240)

上级 12311ddc
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
#include "glog/logging.h" #include "glog/logging.h"
DECLARE_bool(retain_grad_for_all_tensor);
namespace egr { namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
...@@ -41,7 +41,7 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -41,7 +41,7 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()( operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) { bool create_graph, bool is_new_grad) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation"; VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1, PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
...@@ -63,7 +63,7 @@ operator()( ...@@ -63,7 +63,7 @@ operator()(
grad_out = grads[0][0]; grad_out = grads[0][0];
} }
if (!weak_grad_.expired() && FLAGS_retain_grad_for_all_tensor) { if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock(); auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out); CopyOrAddTensor(grad.get(), grad_out);
} }
......
...@@ -39,7 +39,7 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -39,7 +39,7 @@ class GradNodeAccumulation : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override; bool create_graph = false, bool is_new_grad = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
...@@ -147,7 +147,7 @@ void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; } ...@@ -147,7 +147,7 @@ void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()( operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) { bool create_graph, bool is_new_grad) {
// 1. Check Output Size // 1. Check Output Size
PADDLE_ENFORCE( PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)), ((grads.size() == 1) && (grads[0].size() == 1)),
......
...@@ -40,7 +40,7 @@ class GradNodeScale : public GradNodeBase { ...@@ -40,7 +40,7 @@ class GradNodeScale : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override; bool create_graph = false, bool is_new_grad = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
...@@ -2444,7 +2444,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -2444,7 +2444,7 @@ static std::string GenerateGradNodeCCContents(
"std::vector<std::vector<paddle::experimental::Tensor>> " "std::vector<std::vector<paddle::experimental::Tensor>> "
"GradNode%s::operator()(" "GradNode%s::operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool " "std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"create_graph) {\n" "create_graph, bool is_new_grad) {\n"
"%s" "%s"
"%s" "%s"
"\n}"; "\n}";
...@@ -2490,7 +2490,7 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2490,7 +2490,7 @@ static std::string GenerateGradNodeHeaderContents(
" virtual std::vector<std::vector<paddle::experimental::Tensor>> " " virtual std::vector<std::vector<paddle::experimental::Tensor>> "
"operator()(" "operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool " "std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"create_graph = false) " "create_graph = false, bool is_new_grad = false) "
"override;\n" "override;\n"
"\n" "\n"
" void ClearTensorWrappers() override { \n" " void ClearTensorWrappers() override { \n"
......
...@@ -119,7 +119,7 @@ class {} : public egr::GradNodeBase {{ ...@@ -119,7 +119,7 @@ class {} : public egr::GradNodeBase {{
~{}() override = default; ~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override; std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false, bool is_new_grad = false) override;
std::string name() override {{ return \"{}\"; }} std::string name() override {{ return \"{}\"; }}
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
...@@ -149,7 +149,7 @@ class {} : public egr::GradNodeBase {{ ...@@ -149,7 +149,7 @@ class {} : public egr::GradNodeBase {{
GRAD_FUNCTION_TEMPLATE = \ GRAD_FUNCTION_TEMPLATE = \
""" """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph, bool is_new_grad) {{
// Fill Zero For GradIn Tensors // Fill Zero For GradIn Tensors
{} {}
......
...@@ -690,7 +690,7 @@ std::vector<paddle::experimental::Tensor> RunBackward( ...@@ -690,7 +690,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG(6) << "Run Backward Kernel with GradTensorHolder."; VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
// Run Pre Backward Node and get outputs // Run Pre Backward Node and get outputs
std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors = std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
(*node)(node_input_buffer->Buffers(), create_graph); (*node)(node_input_buffer->Buffers(), create_graph, is_general_grad);
// retain_grad or not // retain_grad or not
if (!retain_graph) { if (!retain_graph) {
......
...@@ -20,8 +20,9 @@ ...@@ -20,8 +20,9 @@
namespace egr { namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode:: std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode::
operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, operator()(
bool create_graph) { // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph, bool is_new_grad) {
paddle::CustomOpKernelContext ctx; paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
......
...@@ -39,7 +39,7 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -39,7 +39,7 @@ class RunCustomOpNode : public GradNodeBase {
virtual std::vector<std::vector<paddle::experimental::Tensor>> virtual std::vector<std::vector<paddle::experimental::Tensor>>
operator()( // NOLINT operator()( // NOLINT
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) // NOLINT bool create_graph = false, bool is_new_grad = false) // NOLINT
override; override;
std::string name() { std::string name() {
......
...@@ -109,7 +109,7 @@ class GradNodeBase { ...@@ -109,7 +109,7 @@ class GradNodeBase {
* **/ * **/
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) = 0; bool create_graph = false, bool is_new_grad = false) = 0;
virtual void ClearTensorWrappers() = 0; virtual void ClearTensorWrappers() = 0;
......
...@@ -32,7 +32,7 @@ namespace egr { ...@@ -32,7 +32,7 @@ namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> GradNodePyLayer:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodePyLayer::
operator()( operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) { bool create_graph, bool is_new_grad) {
VLOG(3) << "Running Eager Backward Node: " << name(); VLOG(3) << "Running Eager Backward Node: " << name();
std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads = std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads =
......
...@@ -36,7 +36,7 @@ class GradNodePyLayer : public GradNodeBase { ...@@ -36,7 +36,7 @@ class GradNodePyLayer : public GradNodeBase {
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override; bool create_graph = false, bool is_new_grad = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
...@@ -33,7 +33,7 @@ class GradTestNode : public egr::GradNodeBase { ...@@ -33,7 +33,7 @@ class GradTestNode : public egr::GradNodeBase {
std::string name() override { return "GradTestNode"; } std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()( std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override { bool create_graph = false, bool is_new_grad = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl()) val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0]; ->data<float>()[0];
phi::DenseTensorMeta meta = phi::DenseTensorMeta meta =
......
...@@ -366,7 +366,7 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -366,7 +366,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT
bool create_graph) override { bool create_graph, bool is_new_grad) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads = std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads =
GradNodeRunProgram::ApplyGradientHooks(grads); GradNodeRunProgram::ApplyGradientHooks(grads);
......
...@@ -462,11 +462,9 @@ class TestTensorRegisterHook(unittest.TestCase): ...@@ -462,11 +462,9 @@ class TestTensorRegisterHook(unittest.TestCase):
x.register_hook(double_print_hook) x.register_hook(double_print_hook)
y = x * x y = x * x
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': False})
# Since y = x * x, dx = 2 * x # Since y = x * x, dx = 2 * x
dx = paddle.grad( dx = paddle.grad(
outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0] outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0]
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': True})
z = y + dx z = y + dx
self.assertTrue(x.grad is None) self.assertTrue(x.grad is None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册