From 1c094d3ee698410eaf6be08267c0d34c8665489d Mon Sep 17 00:00:00 2001 From: xiongkun Date: Thu, 30 Dec 2021 13:18:14 +0800 Subject: [PATCH] refine run_program_op_grad output var name (#38470) * refine run_program_op_grad output var name * add default for global_block. for pass the eagle_generator_cmd * fix * ; * fix * const cast * mutable block --- .../auto_code_generator/eager_generator.cc | 6 ++++ paddle/fluid/operators/run_program_op.cc | 33 +++++++++++++++++-- 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 43c40beeebb..dfdd0f1e5ce 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/pybind/op_function_generator.h" #include "paddle/fluid/pybind/pybind.h" @@ -2037,6 +2038,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) { } static void PrepareAttrMapForOps() { + // Handle "run_program_op" + static framework::ProgramDesc fake_prog; + operators_with_attrs["run_program"] = {}; + operators_with_attrs["run_program"]["global_block"] = + fake_prog.MutableBlock(0); // Handle "fused_elemwise_add_activation" std::vector functor_list = {"a", "b"}; operators_with_attrs["fused_elemwise_add_activation"] = {}; diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index 80758e1718b..ec62feb07bc 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -153,6 +153,31 @@ class RunProgramGradOp : public framework::OperatorWithKernel { } }; +template +struct FilterHelper {}; + +template <> +struct FilterHelper { + static void filter(const BlockDesc* desc, + imperative::TracedVarList* vec) { + auto f = [desc](std::shared_ptr ptr) { + return !desc->HasVar(ptr->Name()); + }; + auto new_end = std::remove_if(vec->begin(), vec->end(), f); + vec->resize(new_end - vec->begin()); + } +}; + +template <> +struct FilterHelper { + static void filter(const BlockDesc* desc, std::vector* vec) { + auto f = [desc](const std::string& name) { return !desc->HasVar(name); }; + auto new_end = std::remove_if(vec->begin(), vec->end(), f); + vec->resize(new_end - vec->begin()); + } +}; + template class RunProgramGradOpMaker : public framework::SingleGradOpMaker { public: @@ -167,8 +192,12 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker { grad_op->SetInput("OutScope", this->Output("OutScope")); grad_op->SetInput("DOut", this->Output("DOut")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetOutput(framework::GradVarName("Params"), - this->InputGrad("Params")); + + auto block_desc = + BOOST_GET_CONST(BlockDesc*, this->GetAttr("global_block")); + auto params_grad = this->InputGrad("Params"); + FilterHelper::filter(block_desc, ¶ms_grad); // filter the vector. + grad_op->SetOutput(framework::GradVarName("Params"), params_grad); grad_op->SetAttrMap(this->Attrs()); } }; -- GitLab