未验证 提交 1c094d3e 编写于 作者: X xiongkun 提交者: GitHub

refine run_program_op_grad output var name (#38470)

* refine run_program_op_grad output var name

* add default for global_block. for pass the eagle_generator_cmd

* fix

* ;

* fix

* const cast

* mutable block
上级 ed8ba011
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/pybind/op_function_generator.h"
#include "paddle/fluid/pybind/pybind.h"
......@@ -2037,6 +2038,11 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
}
static void PrepareAttrMapForOps() {
// Handle "run_program_op"
static framework::ProgramDesc fake_prog;
operators_with_attrs["run_program"] = {};
operators_with_attrs["run_program"]["global_block"] =
fake_prog.MutableBlock(0);
// Handle "fused_elemwise_add_activation"
std::vector<std::string> functor_list = {"a", "b"};
operators_with_attrs["fused_elemwise_add_activation"] = {};
......
......@@ -153,6 +153,31 @@ class RunProgramGradOp : public framework::OperatorWithKernel {
}
};
template <typename T>
struct FilterHelper {};
template <>
struct FilterHelper<imperative::OpBase> {
static void filter(const BlockDesc* desc,
imperative::TracedVarList<imperative::VarBase,
imperative::kBackward>* vec) {
auto f = [desc](std::shared_ptr<imperative::VarBase> ptr) {
return !desc->HasVar(ptr->Name());
};
auto new_end = std::remove_if(vec->begin(), vec->end(), f);
vec->resize(new_end - vec->begin());
}
};
template <>
struct FilterHelper<framework::OpDesc> {
static void filter(const BlockDesc* desc, std::vector<std::string>* vec) {
auto f = [desc](const std::string& name) { return !desc->HasVar(name); };
auto new_end = std::remove_if(vec->begin(), vec->end(), f);
vec->resize(new_end - vec->begin());
}
};
template <typename T>
class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
......@@ -167,8 +192,12 @@ class RunProgramGradOpMaker : public framework::SingleGradOpMaker<T> {
grad_op->SetInput("OutScope", this->Output("OutScope"));
grad_op->SetInput("DOut", this->Output("DOut"));
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"),
this->InputGrad("Params"));
auto block_desc =
BOOST_GET_CONST(BlockDesc*, this->GetAttr("global_block"));
auto params_grad = this->InputGrad("Params");
FilterHelper<T>::filter(block_desc, &params_grad); // filter the vector.
grad_op->SetOutput(framework::GradVarName("Params"), params_grad);
grad_op->SetAttrMap(this->Attrs());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册