未验证 提交 9aed9ea0 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Adjusted Eager AutoCodeGen to Support Operators with Multiple OpBases & Enable...

Adjusted Eager AutoCodeGen to Support Operators with Multiple OpBases & Enable Passing Output Tensor as Input Argument (#37943)

* Rearranged Eager AutoCodeGen directory structure

* Removed USE_OP in Eager AutoCodeGen

* Enabled generation for Operators without Grad/Inputs/Outputs

* Resolved operators without input

* Fixed merge conflicts

* Enabled Eager AutoCodeGen for 10+ more operators

* Refactored Eager AutoCodeGen with more organized helper objects

* Enabled Eager AutoCodeGen for operators with multiple OpBases

* Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument
上级 aff7397b
......@@ -29,15 +29,11 @@
namespace paddle {
namespace framework {
/* --- Static maps to handle corner cases --- */
static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_skip = {
"minus",
};
static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {};
static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name;
......@@ -45,6 +41,132 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret;
}
/* --- Helper Objects --- */
class ForwardGenerationInfo {
public:
const std::string& GetOpType() const { return op_type_; }
void SetOpType(const std::string& op_type) { op_type_ = op_type; }
const std::unordered_map<std::string, size_t>& GetFwdInputsNamePosMap()
const {
return fwd_inputs_name_pos_map_;
}
std::unordered_map<std::string, size_t>* GetMutableFwdInputsNamePosMap() {
return &fwd_inputs_name_pos_map_;
}
const std::unordered_map<std::string, size_t>& GetFwdOutputsNamePosMap()
const {
return fwd_outputs_name_pos_map_;
}
std::unordered_map<std::string, size_t>* GetMutableFwdOutputsNamePosMap() {
return &fwd_outputs_name_pos_map_;
}
const std::vector<proto::OpProto::Var>& GetInVars() const { return in_vars_; }
std::vector<proto::OpProto::Var>* GetMutableInVars() { return &in_vars_; }
const std::vector<proto::OpProto::Var>& GetOutVars() const {
return out_vars_;
}
std::vector<proto::OpProto::Var>* GetMutableOutVars() { return &out_vars_; }
private:
std::string op_type_;
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map_;
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map_;
std::vector<proto::OpProto::Var> in_vars_;
std::vector<proto::OpProto::Var> out_vars_;
};
class GradNodeGenerationInfo {
class OpBaseGenerationInfo {
public:
const std::string& GetOpBaseType() const { return op_base_type_; }
void SetOpBaseType(const std::string& op_type) { op_base_type_ = op_type; }
const std::map<std::string, std::string>& GetGradOutsSlotnameMap() const {
return grad_outs_slotname_map_;
}
std::map<std::string, std::string>* GetMutableGradOutsSlotnameMap() {
return &grad_outs_slotname_map_;
}
const std::map<std::string, std::string>& GetGradInsFwdSlotnameMap() const {
return grad_ins_fwd_slotname_map_;
}
std::map<std::string, std::string>* GetMutableGradInsFwdSlotnameMap() {
return &grad_ins_fwd_slotname_map_;
}
const std::map<std::string, std::string>& GetGradInsGradSlotnameMap()
const {
return grad_ins_grad_slotname_map_;
}
std::map<std::string, std::string>* GetMutableGradInsGradSlotnameMap() {
return &grad_ins_grad_slotname_map_;
}
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
GetGradIns() const {
return grad_ins_;
}
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
GetMutableGradIns() {
return &grad_ins_;
}
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
GetGradOuts() const {
return grad_outs_;
}
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
GetMutableGradOuts() {
return &grad_outs_;
}
private:
std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_;
std::map<std::string, std::string> grad_ins_fwd_slotname_map_;
std::map<std::string, std::string> grad_ins_grad_slotname_map_;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins_;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs_;
};
public:
const std::string& GetFwdOpType() const { return fwd_op_type_; }
void SetFwdOpType(const std::string& op_type) { fwd_op_type_ = op_type; }
bool GenerateForwardOnly() const { return generate_forward_only_; }
void SetGenerateForwardOnly(bool generate_forward_only) {
generate_forward_only_ = generate_forward_only;
}
const std::vector<OpBaseGenerationInfo>& GetOpBaseInfos() const {
return op_base_infos_;
}
std::vector<OpBaseGenerationInfo>* GetMutableOpBaseInfos() {
return &op_base_infos_;
}
private:
std::string fwd_op_type_;
bool generate_forward_only_ = false;
std::vector<OpBaseGenerationInfo> op_base_infos_;
};
/* --- Helper Functions --- */
static std::string AttrTypeToString(const proto::AttrType& type) {
std::string ret;
switch (type) {
......@@ -348,7 +470,6 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
VLOG(1) << "------ Analyzing Op ------: " << op_type;
if (!operators_to_codegen.count(op_type)) return false;
if (operators_to_skip.count(op_type)) return false;
return true;
}
......@@ -356,15 +477,16 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
/* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */
static void PurifyForwardOpProto(
const proto::OpProto& op_proto,
std::unordered_map<std::string, size_t>* fwd_inputs_name_pos_map,
std::unordered_map<std::string, size_t>* fwd_outputs_name_pos_map,
std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars) {
static void PurifyForwardOpProto(const proto::OpProto& op_proto,
ForwardGenerationInfo* fwd_info) {
// Op Name
const std::string op_name = op_proto.type();
auto* in_vars = fwd_info->GetMutableInVars();
auto* out_vars = fwd_info->GetMutableOutVars();
auto* fwd_inputs_name_pos_map = fwd_info->GetMutableFwdInputsNamePosMap();
auto* fwd_outputs_name_pos_map = fwd_info->GetMutableFwdOutputsNamePosMap();
// Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name();
......@@ -426,6 +548,104 @@ static void PurifyForwardOpProto(
}
}
static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto,
GradNodeGenerationInfo* bwd_info) {
auto* op_base_infos = bwd_info->GetMutableOpBaseInfos();
for (auto& iter : *op_base_infos) {
std::map<std::string, std::string>* grad_outs_slotname_map =
iter.GetMutableGradOutsSlotnameMap();
std::map<std::string, std::string>* grad_ins_fwd_slotname_map =
iter.GetMutableGradInsFwdSlotnameMap();
std::map<std::string, std::string>* grad_ins_grad_slotname_map =
iter.GetMutableGradInsGradSlotnameMap();
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins = iter.GetMutableGradIns();
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs = iter.GetMutableGradOuts();
// Op Name
const std::string op_name = op_proto.type();
// Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name();
// Delete dispensable tensor unless specified in op_ins_map
if (input.dispensable()) {
if (!op_ins_map.count(op_name) ||
!op_ins_map[op_name].count(input_name)) {
VLOG(6) << "Removing Dispensable Input: " << input_name;
// grad_outs_slotname_map
auto grad_outs_slotname_map_purified = *grad_outs_slotname_map;
for (const auto& iter : *grad_outs_slotname_map) {
const std::string& grad_output_name = iter.first;
const std::string& matched_input_name = iter.second;
if (matched_input_name == input_name) {
grad_outs_slotname_map_purified.erase(grad_output_name);
PADDLE_ENFORCE(
grad_outs->count(grad_output_name) > 0,
paddle::platform::errors::Fatal(
"Unable to find gradient output name in grad_outs."));
// grad_outs
grad_outs->erase(grad_output_name);
}
}
*grad_outs_slotname_map = grad_outs_slotname_map_purified;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if (grad_ins_fwd_slotname_map->count(input_name))
grad_ins_fwd_slotname_map->erase(input_name);
// grad_ins: output as tensorwrapper
if (grad_ins->count(input_name)) grad_ins->erase(input_name);
}
}
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
std::string output_name = output.name();
// Delete dispensable tensor unless specified in op_outs_map
if (output.dispensable()) {
if (!op_outs_map.count(op_name) ||
!op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name;
// grad_ins_grad_slotname_map
auto grad_ins_grad_slotname_map_purified =
*grad_ins_grad_slotname_map;
for (const auto& iter : *grad_ins_grad_slotname_map) {
const std::string& grad_input_name = iter.first;
const std::string& matched_output_name = iter.second;
if (matched_output_name == output_name) {
grad_ins_grad_slotname_map_purified.erase(grad_input_name);
PADDLE_ENFORCE(
grad_ins->count(grad_input_name) > 0,
paddle::platform::errors::Fatal(
"Unable to find gradient input name in grad_ins."));
// grad_ins
grad_ins->erase(grad_input_name);
}
}
*grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if (grad_ins_fwd_slotname_map->count(output_name))
grad_ins_fwd_slotname_map->erase(output_name);
// grad_ins: output as tensorwrapper
if (grad_ins->count(output_name)) grad_ins->erase(output_name);
}
}
}
}
}
static void PurifyGradOpProto(
const proto::OpProto& op_proto,
std::map<std::string, std::string>* grad_outs_slotname_map,
......@@ -520,31 +740,22 @@ static void PurifyGradOpProto(
/* --------- Collect Info --------- */
/* -------------------------------- */
static void CollectForwardInformationFromOpInfo(
const paddle::framework::OpInfo& op_info,
std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars) {
const paddle::framework::OpInfo& op_info, ForwardGenerationInfo* fwd_info) {
const proto::OpProto& op_proto = *op_info.proto_;
fwd_info->SetOpType(op_proto.type());
for (const proto::OpProto::Var& input : op_proto.inputs()) {
in_vars->push_back(input);
fwd_info->GetMutableInVars()->push_back(input);
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
out_vars->push_back(output);
fwd_info->GetMutableOutVars()->push_back(output);
}
}
static bool CollectGradInformationFromOpInfo(
const paddle::framework::OpInfo& op_info, bool* generate_forward_only,
std::vector<std::string>* grad_op_types, // grad
std::map<std::string, std::string>* grad_outs_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_fwd_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_grad_slotname_map, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs // grad
) {
const paddle::framework::OpInfo& op_info,
GradNodeGenerationInfo* bwd_info) {
const proto::OpProto& op_proto = *op_info.proto_;
const std::string& op_type = op_proto.type();
std::vector<int64_t> dims = {1, 1, 1, 1};
......@@ -645,7 +856,7 @@ static bool CollectGradInformationFromOpInfo(
/* ------ Run GradOpMaker ------ */
if (!op_info.dygraph_grad_op_maker_) {
VLOG(6) << op_type << " has no GradOpMaker";
*generate_forward_only = true;
bwd_info->SetGenerateForwardOnly(true);
return false;
}
......@@ -656,32 +867,31 @@ static bool CollectGradInformationFromOpInfo(
if (!grad_node) {
VLOG(6) << "Got nullptr GradOpNode for " << op_type
<< " likely registered EmptyGradOpMaker";
*generate_forward_only = true;
bwd_info->SetGenerateForwardOnly(true);
return false;
}
/*
if (grad_node->size() > 1) {
// Backward attributes can be super complicated
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
skipped_operators.insert(op_type);
return false;
}
*/
VLOG(6) << "Prepared GradOpNode";
/* ---- Collect Default Attr Map ---- */
/* ---- Collect OpBase's op_types ---- */
bwd_info->SetFwdOpType(op_type);
auto* op_base_infos = bwd_info->GetMutableOpBaseInfos();
op_base_infos->resize(grad_node->size());
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
// Each OpBase
int index = std::distance(grad_node->begin(), iter);
paddle::imperative::OpBase& op_base = *iter;
grad_op_types->push_back(op_base.Type());
(*op_base_infos)[index].SetOpBaseType(op_base.Type());
}
/* ------ Get Grad ins/outs ---- */
// In case of multiple OpBase, stitch all the respective ins/outs into one
VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter);
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
const paddle::imperative::OpBase& op_base = *iter;
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
g_ins = op_base.GetInsMap();
......@@ -689,34 +899,47 @@ static bool CollectGradInformationFromOpInfo(
g_outs = op_base.GetOutsMap();
for (const auto& it : g_ins) {
if (!grad_ins->count(it.first)) (*grad_ins)[it.first] = {};
if (!op_base_grad_ins->count(it.first))
(*op_base_grad_ins)[it.first] = {};
for (auto vw_iter = it.second.begin(); vw_iter != it.second.end();
vw_iter++) {
std::shared_ptr<paddle::imperative::VariableWrapper> vw = *vw_iter;
(*grad_ins)[it.first].push_back(vw);
(*op_base_grad_ins)[it.first].push_back(vw);
VLOG(6) << "GradIns Name: " << it.first;
}
}
for (const auto& it : g_outs) {
if (!grad_outs->count(it.first)) (*grad_outs)[it.first] = {};
if (!op_base_grad_outs->count(it.first))
(*op_base_grad_outs)[it.first] = {};
for (auto vw_iter = it.second.begin(); vw_iter != it.second.end();
vw_iter++) {
std::shared_ptr<paddle::imperative::VariableWrapper> vw = *vw_iter;
(*grad_outs)[it.first].push_back(vw);
(*op_base_grad_outs)[it.first].push_back(vw);
VLOG(6) << "GradOuts Name: " << it.first;
}
}
}
/* ------ Slot Name Matching ---- */
// grad_ins -> fwd_ins, fwd_outs
SlotNameMatching(*grad_ins, fwd_ins, fwd_outs, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map);
VLOG(6) << "Finished Slotname Matching for Grad_Ins";
// grad_outs -> fwd_ins, fwd_outs
SlotNameMatching(*grad_outs, fwd_ins, fwd_outs, grad_outs_slotname_map,
grad_outs_slotname_map);
VLOG(6) << "Finished Slotname Matching for Grad_Outs";
for (auto& iter : *op_base_infos) {
// grad_ins -> fwd_ins, fwd_outs
SlotNameMatching(iter.GetGradIns(), fwd_ins, fwd_outs,
iter.GetMutableGradInsFwdSlotnameMap(),
iter.GetMutableGradInsGradSlotnameMap());
// grad_outs -> fwd_ins, fwd_outs
SlotNameMatching(iter.GetGradOuts(), fwd_ins, fwd_outs,
iter.GetMutableGradOutsSlotnameMap(),
iter.GetMutableGradOutsSlotnameMap());
}
VLOG(6) << "Finished Slotname Matching";
return true;
}
......@@ -725,13 +948,20 @@ static bool CollectGradInformationFromOpInfo(
/* --------- CodeGen: Forward GradNode Creation ------ */
/* --------------------------------------------------- */
static std::string GenerateGradNodeCreationContent(
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
VLOG(6) << "Generating GradNode Creation codes";
const std::string& op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
// [Generation] Construct GradOpNode
// Run ComputeRequiredGrad
......@@ -817,12 +1047,17 @@ static std::string GenerateGradNodeCreationContent(
// [GradOpNode] Set TensorWrappers
grad_node_creation_str += " // Set Tensor Wrappers\n";
for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s);\n";
grad_node_creation_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, tensor_wrapper_name);
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s);\n";
grad_node_creation_str +=
paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE,
tensor_wrapper_name, tensor_wrapper_name);
}
}
grad_node_creation_str += "\n";
VLOG(6) << "Generated SetTensorWrapper";
......@@ -892,22 +1127,17 @@ static std::string GenerateGradNodeCreationContent(
/* --------- CodeGen: Forward ----- */
/* -------------------------------- */
static std::pair<std::string, std::string> GenerateForwardFunctionContents(
bool generate_forward_only,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
/* --- Process Forward Info ---*/
const std::string& op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
/*
// Forward Function Example:
std::tuple<vector<Tensor>, Tensor, vector<Tensor>>
......@@ -999,24 +1229,53 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name();
std::string outnum = "1";
if (output.duplicable()) {
outnum = output_name + "Num";
const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
output_name, outnum);
if (op_passing_outs_map[op_type].count(output_name)) {
const std::string output_var_name = output_name + "Var";
// Pass Output from function argument,
// in form of shared_ptr<EagerTensor>/vector<shared_ptr<EagerTensor>>
if (output.duplicable()) {
const char* FWD_NUM_ARG_TEMPLATE =
", std::vector<std::shared_ptr<egr::EagerTensor>>& %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
} else {
const char* FWD_NUM_ARG_TEMPLATE =
", std::shared_ptr<egr::EagerTensor>& %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", {%s} },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
}
} else {
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},";
outs_contents_str +=
paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name);
if (output.duplicable()) {
outnum = output_name + "Num";
const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
output_name, outnum);
} else {
const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},";
outs_contents_str +=
paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name);
}
}
}
if (outs_contents_str.size() > 0)
......@@ -1084,10 +1343,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Converted Output VarBase to EagerTensor(s)";
// [Generation] ComputeRequireGrad -> GradNodeCreation
if (!generate_forward_only) {
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map,
grad_ins_fwd_slotname_map, op_type, in_vars, out_vars);
if (!bwd_info.GenerateForwardOnly()) {
std::string grad_node_creation_body_str =
GenerateGradNodeCreationContent(fwd_info, bwd_info);
generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n";
VLOG(6) << "Generated GradNode Creation codes";
......@@ -1162,22 +1420,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
/* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */
static std::string GenerateGradNodeCCContents(
const std::vector<std::string>& grad_op_types,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map,
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::map<std::string, std::string>& grad_ins_grad_slotname_map,
const std::map<std::string, std::string>& grad_outs_slotname_map,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
/* --- Process Forward Info --- */
const std::string& fwd_op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
VLOG(6) << "Generating Grad Node CC";
/* [Outline]
......@@ -1224,227 +1476,247 @@ static std::string GenerateGradNodeCCContents(
*/
std::string generated_grad_function_body = "";
size_t outs_size = 0;
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
const auto& grad_ins_fwd_slotname_map =
op_base_info.GetGradInsFwdSlotnameMap();
const auto& grad_ins_grad_slotname_map =
op_base_info.GetGradInsGradSlotnameMap();
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const std::string& op_base_type = op_base_info.GetOpBaseType();
const std::string& ins_name = "ins" + std::to_string(i);
const std::string& outs_name = "outs" + std::to_string(i);
outs_size += grad_outs.size();
// [Generation] Get Ins Map
std::string ins_contents_str = "";
for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
// Fwd Tensor
std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", "
"egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper("
"&"
"this->%s, "
"nullptr)) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name);
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE,
grad_input_name, fwd_output_position);
// [Generation] Get Tracer
generated_grad_function_body += "\n";
generated_grad_function_body += "\n";
// [Generation] Get Ins Map
std::string ins_contents_str = "";
for (auto iter : grad_ins) {
const std::string& grad_input_name = iter.first;
if (grad_ins_fwd_slotname_map.count(grad_input_name)) {
// Fwd Tensor
std::string struct_fwd_input_name =
grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", "
"egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&"
"this->%s, "
"nullptr)) },";
ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
grad_input_name, struct_fwd_input_name);
} else if (grad_ins_grad_slotname_map.count(grad_input_name)) {
// Fwd Tensor's Grad
size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_input_name));
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_input_name));
}
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string ins_map_str = paddle::string::Sprintf(
BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::unordered_set<std::string> duplicable_input_name_set;
for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
}
if (ins_contents_str.size() > 0)
ins_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { "
"%s };\n";
std::string ins_map_str =
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_contents_str);
generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map";
// [Generation] Get Outs Map
std::unordered_set<std::string> duplicable_input_name_set;
for (const auto& in : in_vars) {
if (in.duplicable()) duplicable_input_name_set.insert(in.name());
}
std::string outs_contents_str = "";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc
Forward:
Ids W
| |
PullSparseOp
|
Out
Backward:
Ids GradOut W
| | |
PullSparseGradOp
|
GradOut
Its grad output "GradOut" corresponds to forward output "Out",
where there is a hiden inplace involved. So we find "GradOut"'s index
in
grads, and perform the inplace operation by constructing outs =
{{"Out", grads[i]}}
GradOut -> Out -> fwd_output_pos -> grads position -> grads[i]
outs = {{"Out", grads[i]}}
For returns, append "GradOut" to the very end of return list.
*/
if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
std::string grad_ptr_name = fwd_name + "_ptrs";
const char* GET_GRADS_PTR_TEMPLATE =
" std::vector<std::shared_ptr<egr::EagerTensor>> %s;\n"
" for(const auto& t : grads[%d]) {\n "
"%s.emplace_back(std::move(std::make_shared<egr::EagerTensor>(t)));"
"\n }\n";
std::string grads_ptr_str =
paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name,
grads_position, grad_ptr_name);
generated_grad_function_body += grads_ptr_str;
generated_grad_function_body += "\n";
const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name);
} else {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
std::string outs_contents_str = "";
for (auto iter : grad_outs) {
const std::string& grad_output_name = iter.first;
if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor
const std::string& fwd_name =
grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc
Forward:
Ids W
| |
PullSparseOp
|
Out
Backward:
Ids GradOut W
| | |
PullSparseGradOp
|
GradOut
Its grad output "GradOut" corresponds to forward output "Out",
where there is a hiden inplace involved. So we find "GradOut"'s
index
in
grads, and perform the inplace operation by constructing outs =
{{"Out", grads[i]}}
GradOut -> Out -> fwd_output_pos -> grads position -> grads[i]
outs = {{"Out", grads[i]}}
For returns, append "GradOut" to the very end of return list.
*/
if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(
fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
std::string grad_ptr_name = fwd_name + "_ptrs";
const char* GET_GRADS_PTR_TEMPLATE =
" std::vector<std::shared_ptr<egr::EagerTensor>> %s;\n"
" for(const auto& t : grads[%d]) {\n "
"%s.emplace_back(std::move(std::make_shared<egr::EagerTensor>(t))"
");"
"\n }\n";
std::string grads_ptr_str =
paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name,
grads_position, grad_ptr_name);
generated_grad_function_body += grads_ptr_str;
generated_grad_function_body += "\n";
const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
if (duplicable_input_name_set.count(fwd_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },";
outs_contents_str +=
paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance("
")."
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
}
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_output_name));
}
} else {
PADDLE_THROW(platform::errors::Fatal(
"Detected mismatched slot names."
"Unable to find forward slot name that matches %s",
grad_output_name));
}
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // // Remove trailing ","
const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { "
"%s };\n";
std::string outs_map_str =
paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_contents_str);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n";
VLOG(6) << "Generated Outs Map";
const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n";
std::string outs_map_str = paddle::string::Sprintf(
BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n";
// [Generation] Get Attrs Map
std::string trace_opbase_str = "";
for (size_t i = 0; i < grad_op_types.size(); i++) {
const std::string& op_base_type = grad_op_types[i];
VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map
const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need "
"at runtime\n"
" egr::legacy::RunOp(\"%s\", ins, outs, this->attr_map_,\n"
" egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n";
trace_opbase_str = paddle::string::Sprintf(TRACE_OP_TEMPLATE, op_base_type);
}
std::string trace_opbase_str = paddle::string::Sprintf(
TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name);
generated_grad_function_body += trace_opbase_str;
generated_grad_function_body += trace_opbase_str;
VLOG(6) << "Generated Attrs Map";
VLOG(6) << "Generated Attrs Map";
// [Generation] Get Return
std::string outputs_str = "";
size_t num_appended_outputs = 0;
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
// [Generation] Get Return
std::string outputs_str = "";
size_t num_appended_outputs = 0;
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
fwd_input_position, grad_out_name);
num_appended_outputs++;
} else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
num_appended_outputs++;
} else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
}
}
}
/* Handle Special Case: "PullSparseOp", etc
For returns, append "GradOut" to the very end of return list. */
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, num_appended_outputs, grad_out_name);
num_appended_outputs++;
/* Handle Special Case: "PullSparseOp", etc
For returns, append "GradOut" to the very end of return list. */
for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str +=
paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs,
outs_name, grad_out_name);
num_appended_outputs++;
}
}
generated_grad_function_body += outputs_str;
generated_grad_function_body += "\n";
}
const char* BWD_RETURN_TEMPLATE =
" std::vector<std::vector<egr::EagerTensor>> "
"outputs(outs.size());\n%s\n "
"return outputs;";
std::string return_str =
paddle::string::Sprintf(BWD_RETURN_TEMPLATE, outputs_str);
generated_grad_function_body += "\n";
generated_grad_function_body += return_str;
" std::vector<std::vector<egr::EagerTensor>> outputs(%d);\n"
" %s\n"
" return outputs;\n";
generated_grad_function_body = paddle::string::Sprintf(
BWD_RETURN_TEMPLATE, outs_size, generated_grad_function_body);
// [Generation] Get Full Grad Function
const char* GRAD_FUNCTION_TEMPLATE =
......@@ -1452,7 +1724,7 @@ static std::string GenerateGradNodeCCContents(
"GradNode%s::operator()(const "
"std::vector<std::vector<egr::EagerTensor>>& grads) {\n%s\n}";
std::string grad_function_str = paddle::string::Sprintf(
GRAD_FUNCTION_TEMPLATE, op_type, generated_grad_function_body);
GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body);
VLOG(6) << "Generated returns";
......@@ -1463,9 +1735,14 @@ static std::string GenerateGradNodeCCContents(
/* --------- CodeGen: GradNode Header ------ */
/* ----------------------------------------- */
static std::string GenerateGradNodeHeaderContents(
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
const std::string& op_type = fwd_info.GetOpType();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
VLOG(6) << "Generating Grad Node Header";
const char* GRAD_NODE_TEMPLATE =
......@@ -1522,55 +1799,60 @@ static std::string GenerateGradNodeHeaderContents(
std::string set_tensor_wrappers_str = "";
std::string tensor_wrapper_members_str = "";
for (const auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
const std::string& struct_tensor_wrapper_name = kv.second + "_";
std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str;
if (duplicable_tensors.count(tensor_wrapper_name)) {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const std::vector<egr::EagerTensor>& %s";
tensor_wrapper_arg_str = paddle::string::Sprintf(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name);
const char* TENSOR_WRAPPER_MEMBER_TEMPLATE =
" std::vector<egr::TensorWrapper> %s;\n";
tensor_wrapper_members_str += paddle::string::Sprintf(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, true "
"/*full_reserved*/) );\n"
" }\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name);
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
for (const auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second;
const std::string& struct_tensor_wrapper_name = kv.second + "_";
std::string tensor_wrapper_arg_str;
std::string tensor_wrapper_body_str;
if (duplicable_tensors.count(tensor_wrapper_name)) {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const std::vector<egr::EagerTensor>& %s";
tensor_wrapper_arg_str = paddle::string::Sprintf(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name);
const char* TENSOR_WRAPPER_MEMBER_TEMPLATE =
" std::vector<egr::TensorWrapper> %s;\n";
tensor_wrapper_members_str += paddle::string::Sprintf(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"for(const auto& eager_tensor : %s) {\n"
" %s.emplace_back( egr::TensorWrapper(eager_tensor, true "
"/*full_reserved*/) );\n"
" }\n";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name,
struct_tensor_wrapper_name);
} else {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const egr::EagerTensor& %s";
tensor_wrapper_arg_str = paddle::string::Sprintf(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name);
const char* TENSOR_WRAPPER_MEMBER_TEMPLATE =
" egr::TensorWrapper %s;\n";
tensor_wrapper_members_str += paddle::string::Sprintf(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, true /*full_reserved*/);";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name);
}
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, tensor_wrapper_body_str);
} else {
const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE =
"const egr::EagerTensor& %s";
tensor_wrapper_arg_str = paddle::string::Sprintf(
ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name);
const char* TENSOR_WRAPPER_MEMBER_TEMPLATE =
" egr::TensorWrapper %s;\n";
tensor_wrapper_members_str += paddle::string::Sprintf(
TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name);
const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE =
"%s = egr::TensorWrapper(%s, true /*full_reserved*/);";
tensor_wrapper_body_str = paddle::string::Sprintf(
SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name,
tensor_wrapper_name);
}
const char* SET_TENSOR_WRAPPER_TEMPLATE =
" void SetTensorWrapper%s(%s) {\n %s\n }\n";
set_tensor_wrappers_str += paddle::string::Sprintf(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, tensor_wrapper_body_str);
}
}
VLOG(6) << "Generated TensorWrapper";
......@@ -1682,97 +1964,62 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ----------------------------- */
/* ---- Collect Information ---- */
/* ----------------------------- */
std::vector<std::string> grad_op_types;
std::vector<proto::OpProto::Var> in_vars;
std::vector<proto::OpProto::Var> out_vars;
std::map<std::string, std::string> grad_outs_slotname_map;
std::map<std::string, std::string> grad_ins_fwd_slotname_map;
std::map<std::string, std::string> grad_ins_grad_slotname_map;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs;
ForwardGenerationInfo fwd_info;
GradNodeGenerationInfo bwd_info;
VLOG(6) << "-------- CollectInformationFromOpInfo -------";
CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars);
CollectForwardInformationFromOpInfo(op_info, &fwd_info);
bool generate_forward_only = false;
bool is_available = CollectGradInformationFromOpInfo(
op_info, &generate_forward_only, &grad_op_types,
&grad_outs_slotname_map, &grad_ins_fwd_slotname_map,
&grad_ins_grad_slotname_map, &grad_ins, &grad_outs);
bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info);
if (!is_available && !generate_forward_only) {
if (!is_available && !bwd_info.GenerateForwardOnly()) {
VLOG(6) << "Skipped operator: " << op_type;
continue;
}
VLOG(6) << "-------- PurifyOpProto -------";
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map;
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map;
PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map,
&fwd_outputs_name_pos_map, &in_vars, &out_vars);
if (!generate_forward_only) {
PurifyGradOpProto(*op_proto, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map,
&grad_ins, &grad_outs);
PurifyForwardOpProto(*op_proto, &fwd_info);
if (!bwd_info.GenerateForwardOnly()) {
PurifyGradNodeGenerationInfo(*op_proto, &bwd_info);
}
/* --------------------------- */
/* --------- CodeGen --------- */
/* --------------------------- */
/* ---- forward_dygraph_functions.cc ---- */
VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents(
generate_forward_only, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins,
grad_outs, op_type, in_vars, out_vars);
GenerateForwardFunctionContents(fwd_info, bwd_info);
fwd_function_str += body_and_declaration.first + "\n";
/* ---- dygraph_forward_api.h ---- */
VLOG(6) << "-------- GenerateDygraphForwardAPIContents -------";
std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str;
if (generate_forward_only) continue;
if (bwd_info.GenerateForwardOnly()) continue;
/* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
grad_node_h_str +=
GenerateGradNodeHeaderContents(grad_ins_fwd_slotname_map, op_type,
in_vars, out_vars) +
"\n";
grad_node_h_str += GenerateGradNodeHeaderContents(fwd_info, bwd_info);
grad_node_h_str += "\n";
/* ---- nodes.cc ---- */
VLOG(6) << "-------- GenerateGradNodeCCContents -------";
grad_node_cc_str += GenerateGradNodeCCContents(
grad_op_types, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map,
grad_ins, grad_outs, op_type, in_vars, out_vars) +
"\n";
grad_node_cc_str += GenerateGradNodeCCContents(fwd_info, bwd_info);
grad_node_cc_str += "\n";
VLOG(6) << op_type << ": Finished Generating Op: " << op_type;
}
/* ---- dygraph_forward_function.cc ---- */
VLOG(6) << "-------- GenerateDygraphForwardCCFile -------";
GenerateForwardDygraphFile(output_dir, fwd_function_str);
/* ---- dygraph_forward_api.h ---- */
VLOG(6) << "-------- GenerateForwardHFile -------";
GenerateForwardHFile(output_dir, dygraph_forward_api_str);
/* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateNodeHFile -------";
GenerateNodeHFile(output_dir, grad_node_h_str);
/* ---- nodes.cc ---- */
VLOG(6) << "-------- GenerateNodeCCFile -------";
GenerateNodeCCFile(output_dir, grad_node_cc_str);
}
......
......@@ -237,6 +237,7 @@ spp
floor
gelu
retinanet_detection_output
minus
push_dense
silu
sequence_erase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册