diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index c0714775da852cd7f13dfce8203c946455b502e8..136eaebe2cc4bfe25c507878240efdefb4341f19 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/variable.h" +#include "paddle/fluid/pybind/op_function_generator.h" #include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/string/string_helper.h" @@ -358,18 +359,149 @@ static bool CheckOpProto(proto::OpProto* op_proto) { return true; } +/* --------------------------------------- */ +/* --------- Preprocess Ins/Outs --------- */ +/* --------------------------------------- */ +static void PurifyOpProto( + const proto::OpProto& op_proto, + std::unordered_map* fwd_inputs_name_pos_map, + std::unordered_map* fwd_outputs_name_pos_map, + std::map* grad_outs_slotname_map, + std::map* grad_ins_fwd_slotname_map, + std::map* grad_ins_grad_slotname_map, + std::vector* in_vars, + std::vector* out_vars, + std::map>>* + grad_ins, + std::map>>* + grad_outs) { + // Op Name + const std::string op_name = op_proto.type(); + + // Handle dispensable inputs + for (const proto::OpProto::Var& input : op_proto.inputs()) { + std::string input_name = input.name(); + + // Delete dispensable tensor unless specified in op_ins_map + if (input.dispensable()) { + if (!op_ins_map.count(op_name) || + !op_ins_map[op_name].count(input_name)) { + VLOG(6) << "Removing Dispensable Input: " << input_name; + + // in_vars + auto iter = in_vars->begin(); + for (iter = in_vars->begin(); iter != in_vars->end(); iter++) { + if (iter->name() == input_name) { + break; + } + } + in_vars->erase(iter); + + // grad_outs_slotname_map + auto grad_outs_slotname_map_purified = *grad_outs_slotname_map; + for (const auto& iter : *grad_outs_slotname_map) { + const std::string& grad_output_name = iter.first; + const std::string& matched_input_name = iter.second; + if (matched_input_name == input_name) { + grad_outs_slotname_map_purified.erase(grad_output_name); + + PADDLE_ENFORCE( + grad_outs->count(grad_output_name) > 0, + paddle::platform::errors::Fatal( + "Unable to find gradient output name in grad_outs.")); + // grad_outs + grad_outs->erase(grad_output_name); + } + } + *grad_outs_slotname_map = grad_outs_slotname_map_purified; + + // grad_ins_fwd_slotname_map: output as tensorwrapper + if (grad_ins_fwd_slotname_map->count(input_name)) + grad_ins_fwd_slotname_map->erase(input_name); + + // grad_ins: output as tensorwrapper + if (grad_ins->count(input_name)) grad_ins->erase(input_name); + } + } + } + + for (const proto::OpProto::Var& output : op_proto.outputs()) { + std::string output_name = output.name(); + + // Delete dispensable tensor unless specified in op_outs_map + if (output.dispensable()) { + if (!op_outs_map.count(op_name) || + !op_outs_map[op_name].count(output_name)) { + VLOG(6) << "Removing Dispensable Output: " << output_name; + + // out_vars + auto iter = out_vars->begin(); + for (iter = out_vars->begin(); iter != out_vars->end(); iter++) { + if (iter->name() == output_name) { + break; + } + } + out_vars->erase(iter); + + // grad_ins_grad_slotname_map + auto grad_ins_grad_slotname_map_purified = *grad_ins_grad_slotname_map; + for (const auto& iter : *grad_ins_grad_slotname_map) { + const std::string& grad_input_name = iter.first; + const std::string& matched_output_name = iter.second; + if (matched_output_name == output_name) { + grad_ins_grad_slotname_map_purified.erase(grad_input_name); + + PADDLE_ENFORCE( + grad_ins->count(grad_input_name) > 0, + paddle::platform::errors::Fatal( + "Unable to find gradient input name in grad_ins.")); + // grad_ins + grad_ins->erase(grad_input_name); + } + } + *grad_ins_grad_slotname_map = grad_ins_grad_slotname_map_purified; + + // grad_ins_fwd_slotname_map: output as tensorwrapper + if (grad_ins_fwd_slotname_map->count(output_name)) + grad_ins_fwd_slotname_map->erase(output_name); + + // grad_ins: output as tensorwrapper + if (grad_ins->count(output_name)) grad_ins->erase(output_name); + } + } + } + + /* ------ Maping forward slot name to fwd position ------ */ + size_t in_pos = 0; + for (const auto& var : *in_vars) { + VLOG(6) << "Mapping input tensor: " << var.name() + << " To position: " << in_pos; + (*fwd_inputs_name_pos_map)[var.name()] = in_pos; + in_pos++; + } + + size_t out_pos = 0; + for (const auto& var : *out_vars) { + VLOG(6) << "Mapping output tensor: " << var.name() + << " To position: " << out_pos; + (*fwd_outputs_name_pos_map)[var.name()] = out_pos; + out_pos++; + } +} + /* -------------------------------- */ /* --------- Collect Info --------- */ /* -------------------------------- */ static bool CollectInformationFromOpInfo( const paddle::framework::OpInfo& op_info, - std::vector* grad_node_default_attr_maps, std::vector* grad_op_types, - std::unordered_map* fwd_inputs_name_pos_map, - std::unordered_map* fwd_outputs_name_pos_map, std::map* grad_outs_slotname_map, std::map* grad_ins_fwd_slotname_map, std::map* grad_ins_grad_slotname_map, + std::vector* in_vars, + std::vector* out_vars, std::map>>* grad_ins, @@ -380,6 +512,13 @@ static bool CollectInformationFromOpInfo( const std::string& op_type = op_proto.type(); std::vector dims = {1, 1, 1, 1}; + for (const proto::OpProto::Var& input : op_proto.inputs()) { + in_vars->push_back(input); + } + for (const proto::OpProto::Var& output : op_proto.outputs()) { + out_vars->push_back(output); + } + /* ------ Prepare "ins" ------ */ std::map>> @@ -494,7 +633,6 @@ static bool CollectInformationFromOpInfo( for (auto iter = grad_node->begin(); iter < grad_node->end(); iter++) { // Each OpBase paddle::imperative::OpBase& op_base = *iter; - grad_node_default_attr_maps->push_back(op_base.DefaultAttrsMap()); grad_op_types->push_back(op_base.Type()); } @@ -538,22 +676,6 @@ static bool CollectInformationFromOpInfo( grad_outs_slotname_map); VLOG(6) << "Finished Slotname Matching for Grad_Outs"; - /* ------ Maping forward slot name to fwd position ------ */ - size_t in_pos = 0; - for (const auto& iter : ins) { - VLOG(6) << "Mapping input tensor: " << iter.first - << " To position: " << in_pos; - (*fwd_inputs_name_pos_map)[iter.first] = in_pos; - in_pos++; - } - size_t out_pos = 0; - for (const auto& iter : outs) { - VLOG(6) << "Mapping output tensor: " << iter.first - << " To position: " << out_pos; - (*fwd_outputs_name_pos_map)[iter.first] = out_pos; - out_pos++; - } - return true; } @@ -561,16 +683,13 @@ static bool CollectInformationFromOpInfo( /* --------- CodeGen: Forward GradNode Creation ------ */ /* --------------------------------------------------- */ static std::string GenerateGradNodeCreationContent( - const std::vector& - grad_node_default_attr_maps, const std::unordered_map& fwd_inputs_name_pos_map, const std::unordered_map& fwd_outputs_name_pos_map, const std::map& grad_ins_fwd_slotname_map, - const proto::OpProto& op_proto) { + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { VLOG(6) << "Generating GradNode Creation codes"; - const std::string& op_type = op_proto.type(); - // [Generation] Construct GradOpNode // Run ComputeRequiredGrad @@ -578,7 +697,7 @@ static std::string GenerateGradNodeCreationContent( // then generate: "egr::AutogradMeta* p_autograd_out = // egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")" std::string get_autograd_meta_str = " // Prepare Autograd Meta \n"; - for (const proto::OpProto::Var& input : op_proto.inputs()) { + for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); const std::string& input_autograd_name = "p_autograd_" + input_name; @@ -602,7 +721,7 @@ static std::string GenerateGradNodeCreationContent( // If single output slotname and not duplicable, // then generate: "egr::AutogradMeta* p_autograd_out = // egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")" - for (const proto::OpProto::Var& output : op_proto.outputs()) { + for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); const std::string& output_autograd_name = "p_autograd_" + output_name; @@ -636,8 +755,8 @@ static std::string GenerateGradNodeCreationContent( // [GradOpNode] Generation std::string grad_node_creation_str = ""; - size_t bwd_in_slot_num = op_proto.outputs().size(); - size_t bwd_out_slot_num = op_proto.inputs().size(); + size_t bwd_in_slot_num = out_vars.size(); + size_t bwd_out_slot_num = in_vars.size(); const char* GRAD_OP_NODE_TEMPLATE = " auto grad_node = std::make_shared(%d, %d);\n"; grad_node_creation_str += " // Create GradOpNode\n"; @@ -669,7 +788,7 @@ static std::string GenerateGradNodeCreationContent( // [GradOpNode] SetGradOutMeta // [GradOpNode] Add Edges std::string compute_require_grad_args = "trace_backward"; - for (const proto::OpProto::Var& input : op_proto.inputs()) { + for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); const std::string& input_autograd_name = "p_autograd_" + input_name; compute_require_grad_args += ", &" + input_autograd_name; @@ -689,7 +808,7 @@ static std::string GenerateGradNodeCreationContent( // [AutogradMeta] SetOutRank // [AutogradMeta] SetHistory std::string pass_stop_gradient_args = "false"; - for (const proto::OpProto::Var& output : op_proto.outputs()) { + for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); const std::string& output_autograd_name = "p_autograd_" + output_name; pass_stop_gradient_args += ", &" + output_autograd_name; @@ -743,8 +862,6 @@ static std::string AppendUseOp(const std::string& op_type) { /* --------- CodeGen: Forward ----- */ /* -------------------------------- */ static std::pair GenerateForwardFunctionContents( - const std::vector& - grad_node_default_attr_maps, const std::unordered_map& fwd_inputs_name_pos_map, const std::unordered_map& fwd_outputs_name_pos_map, const std::map& grad_ins_fwd_slotname_map, @@ -758,7 +875,8 @@ static std::pair GenerateForwardFunctionContents( std::string, std::vector>>& grad_outs, - const proto::OpProto& op_proto) { + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { /* // Forward Function Example: std::tuple, Tensor, vector> @@ -779,6 +897,7 @@ static std::pair GenerateForwardFunctionContents( ,ConstructDuplicableOutput(Out1Num)} }; // According to op_proto->attrs() + egr::legacy::RunOp("op_type", ins, outs, attr_map, Controller.Instance().GetExpectedPlace(), {}); @@ -795,8 +914,6 @@ static std::pair GenerateForwardFunctionContents( */ VLOG(6) << "Generating Dygraph Forward Function"; - const std::string& op_type = op_proto.type(); - std::string generated_function_body = ""; std::string dygraph_function_args_str = ""; @@ -806,8 +923,8 @@ static std::pair GenerateForwardFunctionContents( // [Generation] Get Ins Map std::string ins_contents_str = ""; - std::vector input_args_str_list(op_proto.inputs().size()); - for (const proto::OpProto::Var& input : op_proto.inputs()) { + std::vector input_args_str_list(in_vars.size()); + for (const proto::OpProto::Var& input : in_vars) { const std::string& input_name = input.name(); size_t input_position = fwd_inputs_name_pos_map.at(input_name); if (input.duplicable()) { @@ -848,7 +965,7 @@ static std::pair GenerateForwardFunctionContents( // [Generation] Get Outs Map std::string outs_contents_str = ""; - for (const proto::OpProto::Var& output : op_proto.outputs()) { + for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); std::string outnum = "1"; if (output.duplicable()) { @@ -898,17 +1015,17 @@ static std::pair GenerateForwardFunctionContents( " egr::Controller::Instance().GetExpectedPlace(),\n" " &default_attrs, true, {});\n"; std::string trace_op_str = - paddle::string::Sprintf(FWD_TRACE_OP_TEMPLATE, op_proto.type()); + paddle::string::Sprintf(FWD_TRACE_OP_TEMPLATE, op_type); generated_function_body += trace_op_str; generated_function_body += "\n"; VLOG(6) << "Generated AttrMap & TraceOp"; // [Generation] Convert output VarBase to Vector/Tensor - size_t output_size = op_proto.outputs().size(); + size_t output_size = out_vars.size(); std::vector return_contents(output_size); std::vector return_types(output_size); - for (const proto::OpProto::Var& output : op_proto.outputs()) { + for (const proto::OpProto::Var& output : out_vars) { const std::string& output_name = output.name(); std::string out_tensor_str; size_t return_position = fwd_outputs_name_pos_map.at(output_name); @@ -937,8 +1054,8 @@ static std::pair GenerateForwardFunctionContents( // [Generation] ComputeRequireGrad -> GradNodeCreation std::string grad_node_creation_body_str = GenerateGradNodeCreationContent( - grad_node_default_attr_maps, fwd_inputs_name_pos_map, - fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, op_proto); + fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, + grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); generated_function_body += grad_node_creation_body_str; generated_function_body += "\n"; VLOG(6) << "Generated GradNode Creation codes"; @@ -1004,8 +1121,6 @@ static std::pair GenerateForwardFunctionContents( /* --------- CodeGen: GradNode::operator() ------ */ /* ---------------------------------------------- */ static std::string GenerateGradNodeCCContents( - const std::vector& - grad_node_default_attr_maps, const std::vector& grad_op_types, const std::unordered_map& fwd_inputs_name_pos_map, const std::unordered_map& fwd_outputs_name_pos_map, @@ -1020,7 +1135,8 @@ static std::string GenerateGradNodeCCContents( std::string, std::vector>>& grad_outs, - const proto::OpProto& op_proto) { + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { VLOG(6) << "Generating Grad Node CC"; /* [Outline] @@ -1066,7 +1182,6 @@ static std::string GenerateGradNodeCCContents( } */ - const std::string& op_type = op_proto.type(); std::string generated_grad_function_body = ""; // [Generation] Get Tracer @@ -1122,7 +1237,7 @@ static std::string GenerateGradNodeCCContents( // [Generation] Get Outs Map std::unordered_set duplicable_input_name_set; - for (const auto& in : op_proto.inputs()) { + for (const auto& in : in_vars) { if (in.duplicable()) duplicable_input_name_set.insert(in.name()); } @@ -1173,7 +1288,7 @@ static std::string GenerateGradNodeCCContents( // [Generation] Get Attrs Map std::string trace_opbase_str = ""; - for (size_t i = 0; i < grad_node_default_attr_maps.size(); i++) { + for (size_t i = 0; i < grad_op_types.size(); i++) { const std::string& op_base_type = grad_op_types[i]; const char* TRACE_OP_TEMPLATE = @@ -1230,10 +1345,9 @@ static std::string GenerateGradNodeCCContents( /* --------- CodeGen: GradNode Header ------ */ /* ----------------------------------------- */ static std::string GenerateGradNodeHeaderContents( - const std::vector& - grad_node_default_attr_maps, const std::map& grad_ins_fwd_slotname_map, - const proto::OpProto& op_proto) { + const std::string& op_type, const std::vector& in_vars, + const std::vector& out_vars) { VLOG(6) << "Generating Grad Node Header"; const char* GRAD_NODE_TEMPLATE = @@ -1261,8 +1375,6 @@ static std::string GenerateGradNodeHeaderContents( "%s\n" "};"; - const std::string& op_type = op_proto.type(); - // [Generation] Handle Attributes std::string set_attr_map_str = " void SetAttrMap(paddle::framework::AttributeMap&& attr_map) {\n " @@ -1279,12 +1391,12 @@ static std::string GenerateGradNodeHeaderContents( // [Generation] Handle TensorWrappers std::unordered_set duplicable_tensors; - for (const proto::OpProto::Var& input : op_proto.inputs()) { + for (const proto::OpProto::Var& input : in_vars) { if (input.duplicable()) { duplicable_tensors.insert(input.name()); } } - for (const proto::OpProto::Var& output : op_proto.outputs()) { + for (const proto::OpProto::Var& output : out_vars) { if (output.duplicable()) { duplicable_tensors.insert(output.name()); } @@ -1454,13 +1566,12 @@ static void DygraphCodeGeneration(const std::string& output_dir) { /* ----------------------------- */ /* ---- Collect Information ---- */ /* ----------------------------- */ - std::vector grad_node_default_attr_maps; std::vector grad_op_types; - std::unordered_map fwd_inputs_name_pos_map; - std::unordered_map fwd_outputs_name_pos_map; std::map grad_outs_slotname_map; std::map grad_ins_fwd_slotname_map; std::map grad_ins_grad_slotname_map; + std::vector in_vars; + std::vector out_vars; std::map>> grad_ins; @@ -1470,13 +1581,20 @@ static void DygraphCodeGeneration(const std::string& output_dir) { VLOG(6) << "-------- CollectInformationFromOpInfo -------"; bool is_available = CollectInformationFromOpInfo( - op_info, &grad_node_default_attr_maps, &grad_op_types, - &fwd_inputs_name_pos_map, &fwd_outputs_name_pos_map, - &grad_outs_slotname_map, &grad_ins_fwd_slotname_map, - &grad_ins_grad_slotname_map, &grad_ins, &grad_outs); + op_info, &grad_op_types, &grad_outs_slotname_map, + &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, &in_vars, + &out_vars, &grad_ins, &grad_outs); if (!is_available) continue; + VLOG(6) << "-------- PurifyOpProto -------"; + std::unordered_map fwd_inputs_name_pos_map; + std::unordered_map fwd_outputs_name_pos_map; + PurifyOpProto(*op_proto, &fwd_inputs_name_pos_map, + &fwd_outputs_name_pos_map, &grad_outs_slotname_map, + &grad_ins_fwd_slotname_map, &grad_ins_grad_slotname_map, + &in_vars, &out_vars, &grad_ins, &grad_outs); + /* --------------------------- */ /* --------- CodeGen --------- */ /* --------------------------- */ @@ -1484,10 +1602,10 @@ static void DygraphCodeGeneration(const std::string& output_dir) { VLOG(6) << "-------- GenerateForwardFunctionContents -------"; std::pair body_and_declaration = GenerateForwardFunctionContents( - grad_node_default_attr_maps, fwd_inputs_name_pos_map, - fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, - grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, - grad_outs, *op_proto); + fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, + grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, + grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars, + out_vars); std::string fwd_function_str = body_and_declaration.first; GenerateForwardDygraphFile(op_type, output_dir, fwd_function_str); @@ -1498,16 +1616,16 @@ static void DygraphCodeGeneration(const std::string& output_dir) { /* ---- xxx_node.h ---- */ VLOG(6) << "-------- GenerateGradNodeHeaderContents -------"; std::string grad_node_h_str = GenerateGradNodeHeaderContents( - grad_node_default_attr_maps, grad_ins_fwd_slotname_map, *op_proto); + grad_ins_fwd_slotname_map, op_type, in_vars, out_vars); GenerateNodeHFile(op_type, output_dir, grad_node_h_str); /* ---- xxx_node.cc ---- */ VLOG(6) << "-------- GenerateGradNodeCCContents -------"; std::string grad_node_cc_str = GenerateGradNodeCCContents( - grad_node_default_attr_maps, grad_op_types, fwd_inputs_name_pos_map, - fwd_outputs_name_pos_map, grad_ins_fwd_slotname_map, - grad_ins_grad_slotname_map, grad_outs_slotname_map, grad_ins, grad_outs, - *op_proto); + grad_op_types, fwd_inputs_name_pos_map, fwd_outputs_name_pos_map, + grad_ins_fwd_slotname_map, grad_ins_grad_slotname_map, + grad_outs_slotname_map, grad_ins, grad_outs, op_type, in_vars, + out_vars); GenerateNodeCCFile(op_type, output_dir, grad_node_cc_str); VLOG(6) << op_type << ": Finished Generation"; diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index 850f208359e0509a96fc1ba422fd04e30b41aa2b..749782f2413e5dbe0c448c62c25df4b696c1cd00 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "paddle/fluid/pybind/op_function_generator.h" + #include #include #include @@ -30,108 +32,6 @@ #include "paddle/fluid/framework/fleet/ascend_wrapper.h" #endif -// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are -// determined by the OP`s proto automatically, i.e., all the inputs registered -// in OpMaker. -// However, some OPs have dispensable inputs, which means the input can -// be none for some conditions. It is discovered that most dispensable inputs -// is not used in imperative mode, so we drop those inputs when generating OP -// functions. While, for very few OPs, the dispensable inputs are used, we -// need to manually specify them in this map. -std::map> op_ins_map = { - {"layer_norm", {"X", "Scale", "Bias"}}, - {"bincount", {"X", "Weights"}}, - {"fused_attention", - {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", - "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, - {"instance_norm", {"X", "Scale", "Bias"}}, - {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, - {"label_smooth", {"X", "PriorDist"}}, - {"assign", {"X"}}, - {"reshape2", {"X", "Shape"}}, - {"expand", {"X", "ExpandTimes"}}, - {"slice", {"Input", "StartsTensor", "EndsTensor"}}, - {"fake_quantize_dequantize_moving_average_abs_max", - {"X", "InScale", "InAccum", "InState"}}, - {"nll_loss", {"X", "Label", "Weight"}}, - {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, - {"gather", {"X", "Index", "Axis"}}, - {"roi_pool", {"X", "ROIs", "RoisNum"}}, - {"roi_align", {"X", "ROIs", "RoisNum"}}, - {"psroi_pool", {"X", "ROIs", "RoisNum"}}, - {"collect_fpn_proposals", - {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, - {"distribute_fpn_proposals", {"FpnRois", "RoisNum"}}, - {"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}}, - {"hierarchical_sigmoid", - {"X", "W", "Label", "PathTable", "PathCode", "Bias"}}, - {"moving_average_abs_max_scale", {"X", "InAccum", "InState"}}, - {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, - {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, - {"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}}, - {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, - {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, - {"run_program", {"X", "Params"}}, - {"fused_feedforward", - {"Dropout1Seed", "Dropout2Seed", "Linear1Bias", "Linear2Bias", "Ln1Scale", - "Ln1Bias", "Ln2Scale", "Ln2Bias"}}, - {"faster_tokenizer", {"Text", "Vocab", "TextPair"}}, - {"matrix_rank", {"X", "TolTensor"}}, - {"adam", - {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", - "Beta2Pow", "MasterParam"}}, - {"adamw", - {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", - "Beta2Pow", "MasterParam"}}, -}; - -// NOTE(zhiqiu): Like op_ins_map. -// Commonly, the outputs in auto-generated OP function are determined by the -// OP`s proto automatically, i.e., all the outputs registered in OpMaker. -// However, some OPs have dispensable outputs, which means the output can -// be none for some conditions. It is discovered that most dispensable outputs -// is not used in imperative mode, so we drop those outputs when generating OP -// functions. While, for very few OPs, the dispensable outputs are used, we -// need to manually specify them in this map. -std::map> op_outs_map = { - {"fake_quantize_dequantize_moving_average_abs_max", - {"Out", "OutScale", "OutAccum", "OutState"}}, - {"batch_norm", - {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", - "ReserveSpace"}}, - {"fused_attention", - {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", - "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", - "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", - "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, - {"sync_batch_norm", - {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", - "ReserveSpace"}}, - {"unique", {"Out", "Index", "Indices", "Counts"}}, - {"unique_consecutive", {"Out", "Index", "Counts"}}, - {"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, - {"collect_fpn_proposals", {"FpnRois", "RoisNum"}}, - {"matrix_nms", {"Out", "Index", "RoisNum"}}, - {"distribute_fpn_proposals", - {"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}}, - {"moving_average_abs_max_scale", - {"Out", "OutScale", "OutAccum", "OutState"}}, - {"multiclass_nms3", {"Out", "NmsRoisNum"}}, - {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, - {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, - {"sparse_momentum", {"ParamOut", "VelocityOut"}}, - {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, - {"lamb", - {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, - {"run_program", {"DOut"}}, - {"adam", - {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", - "MasterParamOut"}}, - {"adamw", - {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", - "MasterParamOut"}}, -}; - // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // generated in C++ automatically. // However, some OPs need to pass the outputs from Python instead of generating diff --git a/paddle/fluid/pybind/op_function_generator.h b/paddle/fluid/pybind/op_function_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7fa780976d7db91551b4e7f3f93d3f16c676dd --- /dev/null +++ b/paddle/fluid/pybind/op_function_generator.h @@ -0,0 +1,121 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +// NOTE(zhiqiu): Commonly, the inputs in auto-generated OP function are +// determined by the OP`s proto automatically, i.e., all the inputs registered +// in OpMaker. +// However, some OPs have dispensable inputs, which means the input can +// be none for some conditions. It is discovered that most dispensable inputs +// is not used in imperative mode, so we drop those inputs when generating OP +// functions. While, for very few OPs, the dispensable inputs are used, we +// need to manually specify them in this map. +std::map> op_ins_map = { + {"layer_norm", {"X", "Scale", "Bias"}}, + {"bincount", {"X", "Weights"}}, + {"fused_attention", + {"X", "LnScale", "LnBias", "QKVW", "QKVBias", "SrcMask", "OutLinearW", + "OutLinearBias", "Ln2Scale", "Ln2Bias"}}, + {"instance_norm", {"X", "Scale", "Bias"}}, + {"gru_unit", {"Input", "HiddenPrev", "Weight", "Bias"}}, + {"label_smooth", {"X", "PriorDist"}}, + {"assign", {"X"}}, + {"reshape2", {"X", "Shape"}}, + {"expand", {"X", "ExpandTimes"}}, + {"slice", {"Input", "StartsTensor", "EndsTensor"}}, + {"fake_quantize_dequantize_moving_average_abs_max", + {"X", "InScale", "InAccum", "InState"}}, + {"nll_loss", {"X", "Label", "Weight"}}, + {"bilinear_tensor_product", {"X", "Y", "Weight", "Bias"}}, + {"gather", {"X", "Index", "Axis"}}, + {"roi_pool", {"X", "ROIs", "RoisNum"}}, + {"roi_align", {"X", "ROIs", "RoisNum"}}, + {"psroi_pool", {"X", "ROIs", "RoisNum"}}, + {"collect_fpn_proposals", + {"MultiLevelRois", "MultiLevelScores", "MultiLevelRoIsNum"}}, + {"distribute_fpn_proposals", {"FpnRois", "RoisNum"}}, + {"warpctc", {"Logits", "Label", "LogitsLength", "LabelLength"}}, + {"hierarchical_sigmoid", + {"X", "W", "Label", "PathTable", "PathCode", "Bias"}}, + {"moving_average_abs_max_scale", {"X", "InAccum", "InState"}}, + {"multiclass_nms3", {"BBoxes", "Scores", "RoisNum"}}, + {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, + {"momentum", {"Param", "Grad", "Velocity", "LearningRate", "MasterParam"}}, + {"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}}, + {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, + {"run_program", {"X", "Params"}}, + {"fused_feedforward", + {"Dropout1Seed", "Dropout2Seed", "Linear1Bias", "Linear2Bias", "Ln1Scale", + "Ln1Bias", "Ln2Scale", "Ln2Bias"}}, + {"faster_tokenizer", {"Text", "Vocab", "TextPair"}}, + {"matrix_rank", {"X", "TolTensor"}}, + {"adam", + {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", + "Beta2Pow", "MasterParam"}}, + {"adamw", + {"Param", "Grad", "LearningRate", "Moment1", "Moment2", "Beta1Pow", + "Beta2Pow", "MasterParam"}}, +}; + +// NOTE(zhiqiu): Like op_ins_map. +// Commonly, the outputs in auto-generated OP function are determined by the +// OP`s proto automatically, i.e., all the outputs registered in OpMaker. +// However, some OPs have dispensable outputs, which means the output can +// be none for some conditions. It is discovered that most dispensable outputs +// is not used in imperative mode, so we drop those outputs when generating OP +// functions. While, for very few OPs, the dispensable outputs are used, we +// need to manually specify them in this map. +std::map> op_outs_map = { + {"fake_quantize_dequantize_moving_average_abs_max", + {"Out", "OutScale", "OutAccum", "OutState"}}, + {"batch_norm", + {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", + "ReserveSpace"}}, + {"fused_attention", + {"LnMean", "LnVariance", "LnOut", "QKVOut", "QKVBiasOut", "TransposeOut2", + "QKOut", "QKTVOut", "SoftmaxOut", "AttnDropoutMaskOut", "AttnDropoutOut", + "SrcMaskOut", "FMHAOut", "OutLinearOut", "DropoutMaskOut", "Ln2Mean", + "Ln2Variance", "BiasDropoutResidualOut", "Y"}}, + {"sync_batch_norm", + {"Y", "MeanOut", "VarianceOut", "SavedMean", "SavedVariance", + "ReserveSpace"}}, + {"unique", {"Out", "Index", "Indices", "Counts"}}, + {"unique_consecutive", {"Out", "Index", "Counts"}}, + {"generate_proposals", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, + {"collect_fpn_proposals", {"FpnRois", "RoisNum"}}, + {"matrix_nms", {"Out", "Index", "RoisNum"}}, + {"distribute_fpn_proposals", + {"MultiFpnRois", "RestoreIndex", "MultiLevelRoIsNum"}}, + {"moving_average_abs_max_scale", + {"Out", "OutScale", "OutAccum", "OutState"}}, + {"multiclass_nms3", {"Out", "NmsRoisNum"}}, + {"generate_proposals_v2", {"RpnRois", "RpnRoiProbs", "RpnRoisNum"}}, + {"momentum", {"ParamOut", "VelocityOut", "MasterParamOut"}}, + {"sparse_momentum", {"ParamOut", "VelocityOut"}}, + {"rnn", {"DropoutState", "Reserve", "Out", "State"}}, + {"lamb", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, + {"run_program", {"DOut"}}, + {"adam", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, + {"adamw", + {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut", + "MasterParamOut"}}, +};