diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 43c40beeebb391422ad6f442fc4700771436a2dd..dfdd0f1e5ce1b932bfa16d19a3d4489a5538451b 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -21,6 +21,7 @@ #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/pybind/op_function_generator.h" #include "paddle/fluid/pybind/pybind.h" @@ -2037,6 +2038,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) { } static void PrepareAttrMapForOps() { + // Handle "run_program_op" + static framework::ProgramDesc fake_prog; + operators_with_attrs["run_program"] = {}; + operators_with_attrs["run_program"]["global_block"] = + fake_prog.MutableBlock(0); // Handle "fused_elemwise_add_activation" std::vector functor_list = {"a", "b"}; operators_with_attrs["fused_elemwise_add_activation"] = {}; diff --git a/paddle/fluid/operators/run_program_op.cc b/paddle/fluid/operators/run_program_op.cc index 80758e1718be4634a7ee87a1fc59811ae9af9758..ec62feb07bc80711665fa8179a1a11cb040fa130 100644 --- a/paddle/fluid/operators/run_program_op.cc +++ b/paddle/fluid/operators/run_program_op.cc @@ -153,6 +153,31 @@ class RunProgramGradOp : public framework::OperatorWithKernel { } }; +template +struct FilterHelper {}; + +template <> +struct FilterHelper { + static void filter(const BlockDesc* desc, + imperative::TracedVarList* vec) { + auto f = [desc](std::shared_ptr ptr) { + return !desc->HasVar(ptr->Name()); + }; + auto new_end = std::remove_if(vec->begin(), vec->end(), f); + vec->resize(new_end - vec->begin()); + } +}; + +template <> +struct FilterHelper { + static void filter(const BlockDesc* desc, std::vector* vec) { + auto f = [desc](const std::string& name) { return !desc->HasVar(name); }; + auto new_end = std::remove_if(vec->begin(), vec->end(), f); + vec->resize(new_end - vec->begin()); + } +}; + template class RunProgramGradOpMaker : public framework::SingleGradOpMaker { public: @@ -167,8 +192,12 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker { grad_op->SetInput("OutScope", this->Output("OutScope")); grad_op->SetInput("DOut", this->Output("DOut")); grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); - grad_op->SetOutput(framework::GradVarName("Params"), - this->InputGrad("Params")); + + auto block_desc = + BOOST_GET_CONST(BlockDesc*, this->GetAttr("global_block")); + auto params_grad = this->InputGrad("Params"); + FilterHelper::filter(block_desc, ¶ms_grad); // filter the vector. + grad_op->SetOutput(framework::GradVarName("Params"), params_grad); grad_op->SetAttrMap(this->Attrs()); } };