未验证 提交 9aed9ea0 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Adjusted Eager AutoCodeGen to Support Operators with Multiple OpBases & Enable...

Adjusted Eager AutoCodeGen to Support Operators with Multiple OpBases & Enable Passing Output Tensor as Input Argument (#37943)

* Rearranged Eager AutoCodeGen directory structure

* Removed USE_OP in Eager AutoCodeGen

* Enabled generation for Operators without Grad/Inputs/Outputs

* Resolved operators without input

* Fixed merge conflicts

* Enabled Eager AutoCodeGen for 10+ more operators

* Refactored Eager AutoCodeGen with more organized helper objects

* Enabled Eager AutoCodeGen for operators with multiple OpBases

* Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument
上级 aff7397b
...@@ -29,15 +29,11 @@ ...@@ -29,15 +29,11 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/* --- Static maps to handle corner cases --- */
static std::unordered_map<std::string, paddle::framework::AttributeMap> static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {}; operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_skip = {
"minus",
};
static std::unordered_set<std::string> operators_to_codegen = {}; static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {};
static std::string LegalizeVariableName(const std::string& var_name) { static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name; std::string ret = var_name;
...@@ -45,6 +41,132 @@ static std::string LegalizeVariableName(const std::string& var_name) { ...@@ -45,6 +41,132 @@ static std::string LegalizeVariableName(const std::string& var_name) {
return ret; return ret;
} }
/* --- Helper Objects --- */
class ForwardGenerationInfo {
public:
const std::string& GetOpType() const { return op_type_; }
void SetOpType(const std::string& op_type) { op_type_ = op_type; }
const std::unordered_map<std::string, size_t>& GetFwdInputsNamePosMap()
const {
return fwd_inputs_name_pos_map_;
}
std::unordered_map<std::string, size_t>* GetMutableFwdInputsNamePosMap() {
return &fwd_inputs_name_pos_map_;
}
const std::unordered_map<std::string, size_t>& GetFwdOutputsNamePosMap()
const {
return fwd_outputs_name_pos_map_;
}
std::unordered_map<std::string, size_t>* GetMutableFwdOutputsNamePosMap() {
return &fwd_outputs_name_pos_map_;
}
const std::vector<proto::OpProto::Var>& GetInVars() const { return in_vars_; }
std::vector<proto::OpProto::Var>* GetMutableInVars() { return &in_vars_; }
const std::vector<proto::OpProto::Var>& GetOutVars() const {
return out_vars_;
}
std::vector<proto::OpProto::Var>* GetMutableOutVars() { return &out_vars_; }
private:
std::string op_type_;
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map_;
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map_;
std::vector<proto::OpProto::Var> in_vars_;
std::vector<proto::OpProto::Var> out_vars_;
};
class GradNodeGenerationInfo {
class OpBaseGenerationInfo {
public:
const std::string& GetOpBaseType() const { return op_base_type_; }
void SetOpBaseType(const std::string& op_type) { op_base_type_ = op_type; }
const std::map<std::string, std::string>& GetGradOutsSlotnameMap() const {
return grad_outs_slotname_map_;
}
std::map<std::string, std::string>* GetMutableGradOutsSlotnameMap() {
return &grad_outs_slotname_map_;
}
const std::map<std::string, std::string>& GetGradInsFwdSlotnameMap() const {
return grad_ins_fwd_slotname_map_;
}
std::map<std::string, std::string>* GetMutableGradInsFwdSlotnameMap() {
return &grad_ins_fwd_slotname_map_;
}
const std::map<std::string, std::string>& GetGradInsGradSlotnameMap()
const {
return grad_ins_grad_slotname_map_;
}
std::map<std::string, std::string>* GetMutableGradInsGradSlotnameMap() {
return &grad_ins_grad_slotname_map_;
}
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
GetGradIns() const {
return grad_ins_;
}
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
GetMutableGradIns() {
return &grad_ins_;
}
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
GetGradOuts() const {
return grad_outs_;
}
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
GetMutableGradOuts() {
return &grad_outs_;
}
private:
std::string op_base_type_;
std::map<std::string, std::string> grad_outs_slotname_map_;
std::map<std::string, std::string> grad_ins_fwd_slotname_map_;
std::map<std::string, std::string> grad_ins_grad_slotname_map_;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins_;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs_;
};
public:
const std::string& GetFwdOpType() const { return fwd_op_type_; }
void SetFwdOpType(const std::string& op_type) { fwd_op_type_ = op_type; }
bool GenerateForwardOnly() const { return generate_forward_only_; }
void SetGenerateForwardOnly(bool generate_forward_only) {
generate_forward_only_ = generate_forward_only;
}
const std::vector<OpBaseGenerationInfo>& GetOpBaseInfos() const {
return op_base_infos_;
}
std::vector<OpBaseGenerationInfo>* GetMutableOpBaseInfos() {
return &op_base_infos_;
}
private:
std::string fwd_op_type_;
bool generate_forward_only_ = false;
std::vector<OpBaseGenerationInfo> op_base_infos_;
};
/* --- Helper Functions --- */
static std::string AttrTypeToString(const proto::AttrType& type) { static std::string AttrTypeToString(const proto::AttrType& type) {
std::string ret; std::string ret;
switch (type) { switch (type) {
...@@ -348,7 +470,6 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -348,7 +470,6 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
VLOG(1) << "------ Analyzing Op ------: " << op_type; VLOG(1) << "------ Analyzing Op ------: " << op_type;
if (!operators_to_codegen.count(op_type)) return false; if (!operators_to_codegen.count(op_type)) return false;
if (operators_to_skip.count(op_type)) return false;
return true; return true;
} }
...@@ -356,15 +477,16 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -356,15 +477,16 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
/* --------------------------------------- */ /* --------------------------------------- */
/* --------- Preprocess Ins/Outs --------- */ /* --------- Preprocess Ins/Outs --------- */
/* --------------------------------------- */ /* --------------------------------------- */
static void PurifyForwardOpProto( static void PurifyForwardOpProto(const proto::OpProto& op_proto,
const proto::OpProto& op_proto, ForwardGenerationInfo* fwd_info) {
std::unordered_map<std::string, size_t>* fwd_inputs_name_pos_map,
std::unordered_map<std::string, size_t>* fwd_outputs_name_pos_map,
std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars) {
// Op Name // Op Name
const std::string op_name = op_proto.type(); const std::string op_name = op_proto.type();
auto* in_vars = fwd_info->GetMutableInVars();
auto* out_vars = fwd_info->GetMutableOutVars();
auto* fwd_inputs_name_pos_map = fwd_info->GetMutableFwdInputsNamePosMap();
auto* fwd_outputs_name_pos_map = fwd_info->GetMutableFwdOutputsNamePosMap();
// Handle dispensable inputs // Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) { for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name(); std::string input_name = input.name();
...@@ -426,6 +548,104 @@ static void PurifyForwardOpProto( ...@@ -426,6 +548,104 @@ static void PurifyForwardOpProto(
} }
} }
static void PurifyGradNodeGenerationInfo(const proto::OpProto& op_proto,
GradNodeGenerationInfo* bwd_info) {
auto* op_base_infos = bwd_info->GetMutableOpBaseInfos();
for (auto& iter : *op_base_infos) {
std::map<std::string, std::string>* grad_outs_slotname_map =
iter.GetMutableGradOutsSlotnameMap();
std::map<std::string, std::string>* grad_ins_fwd_slotname_map =
iter.GetMutableGradInsFwdSlotnameMap();
std::map<std::string, std::string>* grad_ins_grad_slotname_map =
iter.GetMutableGradInsGradSlotnameMap();
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins = iter.GetMutableGradIns();
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs = iter.GetMutableGradOuts();
// Op Name
const std::string op_name = op_proto.type();
// Handle dispensable inputs
for (const proto::OpProto::Var& input : op_proto.inputs()) {
std::string input_name = input.name();
// Delete dispensable tensor unless specified in op_ins_map
if (input.dispensable()) {
if (!op_ins_map.count(op_name) ||
!op_ins_map[op_name].count(input_name)) {
VLOG(6) << "Removing Dispensable Input: " << input_name;
// grad_outs_slotname_map
auto grad_outs_slotname_map_purified = *grad_outs_slotname_map;
for (const auto& iter : *grad_outs_slotname_map) {
const std::string& grad_output_name = iter.first;
const std::string& matched_input_name = iter.second;
if (matched_input_name == input_name) {
grad_outs_slotname_map_purified.erase(grad_output_name);
PADDLE_ENFORCE(
grad_outs->count(grad_output_name) > 0,
paddle::platform::errors::Fatal(
"Unable to find gradient output name in grad_outs."));
// grad_outs
grad_outs->erase(grad_output_name);
}
}
*grad_outs_slotname_map = grad_outs_slotname_map_purified;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if (grad_ins_fwd_slotname_map->count(input_name))
grad_ins_fwd_slotname_map->erase(input_name);
// grad_ins: output as tensorwrapper
if (grad_ins->count(input_name)) grad_ins->erase(input_name);
}
}
}
for (const proto::OpProto::Var& output : op_proto.outputs()) {
std::string output_name = output.name();
// Delete dispensable tensor unless specified in op_outs_map
if (output.dispensable()) {
if (!op_outs_map.count(op_name) ||
!op_outs_map[op_name].count(output_name)) {
VLOG(6) << "Removing Dispensable Output: " << output_name;
// grad_ins_grad_slotname_map
auto grad_ins_grad_slotname_map_purified =
*grad_ins_grad_slotname_map;
for (const auto& iter : *grad_ins_grad_slotname_map) {
const std::string& grad_input_name = iter.first;
const std::string& matched_output_name = iter.second;
if (matched_output_name == output_name) {
grad_ins_grad_slotname_map_purified.erase(grad_input_name);
PADDLE_ENFORCE(
grad_ins->count(grad_input_name) > 0,
paddle::platform::errors::Fatal(
"Unable to find gradient input name in grad_ins."));
// grad_ins
grad_ins->erase(grad_input_name);
}
}
*grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified;
// grad_ins_fwd_slotname_map: output as tensorwrapper
if (grad_ins_fwd_slotname_map->count(output_name))
grad_ins_fwd_slotname_map->erase(output_name);
// grad_ins: output as tensorwrapper
if (grad_ins->count(output_name)) grad_ins->erase(output_name);
}
}
}
}
}
static void PurifyGradOpProto( static void PurifyGradOpProto(
const proto::OpProto& op_proto, const proto::OpProto& op_proto,
std::map<std::string, std::string>* grad_outs_slotname_map, std::map<std::string, std::string>* grad_outs_slotname_map,
...@@ -520,31 +740,22 @@ static void PurifyGradOpProto( ...@@ -520,31 +740,22 @@ static void PurifyGradOpProto(
/* --------- Collect Info --------- */ /* --------- Collect Info --------- */
/* -------------------------------- */ /* -------------------------------- */
static void CollectForwardInformationFromOpInfo( static void CollectForwardInformationFromOpInfo(
const paddle::framework::OpInfo& op_info, const paddle::framework::OpInfo& op_info, ForwardGenerationInfo* fwd_info) {
std::vector<proto::OpProto::Var>* in_vars,
std::vector<proto::OpProto::Var>* out_vars) {
const proto::OpProto& op_proto = *op_info.proto_; const proto::OpProto& op_proto = *op_info.proto_;
fwd_info->SetOpType(op_proto.type());
for (const proto::OpProto::Var& input : op_proto.inputs()) { for (const proto::OpProto::Var& input : op_proto.inputs()) {
in_vars->push_back(input); fwd_info->GetMutableInVars()->push_back(input);
} }
for (const proto::OpProto::Var& output : op_proto.outputs()) { for (const proto::OpProto::Var& output : op_proto.outputs()) {
out_vars->push_back(output); fwd_info->GetMutableOutVars()->push_back(output);
} }
} }
static bool CollectGradInformationFromOpInfo( static bool CollectGradInformationFromOpInfo(
const paddle::framework::OpInfo& op_info, bool* generate_forward_only, const paddle::framework::OpInfo& op_info,
std::vector<std::string>* grad_op_types, // grad GradNodeGenerationInfo* bwd_info) {
std::map<std::string, std::string>* grad_outs_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_fwd_slotname_map, // grad
std::map<std::string, std::string>* grad_ins_grad_slotname_map, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_ins, // grad
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>*
grad_outs // grad
) {
const proto::OpProto& op_proto = *op_info.proto_; const proto::OpProto& op_proto = *op_info.proto_;
const std::string& op_type = op_proto.type(); const std::string& op_type = op_proto.type();
std::vector<int64_t> dims = {1, 1, 1, 1}; std::vector<int64_t> dims = {1, 1, 1, 1};
...@@ -645,7 +856,7 @@ static bool CollectGradInformationFromOpInfo( ...@@ -645,7 +856,7 @@ static bool CollectGradInformationFromOpInfo(
/* ------ Run GradOpMaker ------ */ /* ------ Run GradOpMaker ------ */
if (!op_info.dygraph_grad_op_maker_) { if (!op_info.dygraph_grad_op_maker_) {
VLOG(6) << op_type << " has no GradOpMaker"; VLOG(6) << op_type << " has no GradOpMaker";
*generate_forward_only = true; bwd_info->SetGenerateForwardOnly(true);
return false; return false;
} }
...@@ -656,32 +867,31 @@ static bool CollectGradInformationFromOpInfo( ...@@ -656,32 +867,31 @@ static bool CollectGradInformationFromOpInfo(
if (!grad_node) { if (!grad_node) {
VLOG(6) << "Got nullptr GradOpNode for " << op_type VLOG(6) << "Got nullptr GradOpNode for " << op_type
<< " likely registered EmptyGradOpMaker"; << " likely registered EmptyGradOpMaker";
*generate_forward_only = true; bwd_info->SetGenerateForwardOnly(true);
return false;
}
/*
if (grad_node->size() > 1) {
// Backward attributes can be super complicated
VLOG(6) << "Skip GradOpNode with multiple OpBases for now: " << op_type;
skipped_operators.insert(op_type);
return false; return false;
} }
*/
VLOG(6) << "Prepared GradOpNode"; VLOG(6) << "Prepared GradOpNode";
/* ---- Collect Default Attr Map ---- */ /* ---- Collect OpBase's op_types ---- */
bwd_info->SetFwdOpType(op_type);
auto* op_base_infos = bwd_info->GetMutableOpBaseInfos();
op_base_infos->resize(grad_node->size());
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
// Each OpBase // Each OpBase
int index = std::distance(grad_node->begin(), iter);
paddle::imperative::OpBase& op_base = *iter; paddle::imperative::OpBase& op_base = *iter;
grad_op_types->push_back(op_base.Type()); (*op_base_infos)[index].SetOpBaseType(op_base.Type());
} }
/* ------ Get Grad ins/outs ---- */ /* ------ Get Grad ins/outs ---- */
// In case of multiple OpBase, stitch all the respective ins/outs into one // In case of multiple OpBase, stitch all the respective ins/outs into one
VLOG(6) << "In function size: " << grad_node->size(); VLOG(6) << "In function size: " << grad_node->size();
for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) {
int index = std::distance(grad_node->begin(), iter);
auto* op_base_grad_ins = (*op_base_infos)[index].GetMutableGradIns();
auto* op_base_grad_outs = (*op_base_infos)[index].GetMutableGradOuts();
const paddle::imperative::OpBase& op_base = *iter; const paddle::imperative::OpBase& op_base = *iter;
const std::map<std::string, paddle::imperative::SavedVariableWrapperList>& const std::map<std::string, paddle::imperative::SavedVariableWrapperList>&
g_ins = op_base.GetInsMap(); g_ins = op_base.GetInsMap();
...@@ -689,34 +899,47 @@ static bool CollectGradInformationFromOpInfo( ...@@ -689,34 +899,47 @@ static bool CollectGradInformationFromOpInfo(
g_outs = op_base.GetOutsMap(); g_outs = op_base.GetOutsMap();
for (const auto& it : g_ins) { for (const auto& it : g_ins) {
if (!grad_ins->count(it.first)) (*grad_ins)[it.first] = {}; if (!op_base_grad_ins->count(it.first))
(*op_base_grad_ins)[it.first] = {};
for (auto vw_iter = it.second.begin(); vw_iter != it.second.end(); for (auto vw_iter = it.second.begin(); vw_iter != it.second.end();
vw_iter++) { vw_iter++) {
std::shared_ptr<paddle::imperative::VariableWrapper> vw = *vw_iter; std::shared_ptr<paddle::imperative::VariableWrapper> vw = *vw_iter;
(*grad_ins)[it.first].push_back(vw);
(*op_base_grad_ins)[it.first].push_back(vw);
VLOG(6) << "GradIns Name: " << it.first;
} }
} }
for (const auto& it : g_outs) { for (const auto& it : g_outs) {
if (!grad_outs->count(it.first)) (*grad_outs)[it.first] = {}; if (!op_base_grad_outs->count(it.first))
(*op_base_grad_outs)[it.first] = {};
for (auto vw_iter = it.second.begin(); vw_iter != it.second.end(); for (auto vw_iter = it.second.begin(); vw_iter != it.second.end();
vw_iter++) { vw_iter++) {
std::shared_ptr<paddle::imperative::VariableWrapper> vw = *vw_iter; std::shared_ptr<paddle::imperative::VariableWrapper> vw = *vw_iter;
(*grad_outs)[it.first].push_back(vw);
(*op_base_grad_outs)[it.first].push_back(vw);
VLOG(6) << "GradOuts Name: " << it.first;
} }
} }
} }
/* ------ Slot Name Matching ---- */ /* ------ Slot Name Matching ---- */
for (auto& iter : *op_base_infos) {
// grad_ins -> fwd_ins, fwd_outs // grad_ins -> fwd_ins, fwd_outs
SlotNameMatching(*grad_ins, fwd_ins, fwd_outs, grad_ins_fwd_slotname_map, SlotNameMatching(iter.GetGradIns(), fwd_ins, fwd_outs,
grad_ins_grad_slotname_map); iter.GetMutableGradInsFwdSlotnameMap(),
VLOG(6) << "Finished Slotname Matching for Grad_Ins"; iter.GetMutableGradInsGradSlotnameMap());
// grad_outs -> fwd_ins, fwd_outs // grad_outs -> fwd_ins, fwd_outs
SlotNameMatching(*grad_outs, fwd_ins, fwd_outs, grad_outs_slotname_map, SlotNameMatching(iter.GetGradOuts(), fwd_ins, fwd_outs,
grad_outs_slotname_map); iter.GetMutableGradOutsSlotnameMap(),
VLOG(6) << "Finished Slotname Matching for Grad_Outs"; iter.GetMutableGradOutsSlotnameMap());
}
VLOG(6) << "Finished Slotname Matching";
return true; return true;
} }
...@@ -725,13 +948,20 @@ static bool CollectGradInformationFromOpInfo( ...@@ -725,13 +948,20 @@ static bool CollectGradInformationFromOpInfo(
/* --------- CodeGen: Forward GradNode Creation ------ */ /* --------- CodeGen: Forward GradNode Creation ------ */
/* --------------------------------------------------- */ /* --------------------------------------------------- */
static std::string GenerateGradNodeCreationContent( static std::string GenerateGradNodeCreationContent(
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map, const ForwardGenerationInfo& fwd_info,
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map, const GradNodeGenerationInfo& bwd_info) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
VLOG(6) << "Generating GradNode Creation codes"; VLOG(6) << "Generating GradNode Creation codes";
const std::string& op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
fwd_info.GetFwdInputsNamePosMap();
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
fwd_info.GetFwdOutputsNamePosMap();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
// [Generation] Construct GradOpNode // [Generation] Construct GradOpNode
// Run ComputeRequiredGrad // Run ComputeRequiredGrad
...@@ -817,12 +1047,17 @@ static std::string GenerateGradNodeCreationContent( ...@@ -817,12 +1047,17 @@ static std::string GenerateGradNodeCreationContent(
// [GradOpNode] Set TensorWrappers // [GradOpNode] Set TensorWrappers
grad_node_creation_str += " // Set Tensor Wrappers\n"; grad_node_creation_str += " // Set Tensor Wrappers\n";
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
for (auto& kv : grad_ins_fwd_slotname_map) { for (auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second; const std::string& tensor_wrapper_name = kv.second;
const char* SET_TENSOR_WRAPPER_TEMPLATE = const char* SET_TENSOR_WRAPPER_TEMPLATE =
" grad_node->SetTensorWrapper%s(%s);\n"; " grad_node->SetTensorWrapper%s(%s);\n";
grad_node_creation_str += paddle::string::Sprintf( grad_node_creation_str +=
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, tensor_wrapper_name); paddle::string::Sprintf(SET_TENSOR_WRAPPER_TEMPLATE,
tensor_wrapper_name, tensor_wrapper_name);
}
} }
grad_node_creation_str += "\n"; grad_node_creation_str += "\n";
VLOG(6) << "Generated SetTensorWrapper"; VLOG(6) << "Generated SetTensorWrapper";
...@@ -892,22 +1127,17 @@ static std::string GenerateGradNodeCreationContent( ...@@ -892,22 +1127,17 @@ static std::string GenerateGradNodeCreationContent(
/* --------- CodeGen: Forward ----- */ /* --------- CodeGen: Forward ----- */
/* -------------------------------- */ /* -------------------------------- */
static std::pair<std::string, std::string> GenerateForwardFunctionContents( static std::pair<std::string, std::string> GenerateForwardFunctionContents(
bool generate_forward_only, const ForwardGenerationInfo& fwd_info,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map, const GradNodeGenerationInfo& bwd_info) {
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map, /* --- Process Forward Info ---*/
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map, const std::string& op_type = fwd_info.GetOpType();
const std::map<std::string, std::string>& grad_ins_grad_slotname_map, const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
const std::map<std::string, std::string>& grad_outs_slotname_map, fwd_info.GetFwdInputsNamePosMap();
const std::map< const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
std::string, fwd_info.GetFwdOutputsNamePosMap();
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>& const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
grad_ins, const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
/* /*
// Forward Function Example: // Forward Function Example:
std::tuple<vector<Tensor>, Tensor, vector<Tensor>> std::tuple<vector<Tensor>, Tensor, vector<Tensor>>
...@@ -999,6 +1229,34 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -999,6 +1229,34 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
for (const proto::OpProto::Var& output : out_vars) { for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name(); const std::string& output_name = output.name();
std::string outnum = "1"; std::string outnum = "1";
if (op_passing_outs_map[op_type].count(output_name)) {
const std::string output_var_name = output_name + "Var";
// Pass Output from function argument,
// in form of shared_ptr<EagerTensor>/vector<shared_ptr<EagerTensor>>
if (output.duplicable()) {
const char* FWD_NUM_ARG_TEMPLATE =
", std::vector<std::shared_ptr<egr::EagerTensor>>& %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
} else {
const char* FWD_NUM_ARG_TEMPLATE =
", std::shared_ptr<egr::EagerTensor>& %s";
std::string arg_str =
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, output_var_name);
dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", {%s} },";
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, output_var_name);
}
} else {
if (output.duplicable()) { if (output.duplicable()) {
outnum = output_name + "Num"; outnum = output_name + "Num";
...@@ -1019,6 +1277,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1019,6 +1277,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name);
} }
} }
}
if (outs_contents_str.size() > 0) if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // Remove trailing "," outs_contents_str.pop_back(); // Remove trailing ","
...@@ -1084,10 +1343,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1084,10 +1343,9 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; VLOG(6) << "Converted Output VarBase to EagerTensor(s)";
// [Generation] ComputeRequireGrad -> GradNodeCreation // [Generation] ComputeRequireGrad -> GradNodeCreation
if (!generate_forward_only) { if (!bwd_info.GenerateForwardOnly()) {
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( std::string grad_node_creation_body_str =
fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, GenerateGradNodeCreationContent(fwd_info, bwd_info);
grad_ins_fwd_slotname_map, op_type, in_vars, out_vars);
generated_function_body += grad_node_creation_body_str; generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n"; generated_function_body += "\n";
VLOG(6) << "Generated GradNode Creation codes"; VLOG(6) << "Generated GradNode Creation codes";
...@@ -1162,22 +1420,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1162,22 +1420,16 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
/* --------- CodeGen: GradNode::operator() ------ */ /* --------- CodeGen: GradNode::operator() ------ */
/* ---------------------------------------------- */ /* ---------------------------------------------- */
static std::string GenerateGradNodeCCContents( static std::string GenerateGradNodeCCContents(
const std::vector<std::string>& grad_op_types, const ForwardGenerationInfo& fwd_info,
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map, const GradNodeGenerationInfo& bwd_info) {
const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map, /* --- Process Forward Info --- */
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map, const std::string& fwd_op_type = fwd_info.GetOpType();
const std::map<std::string, std::string>& grad_ins_grad_slotname_map, const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
const std::map<std::string, std::string>& grad_outs_slotname_map, fwd_info.GetFwdInputsNamePosMap();
const std::map< const std::unordered_map<std::string, size_t>& fwd_outputs_name_pos_map =
std::string, fwd_info.GetFwdOutputsNamePosMap();
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>& const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
grad_ins,
const std::map<
std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>&
grad_outs,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars,
const std::vector<proto::OpProto::Var>& out_vars) {
VLOG(6) << "Generating Grad Node CC"; VLOG(6) << "Generating Grad Node CC";
/* [Outline] /* [Outline]
...@@ -1224,10 +1476,24 @@ static std::string GenerateGradNodeCCContents( ...@@ -1224,10 +1476,24 @@ static std::string GenerateGradNodeCCContents(
*/ */
std::string generated_grad_function_body = ""; std::string generated_grad_function_body = "";
size_t outs_size = 0;
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
for (size_t i = 0; i < op_base_infos.size(); i++) {
const auto& op_base_info = op_base_infos[i];
// [Generation] Get Tracer const auto& grad_ins_fwd_slotname_map =
generated_grad_function_body += "\n"; op_base_info.GetGradInsFwdSlotnameMap();
generated_grad_function_body += "\n"; const auto& grad_ins_grad_slotname_map =
op_base_info.GetGradInsGradSlotnameMap();
const auto& grad_outs_slotname_map = op_base_info.GetGradOutsSlotnameMap();
const auto& grad_ins = op_base_info.GetGradIns();
const auto& grad_outs = op_base_info.GetGradOuts();
const std::string& op_base_type = op_base_info.GetOpBaseType();
const std::string& ins_name = "ins" + std::to_string(i);
const std::string& outs_name = "outs" + std::to_string(i);
outs_size += grad_outs.size();
// [Generation] Get Ins Map // [Generation] Get Ins Map
std::string ins_contents_str = ""; std::string ins_contents_str = "";
...@@ -1240,7 +1506,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1240,7 +1506,8 @@ static std::string GenerateGradNodeCCContents(
grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE = const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", " "{ \"%s\", "
"egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&" "egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper("
"&"
"this->%s, " "this->%s, "
"nullptr)) },"; "nullptr)) },";
ins_contents_str += ins_contents_str +=
...@@ -1253,8 +1520,9 @@ static std::string GenerateGradNodeCCContents( ...@@ -1253,8 +1520,9 @@ static std::string GenerateGradNodeCCContents(
grad_ins_grad_slotname_map.at(grad_input_name)); grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },"; "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str += paddle::string::Sprintf( ins_contents_str +=
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); paddle::string::Sprintf(GRAD_INS_GRAD_CONTENT_TEMPLATE,
grad_input_name, fwd_output_position);
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
...@@ -1268,10 +1536,10 @@ static std::string GenerateGradNodeCCContents( ...@@ -1268,10 +1536,10 @@ static std::string GenerateGradNodeCCContents(
const char* BWD_INS_MAP_TEMPLATE = const char* BWD_INS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> ins = { " "std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n"; "%s };\n";
std::string ins_map_str = std::string ins_map_str = paddle::string::Sprintf(
paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_contents_str); BWD_INS_MAP_TEMPLATE, ins_name, ins_contents_str);
generated_grad_function_body += ins_map_str; generated_grad_function_body += ins_map_str;
VLOG(6) << "Generated Ins Map"; VLOG(6) << "Generated Ins Map";
...@@ -1288,7 +1556,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1288,7 +1556,8 @@ static std::string GenerateGradNodeCCContents(
if (grad_outs_slotname_map.count(grad_output_name)) { if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor // Fwd Tensor
const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name); const std::string& fwd_name =
grad_outs_slotname_map.at(grad_output_name);
/* Handle Special Case: "PullSparseOp", etc /* Handle Special Case: "PullSparseOp", etc
...@@ -1309,7 +1578,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1309,7 +1578,8 @@ static std::string GenerateGradNodeCCContents(
GradOut GradOut
Its grad output "GradOut" corresponds to forward output "Out", Its grad output "GradOut" corresponds to forward output "Out",
where there is a hiden inplace involved. So we find "GradOut"'s index where there is a hiden inplace involved. So we find "GradOut"'s
index
in in
grads, and perform the inplace operation by constructing outs = grads, and perform the inplace operation by constructing outs =
{{"Out", grads[i]}} {{"Out", grads[i]}}
...@@ -1320,7 +1590,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1320,7 +1590,8 @@ static std::string GenerateGradNodeCCContents(
For returns, append "GradOut" to the very end of return list. For returns, append "GradOut" to the very end of return list.
*/ */
if (!fwd_inputs_name_pos_map.count(fwd_name)) { if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), PADDLE_ENFORCE(
fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor " "fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map")); "fwd_outputs_name_pos_map"));
...@@ -1330,7 +1601,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1330,7 +1601,8 @@ static std::string GenerateGradNodeCCContents(
const char* GET_GRADS_PTR_TEMPLATE = const char* GET_GRADS_PTR_TEMPLATE =
" std::vector<std::shared_ptr<egr::EagerTensor>> %s;\n" " std::vector<std::shared_ptr<egr::EagerTensor>> %s;\n"
" for(const auto& t : grads[%d]) {\n " " for(const auto& t : grads[%d]) {\n "
"%s.emplace_back(std::move(std::make_shared<egr::EagerTensor>(t)));" "%s.emplace_back(std::move(std::make_shared<egr::EagerTensor>(t))"
");"
"\n }\n"; "\n }\n";
std::string grads_ptr_str = std::string grads_ptr_str =
paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name, paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name,
...@@ -1348,12 +1620,14 @@ static std::string GenerateGradNodeCCContents( ...@@ -1348,12 +1620,14 @@ static std::string GenerateGradNodeCCContents(
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },"; "this->OutputMeta()[%d].Size() ) },";
outs_contents_str += paddle::string::Sprintf( outs_contents_str +=
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE,
grad_output_name, fwd_input_position);
} else { } else {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", " "{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()." "{std::make_shared<egr::EagerTensor>(egr::Controller::Instance("
")."
"GenerateUniqueName())}},"; "GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
...@@ -1371,29 +1645,25 @@ static std::string GenerateGradNodeCCContents( ...@@ -1371,29 +1645,25 @@ static std::string GenerateGradNodeCCContents(
const char* BWD_OUTS_MAP_TEMPLATE = const char* BWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, " " std::map<std::string, "
"std::vector<std::shared_ptr<egr::EagerTensor>>> outs = { " "std::vector<std::shared_ptr<egr::EagerTensor>>> %s = { "
"%s };\n"; "%s };\n";
std::string outs_map_str = std::string outs_map_str = paddle::string::Sprintf(
paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_contents_str); BWD_OUTS_MAP_TEMPLATE, outs_name, outs_contents_str);
generated_grad_function_body += outs_map_str; generated_grad_function_body += outs_map_str;
generated_grad_function_body += "\n"; generated_grad_function_body += "\n";
VLOG(6) << "Generated Outs Map"; VLOG(6) << "Generated Outs Map";
// [Generation] Get Attrs Map // [Generation] Get Attrs Map
std::string trace_opbase_str = "";
for (size_t i = 0; i < grad_op_types.size(); i++) {
const std::string& op_base_type = grad_op_types[i];
const char* TRACE_OP_TEMPLATE = const char* TRACE_OP_TEMPLATE =
" // Pass the entire attribute map to TraceOp\n" " // Pass the entire attribute map to TraceOp\n"
" // The underlying kernel will pickup whatever attribute they need " " // The underlying kernel will pickup whatever attribute they need "
"at runtime\n" "at runtime\n"
" egr::legacy::RunOp(\"%s\", ins, outs, this->attr_map_,\n" " egr::legacy::RunOp(\"%s\", %s, %s, this->attr_map_,\n"
" egr::Controller::Instance().GetExpectedPlace(),\n" " egr::Controller::Instance().GetExpectedPlace(),\n"
" &this->default_attr_map_, false, {});\n"; " &this->default_attr_map_, false, {});\n";
trace_opbase_str = paddle::string::Sprintf(TRACE_OP_TEMPLATE, op_base_type); std::string trace_opbase_str = paddle::string::Sprintf(
} TRACE_OP_TEMPLATE, op_base_type, ins_name, outs_name);
generated_grad_function_body += trace_opbase_str; generated_grad_function_body += trace_opbase_str;
...@@ -1409,9 +1679,9 @@ static std::string GenerateGradNodeCCContents( ...@@ -1409,9 +1679,9 @@ static std::string GenerateGradNodeCCContents(
if (fwd_inputs_name_pos_map.count(fwd_name)) { if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outputs_str += paddle::string::Sprintf(
fwd_input_position, grad_out_name); BWD_OUTPUT_TEMPLATE, fwd_input_position, outs_name, grad_out_name);
num_appended_outputs++; num_appended_outputs++;
} else { } else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
...@@ -1429,22 +1699,24 @@ static std::string GenerateGradNodeCCContents( ...@@ -1429,22 +1699,24 @@ static std::string GenerateGradNodeCCContents(
if (fwd_outputs_name_pos_map.count(fwd_name)) { if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; " outputs[%d] = egr::EagerUtils::GetOutputs(%s[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf( outputs_str +=
BWD_OUTPUT_TEMPLATE, num_appended_outputs, grad_out_name); paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, num_appended_outputs,
outs_name, grad_out_name);
num_appended_outputs++; num_appended_outputs++;
} }
} }
const char* BWD_RETURN_TEMPLATE = generated_grad_function_body += outputs_str;
" std::vector<std::vector<egr::EagerTensor>> "
"outputs(outs.size());\n%s\n "
"return outputs;";
std::string return_str =
paddle::string::Sprintf(BWD_RETURN_TEMPLATE, outputs_str);
generated_grad_function_body += "\n"; generated_grad_function_body += "\n";
generated_grad_function_body += return_str; }
const char* BWD_RETURN_TEMPLATE =
" std::vector<std::vector<egr::EagerTensor>> outputs(%d);\n"
" %s\n"
" return outputs;\n";
generated_grad_function_body = paddle::string::Sprintf(
BWD_RETURN_TEMPLATE, outs_size, generated_grad_function_body);
// [Generation] Get Full Grad Function // [Generation] Get Full Grad Function
const char* GRAD_FUNCTION_TEMPLATE = const char* GRAD_FUNCTION_TEMPLATE =
...@@ -1452,7 +1724,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -1452,7 +1724,7 @@ static std::string GenerateGradNodeCCContents(
"GradNode%s::operator()(const " "GradNode%s::operator()(const "
"std::vector<std::vector<egr::EagerTensor>>& grads) {\n%s\n}"; "std::vector<std::vector<egr::EagerTensor>>& grads) {\n%s\n}";
std::string grad_function_str = paddle::string::Sprintf( std::string grad_function_str = paddle::string::Sprintf(
GRAD_FUNCTION_TEMPLATE, op_type, generated_grad_function_body); GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body);
VLOG(6) << "Generated returns"; VLOG(6) << "Generated returns";
...@@ -1463,9 +1735,14 @@ static std::string GenerateGradNodeCCContents( ...@@ -1463,9 +1735,14 @@ static std::string GenerateGradNodeCCContents(
/* --------- CodeGen: GradNode Header ------ */ /* --------- CodeGen: GradNode Header ------ */
/* ----------------------------------------- */ /* ----------------------------------------- */
static std::string GenerateGradNodeHeaderContents( static std::string GenerateGradNodeHeaderContents(
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map, const ForwardGenerationInfo& fwd_info,
const std::string& op_type, const std::vector<proto::OpProto::Var>& in_vars, const GradNodeGenerationInfo& bwd_info) {
const std::vector<proto::OpProto::Var>& out_vars) { const std::string& op_type = fwd_info.GetOpType();
const std::vector<proto::OpProto::Var>& in_vars = fwd_info.GetInVars();
const std::vector<proto::OpProto::Var>& out_vars = fwd_info.GetOutVars();
const auto& op_base_infos = bwd_info.GetOpBaseInfos();
VLOG(6) << "Generating Grad Node Header"; VLOG(6) << "Generating Grad Node Header";
const char* GRAD_NODE_TEMPLATE = const char* GRAD_NODE_TEMPLATE =
...@@ -1522,6 +1799,10 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -1522,6 +1799,10 @@ static std::string GenerateGradNodeHeaderContents(
std::string set_tensor_wrappers_str = ""; std::string set_tensor_wrappers_str = "";
std::string tensor_wrapper_members_str = ""; std::string tensor_wrapper_members_str = "";
for (const auto& iter : op_base_infos) {
const std::map<std::string, std::string>& grad_ins_fwd_slotname_map =
iter.GetGradInsFwdSlotnameMap();
for (const auto& kv : grad_ins_fwd_slotname_map) { for (const auto& kv : grad_ins_fwd_slotname_map) {
const std::string& tensor_wrapper_name = kv.second; const std::string& tensor_wrapper_name = kv.second;
const std::string& struct_tensor_wrapper_name = kv.second + "_"; const std::string& struct_tensor_wrapper_name = kv.second + "_";
...@@ -1572,6 +1853,7 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -1572,6 +1853,7 @@ static std::string GenerateGradNodeHeaderContents(
SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name,
tensor_wrapper_arg_str, tensor_wrapper_body_str); tensor_wrapper_arg_str, tensor_wrapper_body_str);
} }
}
VLOG(6) << "Generated TensorWrapper"; VLOG(6) << "Generated TensorWrapper";
std::string grad_node_str = paddle::string::Sprintf( std::string grad_node_str = paddle::string::Sprintf(
...@@ -1682,97 +1964,62 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1682,97 +1964,62 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* ----------------------------- */ /* ----------------------------- */
/* ---- Collect Information ---- */ /* ---- Collect Information ---- */
/* ----------------------------- */ /* ----------------------------- */
std::vector<std::string> grad_op_types;
std::vector<proto::OpProto::Var> in_vars; ForwardGenerationInfo fwd_info;
std::vector<proto::OpProto::Var> out_vars; GradNodeGenerationInfo bwd_info;
std::map<std::string, std::string> grad_outs_slotname_map;
std::map<std::string, std::string> grad_ins_fwd_slotname_map;
std::map<std::string, std::string> grad_ins_grad_slotname_map;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_ins;
std::map<std::string,
std::vector<std::shared_ptr<paddle::imperative::VariableWrapper>>>
grad_outs;
VLOG(6) << "-------- CollectInformationFromOpInfo -------"; VLOG(6) << "-------- CollectInformationFromOpInfo -------";
CollectForwardInformationFromOpInfo(op_info, &in_vars, &out_vars); CollectForwardInformationFromOpInfo(op_info, &fwd_info);
bool generate_forward_only = false; bool is_available = CollectGradInformationFromOpInfo(op_info, &bwd_info);
bool is_available = CollectGradInformationFromOpInfo(
op_info, &generate_forward_only, &grad_op_types,
&grad_outs_slotname_map, &grad_ins_fwd_slotname_map,
&grad_ins_grad_slotname_map, &grad_ins, &grad_outs);
if (!is_available && !generate_forward_only) { if (!is_available && !bwd_info.GenerateForwardOnly()) {
VLOG(6) << "Skipped operator: " << op_type; VLOG(6) << "Skipped operator: " << op_type;
continue; continue;
} }
VLOG(6) << "-------- PurifyOpProto -------"; VLOG(6) << "-------- PurifyOpProto -------";
std::unordered_map<std::string, size_t> fwd_inputs_name_pos_map; PurifyForwardOpProto(*op_proto, &fwd_info);
std::unordered_map<std::string, size_t> fwd_outputs_name_pos_map; if (!bwd_info.GenerateForwardOnly()) {
PurifyForwardOpProto(*op_proto, &fwd_inputs_name_pos_map, PurifyGradNodeGenerationInfo(*op_proto, &bwd_info);
&fwd_outputs_name_pos_map, &in_vars, &out_vars);
if (!generate_forward_only) {
PurifyGradOpProto(*op_proto, &grad_outs_slotname_map,
&grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map,
&grad_ins, &grad_outs);
} }
/* --------------------------- */ /* --------------------------- */
/* --------- CodeGen --------- */ /* --------- CodeGen --------- */
/* --------------------------- */ /* --------------------------- */
/* ---- forward_dygraph_functions.cc ---- */
VLOG(6) << "-------- GenerateForwardFunctionContents -------"; VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration = std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents( GenerateForwardFunctionContents(fwd_info, bwd_info);
generate_forward_only, fwd_inputs_name_pos_map,
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins,
grad_outs, op_type, in_vars, out_vars);
fwd_function_str += body_and_declaration.first + "\n"; fwd_function_str += body_and_declaration.first + "\n";
/* ---- dygraph_forward_api.h ---- */ VLOG(6) << "-------- GenerateDygraphForwardAPIContents -------";
std::string fwd_function_declare_str = body_and_declaration.second; std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str; dygraph_forward_api_str += fwd_function_declare_str;
if (generate_forward_only) continue; if (bwd_info.GenerateForwardOnly()) continue;
/* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
grad_node_h_str += grad_node_h_str += GenerateGradNodeHeaderContents(fwd_info, bwd_info);
GenerateGradNodeHeaderContents(grad_ins_fwd_slotname_map, op_type, grad_node_h_str += "\n";
in_vars, out_vars) +
"\n";
/* ---- nodes.cc ---- */
VLOG(6) << "-------- GenerateGradNodeCCContents -------"; VLOG(6) << "-------- GenerateGradNodeCCContents -------";
grad_node_cc_str += GenerateGradNodeCCContents( grad_node_cc_str += GenerateGradNodeCCContents(fwd_info, bwd_info);
grad_op_types, fwd_inputs_name_pos_map, grad_node_cc_str += "\n";
fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map,
grad_ins_grad_slotname_map, grad_outs_slotname_map,
grad_ins, grad_outs, op_type, in_vars, out_vars) +
"\n";
VLOG(6) << op_type << ": Finished Generating Op: " << op_type; VLOG(6) << op_type << ": Finished Generating Op: " << op_type;
} }
/* ---- dygraph_forward_function.cc ---- */
VLOG(6) << "-------- GenerateDygraphForwardCCFile -------"; VLOG(6) << "-------- GenerateDygraphForwardCCFile -------";
GenerateForwardDygraphFile(output_dir, fwd_function_str); GenerateForwardDygraphFile(output_dir, fwd_function_str);
/* ---- dygraph_forward_api.h ---- */
VLOG(6) << "-------- GenerateForwardHFile -------"; VLOG(6) << "-------- GenerateForwardHFile -------";
GenerateForwardHFile(output_dir, dygraph_forward_api_str); GenerateForwardHFile(output_dir, dygraph_forward_api_str);
/* ---- nodes.h ---- */
VLOG(6) << "-------- GenerateNodeHFile -------"; VLOG(6) << "-------- GenerateNodeHFile -------";
GenerateNodeHFile(output_dir, grad_node_h_str); GenerateNodeHFile(output_dir, grad_node_h_str);
/* ---- nodes.cc ---- */
VLOG(6) << "-------- GenerateNodeCCFile -------"; VLOG(6) << "-------- GenerateNodeCCFile -------";
GenerateNodeCCFile(output_dir, grad_node_cc_str); GenerateNodeCCFile(output_dir, grad_node_cc_str);
} }
......
...@@ -237,6 +237,7 @@ spp ...@@ -237,6 +237,7 @@ spp
floor floor
gelu gelu
retinanet_detection_output retinanet_detection_output
minus
push_dense push_dense
silu silu
sequence_erase sequence_erase
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册