diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index b3657a9894f82b05748ae5d27dbf4a9d68e5c12a..fe29792b6e75c48bd40ab4d261afaae5b722c2ab 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -33,14 +33,18 @@ static std::unordered_map operators_with_attrs = {}; static std::unordered_set operators_to_skip = { - "chunk_eval", // Stupid tensor name - "minus", "pull_sparse", "pull_box_extended_sparse", - "pull_sparse_v2", "pull_box_sparse", "fused_attention", - "diag_v2", "c_split"}; + "minus", +}; static std::unordered_set operators_to_codegen = {}; static std::unordered_set skipped_operators = {}; +static std::string LegalizeVariableName(const std::string& var_name) { + std::string ret = var_name; + std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_' + return ret; +} + static std::string AttrTypeToString(const proto::AttrType& type) { std::string ret; switch (type) { @@ -608,6 +612,9 @@ static bool CollectGradInformationFromOpInfo( } VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size(); + for (const auto& iter : default_attrs) { + VLOG(6) << iter.first; + } /* ---------------------------- */ /* --------- Backward --------- */ @@ -1052,24 +1059,25 @@ static std::pair GenerateForwardFunctionContents( const std::string& output_name = output.name(); std::string out_tensor_str; size_t return_position = fwd_outputs_name_pos_map.at(output_name); + std::string output_varname = LegalizeVariableName(output_name); if (output.duplicable()) { const char* FWD_OUT_TENSORS_TEMPLATE = " std::vector %s = " "egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE, - output_name, output_name); + output_varname, output_name); return_types[return_position] = "std::vector"; } else { const char* FWD_OUT_TENSOR_TEMPLATE = " egr::EagerTensor %s = " "egr::EagerUtils::GetOutput(outs[\"%s\"][0]);\n"; out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, - output_name, output_name); + output_varname, output_name); return_types[return_position] = "egr::EagerTensor"; } - return_contents[return_position] = output_name; + return_contents[return_position] = output_varname; generated_function_body += out_tensor_str; } generated_function_body += "\n"; @@ -1280,23 +1288,76 @@ static std::string GenerateGradNodeCCContents( if (grad_outs_slotname_map.count(grad_output_name)) { // Fwd Tensor - const std::string& fwd_input_name = - grad_outs_slotname_map.at(grad_output_name); - size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_input_name); - - if (duplicable_input_name_set.count(fwd_input_name)) { - const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " - "this->OutputMeta()[%d].Size() ) },"; + const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name); + + /* Handle Special Case: "PullSparseOp", etc + + Forward: + + Ids W + | | + PullSparseOp + | + Out + + Backward: + + Ids GradOut W + | | | + PullSparseGradOp + | + GradOut + + Its grad output "GradOut" corresponds to forward output "Out", + where there is a hiden inplace involved. So we find "GradOut"'s index + in + grads, and perform the inplace operation by constructing outs = + {{"Out", grads[i]}} + + GradOut -> Out -> fwd_output_pos -> grads position -> grads[i] + outs = {{"Out", grads[i]}} + + For returns, append "GradOut" to the very end of return list. + */ + if (!fwd_inputs_name_pos_map.count(fwd_name)) { + PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), + paddle::platform::errors::Fatal( + "fwd_name not found in fwd_inputs_name_pos_map nor " + "fwd_outputs_name_pos_map")); + + size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name); + std::string grad_ptr_name = fwd_name + "_ptrs"; + const char* GET_GRADS_PTR_TEMPLATE = + " std::vector> %s;\n" + " for(const auto& t : grads[%d]) {\n " + "%s.emplace_back(std::move(std::make_shared(t)));" + "\n }\n"; + std::string grads_ptr_str = + paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name, + grads_position, grad_ptr_name); + generated_grad_function_body += grads_ptr_str; + generated_grad_function_body += "\n"; + + const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },"; outs_contents_str += paddle::string::Sprintf( - GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name); + } else { - const char* GRAD_OUTS_CONTENT_TEMPLATE = - "{ \"%s\", " - "{std::make_shared(egr::Controller::Instance()." - "GenerateUniqueName())}},"; - outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, - grad_output_name); + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); + if (duplicable_input_name_set.count(fwd_name)) { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " + "this->OutputMeta()[%d].Size() ) },"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); + } else { + const char* GRAD_OUTS_CONTENT_TEMPLATE = + "{ \"%s\", " + "{std::make_shared(egr::Controller::Instance()." + "GenerateUniqueName())}},"; + outs_contents_str += paddle::string::Sprintf( + GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name); + } } } else { PADDLE_THROW(platform::errors::Fatal( @@ -1340,15 +1401,39 @@ static std::string GenerateGradNodeCCContents( // [Generation] Get Return std::string outputs_str = ""; + size_t num_appended_outputs = 0; for (auto iter : grad_outs) { const std::string& grad_out_name = iter.first; - size_t fwd_input_position = - fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name)); + const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); + + if (fwd_inputs_name_pos_map.count(fwd_name)) { + size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name); + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; + outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, + fwd_input_position, grad_out_name); + num_appended_outputs++; + } else { + PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name), + paddle::platform::errors::Fatal( + "fwd_name not found in fwd_inputs_name_pos_map nor " + "fwd_outputs_name_pos_map")); + } + } - const char* BWD_OUTPUT_TEMPLATE = - " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; - outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, - fwd_input_position, grad_out_name); + /* Handle Special Case: "PullSparseOp", etc + For returns, append "GradOut" to the very end of return list. */ + for (auto iter : grad_outs) { + const std::string& grad_out_name = iter.first; + const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name); + + if (fwd_outputs_name_pos_map.count(fwd_name)) { + const char* BWD_OUTPUT_TEMPLATE = + " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; + outputs_str += paddle::string::Sprintf( + BWD_OUTPUT_TEMPLATE, num_appended_outputs, grad_out_name); + num_appended_outputs++; + } } const char* BWD_RETURN_TEMPLATE = @@ -1722,6 +1807,10 @@ static void PrepareAttrMapForOps() { operators_with_attrs["transfer_dtype"] = {}; operators_with_attrs["transfer_dtype"]["out_dtype"] = 5; operators_with_attrs["transfer_dtype"]["in_dtype"] = 5; + + // Handle "c_split" + operators_with_attrs["c_split"] = {}; + operators_with_attrs["c_split"]["nranks"] = 1; } static void CollectOperatorsToCodeGen(const std::string& op_list_path) { diff --git a/paddle/fluid/eager/auto_code_generator/op_list.txt b/paddle/fluid/eager/auto_code_generator/op_list.txt index 2456a7a1846d1e2861dbd686670d83a66f9cfbc5..699a84169d70022a17839d336ff0f52ce334ee33 100644 --- a/paddle/fluid/eager/auto_code_generator/op_list.txt +++ b/paddle/fluid/eager/auto_code_generator/op_list.txt @@ -1,7 +1,3 @@ -sigmoid -matmul_v2 -reduce_sum -elementwise_add rsqrt multihead_matmul addmm @@ -19,7 +15,9 @@ pow2_decay_with_linear_warmup split fc clear_float_status +matmul_v2 load +c_embedding elementwise_max adadelta chunk_eval @@ -43,8 +41,10 @@ expand_v2 lgamma solve deformable_psroi_pooling +transfer_layout instance_norm decode_jpeg +distributed_push_sparse gather_nd reduce_prod matrix_rank @@ -57,10 +57,12 @@ sequence_slice lookup_table softplus depthwise_conv2d +c_allreduce_sum fused_fc_elementwise_layernorm sigmoid_cross_entropy_with_logits exp scatter +c_allreduce_min equal_all searchsorted fusion_squared_mat_sub @@ -73,6 +75,7 @@ momentum temporal_shift nce mv +global_scatter proximal_gd memcpy_h2d add_position_encoding @@ -90,13 +93,18 @@ randperm sequence_scatter partial_sum relu6 +partial_allgather +c_scatter +alltoall conv3d lstm_unit not_equal transpose2 +c_sync_comm_stream uniform_random_batch_size_like unfold lrn +isclose softmax_with_cross_entropy isfinite_v2 bernoulli @@ -105,6 +113,7 @@ gaussian_random flatten2 matmul cvm +recv_v2 adamax masked_select range @@ -112,6 +121,7 @@ bitwise_not trace multinomial modified_huber_loss +c_reduce_prod roll squared_l2_distance conv3d_transpose @@ -128,8 +138,10 @@ multiclass_nms2 bpr_loss fft_c2c bicubic_interp_v2 +angle reshape coalesce_tensor +dgc roi_align reshape2 reduce_any @@ -139,6 +151,7 @@ sequence_reshape bilateral_slice fill_any_like empty +partial_recv pad_constant_like pool2d size @@ -148,11 +161,14 @@ stack dgc_momentum lamb generate_proposals_v2 +c_sync_calc_stream bitwise_or gru_unit fake_channel_wise_quantize_dequantize_abs_max sampling_id unsqueeze2 +transfer_dtype +allreduce average_accumulates sequence_enumerate fusion_seqconv_eltadd_relu @@ -160,6 +176,7 @@ bce_loss generate_proposal_labels im2sequence isinf +c_reducescatter adagrad linear_chain_crf retinanet_target_assign @@ -170,6 +187,7 @@ lookup_table_v2 detection_map l1_norm sqrt +partial_send fused_elemwise_activation slogdeterminant share_buffer @@ -191,7 +209,10 @@ linear_interp auc logical_or batch_norm +c_reduce_sum +elementwise_add acos +send_and_recv unpool cumprod sample_logits @@ -206,6 +227,7 @@ matrix_power greater_equal generate_proposals bilinear_interp +sigmoid inplace_abn softshrink mul @@ -243,6 +265,8 @@ overlap_add fill_constant_batch_size_like fill_any dequantize_log +c_split +barrier max_pool2d_with_index pad3d norm @@ -258,6 +282,7 @@ pow stanh label_smooth merged_momentum +c_reduce_min ascend_trigger fused_feedforward rpn_target_assign @@ -271,6 +296,7 @@ frame bincount shape group_norm +c_softmax_with_cross_entropy resnet_unit sequence_expand_as cos_sim @@ -319,6 +345,7 @@ adamw elementwise_pow prior_box p_norm +c_concat unique_consecutive lod_reset pad @@ -339,6 +366,7 @@ pad2d inverse spectral_norm shuffle_channel +send_v2 psroi_pool seed ceil @@ -347,6 +375,7 @@ reduce_min cos ncclAllReduce cudnn_lstm +reduce_sum digamma assign_value increment @@ -366,6 +395,8 @@ atan less_than unsqueeze crf_decoding +global_gather +c_allreduce_prod log_softmax ftrl matrix_nms @@ -374,6 +405,7 @@ cast tanh_shrink hard_shrink multiclass_nms +c_broadcast fusion_transpose_flatten_concat sequence_unpad fused_elemwise_add_activation @@ -393,6 +425,7 @@ alloc_float_status sequence_concat fusion_seqpool_cvm_concat similarity_focus +c_allreduce_max argsort sequence_expand sgd @@ -413,6 +446,7 @@ pyramid_hash fake_quantize_dequantize_moving_average_abs_max multi_dot sequence_pool +broadcast transpose top_k dist @@ -466,10 +500,13 @@ squared_l2_norm elementwise_sub margin_rank_loss faster_tokenizer +c_identity +c_reduce_max relu is_empty reduce_all edit_distance +distributed_lookup_table bmm yolo_box soft_relu @@ -484,6 +521,7 @@ nearest_interp gather trilinear_interp_v2 box_clip +c_allgather isnan_v2 softmax conv2d_fusion