未验证 提交 2998a7d2 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Remove retain_grad_flag in accumulation_nade, add is_new_grad args in operator (#42240)

上级 12311ddc
......@@ -24,7 +24,7 @@
#include "paddle/fluid/platform/errors.h"
#include "glog/logging.h"
DECLARE_bool(retain_grad_for_all_tensor);
namespace egr {
static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
......@@ -41,7 +41,7 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
bool create_graph, bool is_new_grad) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal(
......@@ -63,7 +63,7 @@ operator()(
grad_out = grads[0][0];
}
if (!weak_grad_.expired() && FLAGS_retain_grad_for_all_tensor) {
if (!weak_grad_.expired() && !is_new_grad) {
auto grad = weak_grad_.lock();
CopyOrAddTensor(grad.get(), grad_out);
}
......
......@@ -39,7 +39,7 @@ class GradNodeAccumulation : public GradNodeBase {
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override;
bool create_graph = false, bool is_new_grad = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
......@@ -147,7 +147,7 @@ void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
bool create_graph, bool is_new_grad) {
// 1. Check Output Size
PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)),
......
......@@ -40,7 +40,7 @@ class GradNodeScale : public GradNodeBase {
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override;
bool create_graph = false, bool is_new_grad = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
......@@ -2444,7 +2444,7 @@ static std::string GenerateGradNodeCCContents(
"std::vector<std::vector<paddle::experimental::Tensor>> "
"GradNode%s::operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"create_graph) {\n"
"create_graph, bool is_new_grad) {\n"
"%s"
"%s"
"\n}";
......@@ -2490,7 +2490,7 @@ static std::string GenerateGradNodeHeaderContents(
" virtual std::vector<std::vector<paddle::experimental::Tensor>> "
"operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"create_graph = false) "
"create_graph = false, bool is_new_grad = false) "
"override;\n"
"\n"
" void ClearTensorWrappers() override { \n"
......
......@@ -119,7 +119,7 @@ class {} : public egr::GradNodeBase {{
~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false, bool is_new_grad = false) override;
std::string name() override {{ return \"{}\"; }}
void ClearTensorWrappers() override {{
......@@ -149,7 +149,7 @@ class {} : public egr::GradNodeBase {{
GRAD_FUNCTION_TEMPLATE = \
"""
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph, bool is_new_grad) {{
// Fill Zero For GradIn Tensors
{}
......
......@@ -690,7 +690,7 @@ std::vector<paddle::experimental::Tensor> RunBackward(
VLOG(6) << "Run Backward Kernel with GradTensorHolder.";
// Run Pre Backward Node and get outputs
std::vector<std::vector<paddle::experimental::Tensor>> grad_output_tensors =
(*node)(node_input_buffer->Buffers(), create_graph);
(*node)(node_input_buffer->Buffers(), create_graph, is_general_grad);
// retain_grad or not
if (!retain_graph) {
......
......@@ -20,8 +20,9 @@
namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode::
operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) { // NOLINT
operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph, bool is_new_grad) {
paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
......
......@@ -39,7 +39,7 @@ class RunCustomOpNode : public GradNodeBase {
virtual std::vector<std::vector<paddle::experimental::Tensor>>
operator()( // NOLINT
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) // NOLINT
bool create_graph = false, bool is_new_grad = false) // NOLINT
override;
std::string name() {
......
......@@ -109,7 +109,7 @@ class GradNodeBase {
* **/
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) = 0;
bool create_graph = false, bool is_new_grad = false) = 0;
virtual void ClearTensorWrappers() = 0;
......
......@@ -32,7 +32,7 @@ namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> GradNodePyLayer::
operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
bool create_graph, bool is_new_grad) {
VLOG(3) << "Running Eager Backward Node: " << name();
std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads =
......
......@@ -36,7 +36,7 @@ class GradNodePyLayer : public GradNodeBase {
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override;
bool create_graph = false, bool is_new_grad = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
......@@ -33,7 +33,7 @@ class GradTestNode : public egr::GradNodeBase {
std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override {
bool create_graph = false, bool is_new_grad = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0];
phi::DenseTensorMeta meta =
......
......@@ -366,7 +366,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
// Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT
bool create_graph) override {
bool create_graph, bool is_new_grad) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
std::vector<std::vector<paddle::experimental::Tensor>> hooked_grads =
GradNodeRunProgram::ApplyGradientHooks(grads);
......
......@@ -462,11 +462,9 @@ class TestTensorRegisterHook(unittest.TestCase):
x.register_hook(double_print_hook)
y = x * x
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': False})
# Since y = x * x, dx = 2 * x
dx = paddle.grad(
outputs=[y], inputs=[x], create_graph=True, retain_graph=True)[0]
fluid.set_flags({'FLAGS_retain_grad_for_all_tensor': True})
z = y + dx
self.assertTrue(x.grad is None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册