From e7bda1ddaeb152df632c4da10b1efa061f6a70f1 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Sat, 27 Nov 2021 20:32:18 +0800 Subject: [PATCH] Added Eager Dygraph AutoCodeGen dependencies #2 (#37575) --- .../auto_code_generator/eager_generator.cc | 811 +++++++++++++++++- 1 file changed, 804 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 1fc2d480e8e..2cb8b4b9904 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -40,11 +40,10 @@ static std::unordered_set operators_to_skip = { "fused_attention", "diag_v2", }; -/* + static std::unordered_set operators_to_codegen = { "sigmoid", "matmul_v2", "reduce_sum", "elementwise_add", "share_buffer", "var_conv_2d", "split"}; -*/ static std::unordered_set skipped_operators = {}; @@ -107,8 +106,10 @@ static std::string AttrTypeToString(const proto::AttrType& type) { break; } default: { - PADDLE_THROW( - platform::errors::Fatal("Unable to recognize AttrType: %d", type)); + PADDLE_THROW(platform::errors::Fatal( + "AttrType of type boost::variant only supports specific data types." + "However, detected unrecognized AttrType: %d", + type)); } } return ret; @@ -214,8 +215,10 @@ static std::pair GetAttrType( break; } default: { - PADDLE_THROW(platform::errors::Fatal("Unable to recognize AttrType: %d", - variant_pos)); + PADDLE_THROW(platform::errors::Fatal( + "AttrType of type boost::variant only supports specific data types." + "However, detected unrecognized AttrType: %d", + variant_pos)); } } return {ret, val}; @@ -259,6 +262,7 @@ static void SlotNameMatching( if (grad_fwd_slotname_map.count(grad_slot_name) && grad_fwd_slotname_map[grad_slot_name] != fwd_slot_name) { PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." "grad_slot_name %s matches both %s and %s fwd_slot_name", grad_slot_name, grad_fwd_slotname_map[grad_slot_name], fwd_slot_name)); @@ -271,6 +275,7 @@ static void SlotNameMatching( if (grad_grad_slotname_map.count(grad_slot_name) && grad_grad_slotname_map[grad_slot_name] != fwd_slot_name) { PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." "grad_slot_name %s matches both %s and %s fwd_slot_name", grad_slot_name, grad_grad_slotname_map[grad_slot_name], fwd_slot_name)); @@ -290,6 +295,7 @@ static void SlotNameMatching( if (grad_fwd_slotname_map.count(grad_slot_name) && grad_fwd_slotname_map[grad_slot_name] != fwd_slot_name) { PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names" "grad_slot_name %s matches both %s and %s fwd_slot_name", grad_slot_name, grad_fwd_slotname_map[grad_slot_name], fwd_slot_name)); @@ -302,6 +308,7 @@ static void SlotNameMatching( if (grad_grad_slotname_map.count(grad_slot_name) && grad_grad_slotname_map[grad_slot_name] != fwd_slot_name) { PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." "grad_slot_name %s matches both %s and %s fwd_slot_name", grad_slot_name, grad_grad_slotname_map[grad_slot_name], fwd_slot_name)); @@ -315,6 +322,7 @@ static void SlotNameMatching( if (!found_matching) { PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." "Found no matching fwd_slot_name for grad_slot_name: %s", grad_slot_name)); @@ -344,7 +352,7 @@ static bool CheckOpProto(proto::OpProto* op_proto) { // Only handle matmul_v2 for now VLOG(1) << "------ Analyzing Op ------: " << op_type; - // if (!operators_to_codegen.count(op_type)) return false; + if (!operators_to_codegen.count(op_type)) return false; if (operators_to_skip.count(op_type)) return false; return true; @@ -741,5 +749,794 @@ static std::string AppendUseOp(const std::string& op_type) { return return_str; } +/* -------------------------------- */ +/* --------- CodeGen: Forward ----- */ +/* -------------------------------- */ +static std::pair GenerateForwardFunctionContents( + const std::vector& + grad_node_default_attr_maps, + const std::unordered_map& fwd_inputs_name_pos_map, + const std::unordered_map& fwd_outputs_name_pos_map, + const std::map& grad_ins_fwd_slotname_map, + const std::map& grad_ins_grad_slotname_map, + const std::map& grad_outs_slotname_map, + const std::map< + std::string, + std::vector>>& + grad_ins, + const std::map< + std::string, + std::vector>>& + grad_outs, + const proto::OpProto& op_proto) { + /* + // Forward Function Example: + std::tuple, Tensor, vector> + kernel_function(vector& X, Tensor& Y, const paddle::AttributeMap& + attr_map, size_t + Out0Num, size_t Out1Num) { + + // Forward Function Body + // According to fwd_inputs_name_pos_map + std::map>> + ins = + { {"X" , SyncToVars(X)}, { "Y" , SyncToVars(Y)} }; + + std::map>> + outs = + { + {"Out0" , ConstructDuplicableOutput(Out0Num)}, {"Out1" + ,ConstructDuplicableOutput(Out1Num)} }; + + // According to op_proto->attrs() + egr::RunOp("op_type", ins, outs, attr_map, + Controller.Instance().GetExpectedPlace(), {}); + + // According to fwd_outputs_names + std::vector Out0 = GetOutputs(outs["Out0"]); + egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]); + std::vector Out2 = GetOutputs(outs["Out2"]); + + // Grad Node Generation Codes + ... + + return std::make_tuple(Out0, Out1, Out2); + } + */ + VLOG(6) << "Generating Dygraph Forward Function"; + + const std::string& op_type = op_proto.type(); + + std::string generated_function_body = ""; + std::string dygraph_function_args_str = ""; + + /* ------ Dygraph forward function generation ------ */ + generated_function_body += " // Dygraph Forward Pass\n"; + generated_function_body += "\n"; + + // [Generation] Get Ins Map + std::string ins_contents_str = ""; + std::vector input_args_str_list(op_proto.inputs().size()); + for (const proto::OpProto::Var& input : op_proto.inputs()) { + const std::string& input_name = input.name(); + size_t input_position = fwd_inputs_name_pos_map.at(input_name); + if (input.duplicable()) { + const char* FWD_INS_ARG_TEMPLATE = + "const std::vector& %s"; + input_args_str_list[input_position] = + paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); + } else { + const char* FWD_INS_ARG_TEMPLATE = "const egr::EagerTensor& %s"; + input_args_str_list[input_position] = + paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); + } + const char* FWD_INS_CONTENT_TEMPLATE = "{ \"%s\", egr::SyncToVars(%s) },"; + ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, + input_name, input_name); + } + if (ins_contents_str.size() > 0) + ins_contents_str.pop_back(); // // Remove trailing "," + + for (const std::string& arg : input_args_str_list) { + dygraph_function_args_str += arg; + dygraph_function_args_str += ","; + } + if (dygraph_function_args_str.size() > 0) + dygraph_function_args_str.pop_back(); + + const char* FWD_INS_MAP_TEMPLATE = + " std::map>> ins = { " + "%s };\n"; + std::string ins_map_str = + paddle::string::Sprintf(FWD_INS_MAP_TEMPLATE, ins_contents_str); + generated_function_body += ins_map_str; + generated_function_body += "\n"; + + VLOG(6) << "Generated Ins Map"; + + // [Generation] Get Outs Map + std::string outs_contents_str = ""; + for (const proto::OpProto::Var& output : op_proto.outputs()) { + const std::string& output_name = output.name(); + std::string outnum = "1"; + if (output.duplicable()) { + outnum = output_name + "Num"; + + const char* FWD_NUM_ARG_TEMPLATE = ", size_t %s"; + std::string arg_str = + paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); + dygraph_function_args_str += arg_str; + const char* FWD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::ConstructDuplicableOutput(%s) },"; + outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, + output_name, outnum); + } else { + const char* FWD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance()." + "GenerateUniqueName())}},"; + outs_contents_str += + paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, output_name); + } + } + if (outs_contents_str.size() > 0) + outs_contents_str.pop_back(); // Remove trailing "," + + const char* FWD_OUTS_MAP_TEMPLATE = + " std::map>> outs = { " + "%s };\n"; + std::string outs_map_str = + paddle::string::Sprintf(FWD_OUTS_MAP_TEMPLATE, outs_contents_str); + generated_function_body += outs_map_str; + generated_function_body += "\n"; + + VLOG(6) << "Generated Outs Map"; + + // [Generation] Get Attrs + dygraph_function_args_str += + ", const paddle::framework::AttributeMap& attr_map"; + generated_function_body += "\n"; + + // [Generation] Get TraceOp + const char* FWD_TRACE_OP_TEMPLATE = + " paddle::framework::AttributeMap attrs = attr_map;\n" + " paddle::framework::AttributeMap default_attrs;\n" + " egr::RunOp(\"%s\", ins, outs, attrs, \n" + " egr::Controller::Instance().GetExpectedPlace(),\n" + " &default_attrs, true, {});\n"; + std::string trace_op_str = + paddle::string::Sprintf(FWD_TRACE_OP_TEMPLATE, op_proto.type()); + generated_function_body += trace_op_str; + generated_function_body += "\n"; + + VLOG(6) << "Generated AttrMap & TraceOp"; + + // [Generation] Convert output VarBase to Vector/Tensor + size_t output_size = op_proto.outputs().size(); + std::vector return_contents(output_size); + std::vector return_types(output_size); + for (const proto::OpProto::Var& output : op_proto.outputs()) { + const std::string& output_name = output.name(); + std::string out_tensor_str; + size_t return_position = fwd_outputs_name_pos_map.at(output_name); + + if (output.duplicable()) { + const char* FWD_OUT_TENSORS_TEMPLATE = + " std::vector %s = " + "egr::GetOutputs(outs[\"%s\"]);\n"; + out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE, + output_name, output_name); + return_types[return_position] = "std::vector"; + } else { + const char* FWD_OUT_TENSOR_TEMPLATE = + " egr::EagerTensor %s = " + "egr::GetOutput(outs[\"%s\"][0]);\n"; + out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, + output_name, output_name); + return_types[return_position] = "egr::EagerTensor"; + } + + return_contents[return_position] = output_name; + generated_function_body += out_tensor_str; + } + generated_function_body += "\n"; + VLOG(6) << "Converted Output VarBase to EagerTensor(s)"; + + // [Generation] ComputeRequireGrad -> GradNodeCreation + std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( + grad_node_default_attr_maps, fwd_inputs_name_pos_map, + fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, op_proto); + generated_function_body += grad_node_creation_body_str; + generated_function_body += "\n"; + VLOG(6) << "Generated GradNode Creation codes"; + + // [Generation] Handle return: Tuple/Vector/Tensor + generated_function_body += "\n"; + std::string return_str; + std::string return_type_str = ""; + std::string function_proto_return_type_str = ""; + if (return_contents.size() > 1) { + // Return tuple + std::string return_content_str = ""; + for (const std::string& s : return_contents) { + return_content_str += s + ","; + } + return_content_str.pop_back(); // Remove trailing "," + + for (const std::string& s : return_types) { + return_type_str += s + ","; + } + return_type_str.pop_back(); // Remove trailing "," + + const char* FWD_TUPLE_RETURN_TEMPLATE = " return std::make_tuple(%s);"; + return_str = + paddle::string::Sprintf(FWD_TUPLE_RETURN_TEMPLATE, return_content_str); + + const char* FWD_FUNCTION_PROTO_RETURN_TEMPLATE = "std::tuple<%s>"; + function_proto_return_type_str = paddle::string::Sprintf( + FWD_FUNCTION_PROTO_RETURN_TEMPLATE, return_type_str); + } else { + // Return vector or Tensor + return_type_str = return_types[0]; + const char* FWD_TENSOR_RETURN_TEMPLATE = " return %s;"; + return_str = + paddle::string::Sprintf(FWD_TENSOR_RETURN_TEMPLATE, return_contents[0]); + function_proto_return_type_str = return_type_str; + } + generated_function_body += return_str; + generated_function_body += "\n"; + VLOG(6) << "Generated return codes"; + + // [Generation] Get Full Function + std::string function_name = op_type + "_dygraph_function"; + + const char* FWD_FUNCTION_TEMPLATE = "%s %s(%s) {\n\n%s\n}\n\n"; + std::string fwd_function_str = paddle::string::Sprintf( + FWD_FUNCTION_TEMPLATE, function_proto_return_type_str, function_name, + dygraph_function_args_str, generated_function_body); + + // [Generation] Append USE_OP + fwd_function_str += AppendUseOp(op_type); + + // [Generation] Generate forward functions header + const char* FWD_HEADER_TEMPLATE = "%s %s(%s);\n"; + std::string dygraph_function_declaration_str = paddle::string::Sprintf( + FWD_HEADER_TEMPLATE, function_proto_return_type_str, function_name, + dygraph_function_args_str); + + return {fwd_function_str, dygraph_function_declaration_str}; +} + +/* ---------------------------------------------- */ +/* --------- CodeGen: GradNode::operator() ------ */ +/* ---------------------------------------------- */ +static std::string GenerateGradNodeCCContents( + const std::vector& + grad_node_default_attr_maps, + const std::vector& grad_op_types, + const std::unordered_map& fwd_inputs_name_pos_map, + const std::unordered_map& fwd_outputs_name_pos_map, + const std::map& grad_ins_fwd_slotname_map, + const std::map& grad_ins_grad_slotname_map, + const std::map& grad_outs_slotname_map, + const std::map< + std::string, + std::vector>>& + grad_ins, + const std::map< + std::string, + std::vector>>& + grad_outs, + const proto::OpProto& op_proto) { + VLOG(6) << "Generating Grad Node CC"; + + /* [Outline] + + vector> GradNodeXXX::operator()(vector>& grads) + { + + const std::shared_ptr& tracer = imperative::GetCurrentTracer(); + + // Comes from "grad_ins" + std::map>> ins = + { + "X" : this->"X", "Y" : this->"Y", + "Out0@Grad": + SyncToVars(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out0@Grad"]]"]), + "Out1@Grad": + TensorsToVarBases(grads["fwd_outputs_name_pos_map[grad_ins_grad_slotname_map["Out1@Grad"]]"]) + }; + + // Comes from "grad_outs" + std::map>> outs = + { + "X@Grad" : + ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["X@Grad"]]"].Size()), + "Y@Grad" : + ConstructDuplicableOutput(this->OutputMeta()["fwd_inputs_name_pos_map[grad_outs_slotname_map["Y@Grad"]]"].Size()) + }; + + // Visit each OpBase + for(auto iter = "grad_node->begin()"; iter < "grad_node->end()"; iter++) { + // Simply pass entire attribute map to kernels + egr::RunOp("iter->Type()", ins, outs, this->attr_map_, + egr::Controller::Instance().ExpectedPlace(), false, {}); + } + + vector> outputs(outs.size()); + for(auto& kv : outs) { + outputs["fwd_inputs_name_pos_map[grad_outs_slotname_map[kv.first]]"] = + GetOutputs(outs["kv.first"]); + } + + return outputs; + } + */ + + const std::string& op_type = op_proto.type(); + std::string generated_grad_function_body = ""; + + // [Generation] Get Tracer + generated_grad_function_body += "\n"; + generated_grad_function_body += "\n"; + + // [Generation] Get Ins Map + std::string ins_contents_str = ""; + for (auto iter : grad_ins) { + const std::string& grad_input_name = iter.first; + + if (grad_ins_fwd_slotname_map.count(grad_input_name)) { + // Fwd Tensor + std::string struct_fwd_input_name = + grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; + const char* GRAD_INS_FWD_CONTENT_TEMPLATE = + "{ \"%s\", " + "egr::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&this->%s, " + "nullptr)) },"; + ins_contents_str += + paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, + grad_input_name, struct_fwd_input_name); + + } else if (grad_ins_grad_slotname_map.count(grad_input_name)) { + // Fwd Tensor's Grad + size_t fwd_output_position = fwd_outputs_name_pos_map.at( + grad_ins_grad_slotname_map.at(grad_input_name)); + const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = + "{ \"%s\", egr::SyncToVars(grads[%d]) },"; + ins_contents_str += paddle::string::Sprintf( + GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); + + } else { + PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." + "Unable to find forward slot name that matches %s", + grad_input_name)); + } + } + if (ins_contents_str.size() > 0) + ins_contents_str.pop_back(); // // Remove trailing "," + + const char* BWD_INS_MAP_TEMPLATE = + " std::map>> ins = { " + "%s };\n"; + std::string ins_map_str = + paddle::string::Sprintf(BWD_INS_MAP_TEMPLATE, ins_contents_str); + generated_grad_function_body += ins_map_str; + + VLOG(6) << "Generated Ins Map"; + + // [Generation] Get Outs Map + std::unordered_set duplicable_input_name_set; + for (const auto& in : op_proto.inputs()) { + if (in.duplicable()) duplicable_input_name_set.insert(in.name()); + } + + std::string outs_contents_str = ""; + for (auto iter : grad_outs) { + const std::string& grad_output_name = iter.first; + + if (grad_outs_slotname_map.count(grad_output_name)) { + // Fwd Tensor + const std::string& fwd_input_name = + grad_outs_slotname_map.at(grad_output_name); + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_input_name); + + if (duplicable_input_name_set.count(fwd_input_name)) { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::ConstructDuplicableOutput( " + "this->OutputMeta()[%d].Size() ) },"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); + } else { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance()." + "GenerateUniqueName())}},"; + outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, + grad_output_name); + } + } else { + PADDLE_THROW(platform::errors::Fatal( + "Detected mismatched slot names." + "Unable to find forward slot name that matches %s", + grad_output_name)); + } + } + if (outs_contents_str.size() > 0) + outs_contents_str.pop_back(); // // Remove trailing "," + + const char* BWD_OUTS_MAP_TEMPLATE = + " std::map>> outs = { " + "%s };\n"; + std::string outs_map_str = + paddle::string::Sprintf(BWD_OUTS_MAP_TEMPLATE, outs_contents_str); + generated_grad_function_body += outs_map_str; + generated_grad_function_body += "\n"; + + VLOG(6) << "Generated Outs Map"; + + // [Generation] Get Attrs Map + std::string trace_opbase_str = ""; + for (size_t i = 0; i < grad_node_default_attr_maps.size(); i++) { + const std::string& op_base_type = grad_op_types[i]; + + const char* TRACE_OP_TEMPLATE = + " // Pass the entire attribute map to TraceOp\n" + " // The underlying kernel will pickup whatever attribute they need " + "at runtime\n" + " egr::RunOp(\"%s\", ins, outs, this->attr_map_,\n" + " egr::Controller::Instance().GetExpectedPlace(),\n" + " &this->default_attr_map_, false, {});\n"; + trace_opbase_str = paddle::string::Sprintf(TRACE_OP_TEMPLATE, op_base_type); + } + + generated_grad_function_body += trace_opbase_str; + + VLOG(6) << "Generated Attrs Map"; + + // [Generation] Get Return + std::string outputs_str = ""; + for (auto iter : grad_outs) { + const std::string& grad_out_name = iter.first; + size_t fwd_input_position = + fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name)); + + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = GetOutputs(outs[\"%s\"]);\n"; + outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, + fwd_input_position, grad_out_name); + } + + const char* BWD_RETURN_TEMPLATE = + " std::vector> " + "outputs(outs.size());\n%s\n " + "return outputs;"; + std::string return_str = + paddle::string::Sprintf(BWD_RETURN_TEMPLATE, outputs_str); + + generated_grad_function_body += "\n"; + generated_grad_function_body += return_str; + + // [Generation] Get Full Grad Function + const char* GRAD_FUNCTION_TEMPLATE = + "std::vector> " + "GradNode%s::operator()(const " + "std::vector>& grads) {\n%s\n}"; + std::string grad_function_str = paddle::string::Sprintf( + GRAD_FUNCTION_TEMPLATE, op_type, generated_grad_function_body); + + VLOG(6) << "Generated returns"; + + return grad_function_str; +} + +/* ----------------------------------------- */ +/* --------- CodeGen: GradNode Header ------ */ +/* ----------------------------------------- */ +static std::string GenerateGradNodeHeaderContents( + const std::vector& + grad_node_default_attr_maps, + const std::map& grad_ins_fwd_slotname_map, + const proto::OpProto& op_proto) { + VLOG(6) << "Generating Grad Node Header"; + + const char* GRAD_NODE_TEMPLATE = + "class GradNode%s : public egr::GradNodeBase {\n" + " public:\n" + " GradNode%s() : egr::GradNodeBase() {}\n" + " GradNode%s(size_t bwd_in_slot_num, size_t bwd_out_slot_num) : " + "egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {}\n" + " ~GradNode%s() override = default;\n" + "\n" + " virtual std::vector> " + "operator()(const " + "std::vector>& grads) " + "override;\n" + "\n" + " // SetX, SetY, ...\n" + "%s\n" + " // SetAttrMap\n" + "%s\n" + "\n" + " private:\n" + " // TensorWrappers\n" + "%s\n" + " // Attribute Map\n" + "%s\n" + "};"; + + const std::string& op_type = op_proto.type(); + + // [Generation] Handle Attributes + std::string set_attr_map_str = + " void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n " + "attr_map_ = std::move(attr_map);\n }\n"; + set_attr_map_str += + " void SetDefaultAttrMap(paddle::framework::AttributeMap&& " + "default_attr_map) {\n default_attr_map_ = " + "std::move(default_attr_map);\n }\n"; + std::string attr_members_str = + " paddle::framework::AttributeMap attr_map_;\n"; + attr_members_str += " paddle::framework::AttributeMap default_attr_map_;"; + + VLOG(6) << "Generated SetAttr"; + + // [Generation] Handle TensorWrappers + std::unordered_set duplicable_tensors; + for (const proto::OpProto::Var& input : op_proto.inputs()) { + if (input.duplicable()) { + duplicable_tensors.insert(input.name()); + } + } + for (const proto::OpProto::Var& output : op_proto.outputs()) { + if (output.duplicable()) { + duplicable_tensors.insert(output.name()); + } + } + + std::string set_tensor_wrappers_str = ""; + std::string tensor_wrapper_members_str = ""; + for (const auto& kv : grad_ins_fwd_slotname_map) { + const std::string& tensor_wrapper_name = kv.second; + const std::string& struct_tensor_wrapper_name = kv.second + "_"; + + std::string tensor_wrapper_arg_str; + std::string tensor_wrapper_body_str; + if (duplicable_tensors.count(tensor_wrapper_name)) { + const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = + "const std::vector& %s"; + tensor_wrapper_arg_str = paddle::string::Sprintf( + ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); + + const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = + " std::vector %s;\n"; + tensor_wrapper_members_str += paddle::string::Sprintf( + TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); + + const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = + "for(const auto& eager_tensor : %s) {\n" + " %s.emplace_back( egr::TensorWrapper(eager_tensor, true " + "/*full_reserved*/) );\n" + " }\n"; + tensor_wrapper_body_str = paddle::string::Sprintf( + SET_TENSOR_WRAPPER_BODY_TEMPLATE, tensor_wrapper_name, + struct_tensor_wrapper_name); + + } else { + const char* ATTR_TENSOR_WRAPPER_ARG_TEMPLATE = + "const egr::EagerTensor& %s"; + tensor_wrapper_arg_str = paddle::string::Sprintf( + ATTR_TENSOR_WRAPPER_ARG_TEMPLATE, tensor_wrapper_name); + + const char* TENSOR_WRAPPER_MEMBER_TEMPLATE = + " egr::TensorWrapper %s;\n"; + tensor_wrapper_members_str += paddle::string::Sprintf( + TENSOR_WRAPPER_MEMBER_TEMPLATE, struct_tensor_wrapper_name); + + const char* SET_TENSOR_WRAPPER_BODY_TEMPLATE = + "%s = egr::TensorWrapper(%s, true /*full_reserved*/);"; + tensor_wrapper_body_str = paddle::string::Sprintf( + SET_TENSOR_WRAPPER_BODY_TEMPLATE, struct_tensor_wrapper_name, + tensor_wrapper_name); + } + + const char* SET_TENSOR_WRAPPER_TEMPLATE = + " void SetTensorWrapper%s(%s) {\n %s\n }\n"; + set_tensor_wrappers_str += paddle::string::Sprintf( + SET_TENSOR_WRAPPER_TEMPLATE, tensor_wrapper_name, + tensor_wrapper_arg_str, tensor_wrapper_body_str); + } + VLOG(6) << "Generated TensorWrapper"; + + std::string grad_node_str = paddle::string::Sprintf( + GRAD_NODE_TEMPLATE, op_type, op_type, op_type, op_type, + set_tensor_wrappers_str, set_attr_map_str, tensor_wrapper_members_str, + attr_members_str); + + return grad_node_str; +} + +/* --------------------------------- */ +/* --------- FileGeneration --------- */ +/* ---------------------------------- */ +static void GenerateForwardHFile(const std::string& output_dir, + const std::string& dygraph_forward_api_str) { + std::string dygraph_forward_api_path = output_dir + "/dygraph_forward_api.h"; + std::ofstream forward_header_stream(dygraph_forward_api_path, std::ios::out); + forward_header_stream << dygraph_forward_api_str; + forward_header_stream.close(); +} + +static void GenerateForwardDygraphFile(const std::string& op_type, + const std::string& output_dir, + const std::string& fwd_function_str) { + std::string forwards_dir = output_dir + "/forwards/"; + std::string node_h_filename = op_type + "_node.h"; + std::string forward_cc_filename = op_type + "_dygraph.cc"; + std::string forward_cc_path = forwards_dir + forward_cc_filename; + const char* FORWARD_INCLUDE_TEMPLATE = + "#include " + "\"paddle/fluid/eager/api/generated/fluid_generated/" + "dygraph_forward_api.h\"\n" + "#include " + "\"paddle/fluid/eager/api/generated/fluid_generated/nodes/%s\"\n\n" + "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" + "#include \"paddle/fluid/eager/legacy/op_runner.h\"\n"; + std::string forward_cc_include_str = + paddle::string::Sprintf(FORWARD_INCLUDE_TEMPLATE, node_h_filename); + std::ofstream forward_cc_stream(forward_cc_path, std::ios::out); + forward_cc_stream << forward_cc_include_str; + forward_cc_stream << fwd_function_str; + forward_cc_stream.close(); +} + +static void GenerateNodeHFile(const std::string& op_type, + const std::string& output_dir, + const std::string& grad_node_str) { + std::string nodes_dir = output_dir + "/nodes/"; + std::string node_h_filename = op_type + "_node.h"; + std::string node_h_path = nodes_dir + node_h_filename; + std::string node_h_include_str = + "#pragma once\n" + "#include \"paddle/fluid/eager/tensor_wrapper.h\"\n" + "#include \"paddle/fluid/eager/legacy/op_runner.h\"\n" + "#include \"paddle/fluid/eager/grad_node_info.h\"\n\n"; + std::ofstream node_h_stream(node_h_path, std::ios::out); + node_h_stream << node_h_include_str; + node_h_stream << grad_node_str; + node_h_stream.close(); +} + +static void GenerateNodeCCFile(const std::string& op_type, + const std::string& output_dir, + const std::string& grad_function_str) { + std::string nodes_dir = output_dir + "/nodes/"; + std::string node_h_filename = op_type + "_node.h"; + std::string node_cc_filename = op_type + "_node.cc"; + std::string node_cc_path = nodes_dir + node_cc_filename; + const char* NODE_CC_INCLUDE_TEMPLATE = + "#include \"glog/logging.h\"\n" + "#include \"paddle/pten/api/all.h\"\n" + "#include \"paddle/fluid/imperative/tracer.h\"\n" + "#include \"paddle/fluid/framework/op_registry.h\"\n" + "#include \"paddle/fluid/eager/utils.h\"\n" + "#include \"paddle/fluid/eager/api/utils/global_utils.h\"\n" + "#include " + "\"paddle/fluid/eager/api/generated/fluid_generated/nodes/%s\"\n\n"; + std::string node_cc_include_str = + paddle::string::Sprintf(NODE_CC_INCLUDE_TEMPLATE, node_h_filename); + std::ofstream node_cc_stream(node_cc_path, std::ios::out); + node_cc_stream << node_cc_include_str; + node_cc_stream << grad_function_str; + node_cc_stream.close(); +} + +static std::string GenerateDygraphHFileIncludes() { + std::string dygraph_forward_api_includes_str = + "#pragma once\n" + "#include \"glog/logging.h\"\n" + "#include \"paddle/fluid/eager/autograd_meta.h\"\n" + "#include \"paddle/pten/api/all.h\"\n" + "#include \"paddle/fluid/eager/utils.h\"\n" + "#include \"paddle/fluid/framework/op_registry.h\"\n\n"; + + return dygraph_forward_api_includes_str; +} + +static void DygraphCodeGeneration(const std::string& output_dir) { + std::string dygraph_forward_api_str = GenerateDygraphHFileIncludes(); + + auto& op_info_map = paddle::framework::OpInfoMap::Instance().map(); + + for (auto& pair : op_info_map) { + const OpInfo& op_info = pair.second; + proto::OpProto* op_proto = op_info.proto_; + + if (!CheckOpProto(op_proto)) continue; + const std::string& op_type = op_proto->type(); + + /* ----------------------------- */ + /* ---- Collect Information ---- */ + /* ----------------------------- */ + std::vector grad_node_default_attr_maps; + std::vector grad_op_types; + std::unordered_map fwd_inputs_name_pos_map; + std::unordered_map fwd_outputs_name_pos_map; + std::map grad_outs_slotname_map; + std::map grad_ins_fwd_slotname_map; + std::map grad_ins_grad_slotname_map; + std::map>> + grad_ins; + std::map>> + grad_outs; + + VLOG(6) << "-------- CollectInformationFromOpInfo -------"; + bool is_available = CollectInformationFromOpInfo( + op_info, &grad_node_default_attr_maps, &grad_op_types, + &fwd_inputs_name_pos_map, &fwd_outputs_name_pos_map, + &grad_outs_slotname_map, &grad_ins_fwd_slotname_map, + &grad_ins_grad_slotname_map, &grad_ins, &grad_outs); + + if (!is_available) continue; + + /* --------------------------- */ + /* --------- CodeGen --------- */ + /* --------------------------- */ + /* ---- xxx_dygraph.cc ---- */ + VLOG(6) << "-------- GenerateForwardFunctionContents -------"; + std::pair body_and_declaration = + GenerateForwardFunctionContents( + grad_node_default_attr_maps, fwd_inputs_name_pos_map, + fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, + grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, + grad_outs, *op_proto); + std::string fwd_function_str = body_and_declaration.first; + GenerateForwardDygraphFile(op_type, output_dir, fwd_function_str); + + /* ---- dygraph_forward_api.h ---- */ + std::string fwd_function_declare_str = body_and_declaration.second; + dygraph_forward_api_str += fwd_function_declare_str; + + /* ---- xxx_node.h ---- */ + VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; + std::string grad_node_h_str = GenerateGradNodeHeaderContents( + grad_node_default_attr_maps, grad_ins_fwd_slotname_map, *op_proto); + GenerateNodeHFile(op_type, output_dir, grad_node_h_str); + + /* ---- xxx_node.cc ---- */ + VLOG(6) << "-------- GenerateGradNodeCCContents -------"; + std::string grad_node_cc_str = GenerateGradNodeCCContents( + grad_node_default_attr_maps, grad_op_types, fwd_inputs_name_pos_map, + fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, + grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs, + *op_proto); + GenerateNodeCCFile(op_type, output_dir, grad_node_cc_str); + + VLOG(6) << op_type << ": Finished Generation"; + } + + /* ---- dygraph_forward_api.h ---- */ + VLOG(6) << "-------- GenerateForwardHFile -------"; + GenerateForwardHFile(output_dir, dygraph_forward_api_str); +} + +int main(int argc, char* argv[]) { + if (argc != 2) { + std::cerr << "argc must be 2" << std::endl; + return -1; + } + + std::string eager_root = argv[1]; + paddle::framework::DygraphCodeGeneration(eager_root); + + return 0; +} + } // namespace framework } // namespace paddle -- GitLab