未验证 提交 cf873c39 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Enabled Eager AutoCodeGen for 40+ more operators (#37910)

* Rearranged Eager AutoCodeGen directory structure

* Removed USE_OP in Eager AutoCodeGen

* Enabled generation for Operators without Grad/Inputs/Outputs

* Resolved operators without input

* Fixed merge conflicts

* Enabled Eager AutoCodeGen for 10+ more operators
上级 b5dd12fb
...@@ -33,14 +33,18 @@ static std::unordered_map<std::string, paddle::framework::AttributeMap> ...@@ -33,14 +33,18 @@ static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {}; operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_skip = { static std::unordered_set<std::string> operators_to_skip = {
"chunk_eval", // Stupid tensor name "minus",
"minus", "pull_sparse", "pull_box_extended_sparse", };
"pull_sparse_v2", "pull_box_sparse", "fused_attention",
"diag_v2", "c_split"};
static std::unordered_set<std::string> operators_to_codegen = {}; static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {}; static std::unordered_set<std::string> skipped_operators = {};
static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_'
return ret;
}
static std::string AttrTypeToString(const proto::AttrType& type) { static std::string AttrTypeToString(const proto::AttrType& type) {
std::string ret; std::string ret;
switch (type) { switch (type) {
...@@ -608,6 +612,9 @@ static bool CollectGradInformationFromOpInfo( ...@@ -608,6 +612,9 @@ static bool CollectGradInformationFromOpInfo(
} }
VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size(); VLOG(6) << "Prepared Default Attributes Map, size = " << default_attrs.size();
for (const auto& iter : default_attrs) {
VLOG(6) << iter.first;
}
/* ---------------------------- */ /* ---------------------------- */
/* --------- Backward --------- */ /* --------- Backward --------- */
...@@ -1052,24 +1059,25 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -1052,24 +1059,25 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const std::string& output_name = output.name(); const std::string& output_name = output.name();
std::string out_tensor_str; std::string out_tensor_str;
size_t return_position = fwd_outputs_name_pos_map.at(output_name); size_t return_position = fwd_outputs_name_pos_map.at(output_name);
std::string output_varname = LegalizeVariableName(output_name);
if (output.duplicable()) { if (output.duplicable()) {
const char* FWD_OUT_TENSORS_TEMPLATE = const char* FWD_OUT_TENSORS_TEMPLATE =
" std::vector<egr::EagerTensor> %s = " " std::vector<egr::EagerTensor> %s = "
"egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; "egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE, out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE,
output_name, output_name); output_varname, output_name);
return_types[return_position] = "std::vector<egr::EagerTensor>"; return_types[return_position] = "std::vector<egr::EagerTensor>";
} else { } else {
const char* FWD_OUT_TENSOR_TEMPLATE = const char* FWD_OUT_TENSOR_TEMPLATE =
" egr::EagerTensor %s = " " egr::EagerTensor %s = "
"egr::EagerUtils::GetOutput(outs[\"%s\"][0]);\n"; "egr::EagerUtils::GetOutput(outs[\"%s\"][0]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE,
output_name, output_name); output_varname, output_name);
return_types[return_position] = "egr::EagerTensor"; return_types[return_position] = "egr::EagerTensor";
} }
return_contents[return_position] = output_name; return_contents[return_position] = output_varname;
generated_function_body += out_tensor_str; generated_function_body += out_tensor_str;
} }
generated_function_body += "\n"; generated_function_body += "\n";
...@@ -1280,23 +1288,76 @@ static std::string GenerateGradNodeCCContents( ...@@ -1280,23 +1288,76 @@ static std::string GenerateGradNodeCCContents(
if (grad_outs_slotname_map.count(grad_output_name)) { if (grad_outs_slotname_map.count(grad_output_name)) {
// Fwd Tensor // Fwd Tensor
const std::string& fwd_input_name = const std::string& fwd_name = grad_outs_slotname_map.at(grad_output_name);
grad_outs_slotname_map.at(grad_output_name);
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_input_name); /* Handle Special Case: "PullSparseOp", etc
if (duplicable_input_name_set.count(fwd_input_name)) { Forward:
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( " Ids W
"this->OutputMeta()[%d].Size() ) },"; | |
PullSparseOp
|
Out
Backward:
Ids GradOut W
| | |
PullSparseGradOp
|
GradOut
Its grad output "GradOut" corresponds to forward output "Out",
where there is a hiden inplace involved. So we find "GradOut"'s index
in
grads, and perform the inplace operation by constructing outs =
{{"Out", grads[i]}}
GradOut -> Out -> fwd_output_pos -> grads position -> grads[i]
outs = {{"Out", grads[i]}}
For returns, append "GradOut" to the very end of return list.
*/
if (!fwd_inputs_name_pos_map.count(fwd_name)) {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
size_t grads_position = fwd_outputs_name_pos_map.at(fwd_name);
std::string grad_ptr_name = fwd_name + "_ptrs";
const char* GET_GRADS_PTR_TEMPLATE =
" std::vector<std::shared_ptr<egr::EagerTensor>> %s;\n"
" for(const auto& t : grads[%d]) {\n "
"%s.emplace_back(std::move(std::make_shared<egr::EagerTensor>(t)));"
"\n }\n";
std::string grads_ptr_str =
paddle::string::Sprintf(GET_GRADS_PTR_TEMPLATE, grad_ptr_name,
grads_position, grad_ptr_name);
generated_grad_function_body += grads_ptr_str;
generated_grad_function_body += "\n";
const char* GRAD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", %s },";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, grad_ptr_name);
} else { } else {
const char* GRAD_OUTS_CONTENT_TEMPLATE = size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
"{ \"%s\", " if (duplicable_input_name_set.count(fwd_name)) {
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()." const char* GRAD_OUTS_CONTENT_TEMPLATE =
"GenerateUniqueName())}},"; "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
outs_contents_str += paddle::string::Sprintf(GRAD_OUTS_CONTENT_TEMPLATE, "this->OutputMeta()[%d].Size() ) },";
grad_output_name); outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
} else {
const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", "
"{std::make_shared<egr::EagerTensor>(egr::Controller::Instance()."
"GenerateUniqueName())}},";
outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name);
}
} }
} else { } else {
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
...@@ -1340,15 +1401,39 @@ static std::string GenerateGradNodeCCContents( ...@@ -1340,15 +1401,39 @@ static std::string GenerateGradNodeCCContents(
// [Generation] Get Return // [Generation] Get Return
std::string outputs_str = ""; std::string outputs_str = "";
size_t num_appended_outputs = 0;
for (auto iter : grad_outs) { for (auto iter : grad_outs) {
const std::string& grad_out_name = iter.first; const std::string& grad_out_name = iter.first;
size_t fwd_input_position = const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name));
if (fwd_inputs_name_pos_map.count(fwd_name)) {
size_t fwd_input_position = fwd_inputs_name_pos_map.at(fwd_name);
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
fwd_input_position, grad_out_name);
num_appended_outputs++;
} else {
PADDLE_ENFORCE(fwd_outputs_name_pos_map.count(fwd_name),
paddle::platform::errors::Fatal(
"fwd_name not found in fwd_inputs_name_pos_map nor "
"fwd_outputs_name_pos_map"));
}
}
const char* BWD_OUTPUT_TEMPLATE = /* Handle Special Case: "PullSparseOp", etc
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n"; For returns, append "GradOut" to the very end of return list. */
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, for (auto iter : grad_outs) {
fwd_input_position, grad_out_name); const std::string& grad_out_name = iter.first;
const std::string& fwd_name = grad_outs_slotname_map.at(grad_out_name);
if (fwd_outputs_name_pos_map.count(fwd_name)) {
const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(
BWD_OUTPUT_TEMPLATE, num_appended_outputs, grad_out_name);
num_appended_outputs++;
}
} }
const char* BWD_RETURN_TEMPLATE = const char* BWD_RETURN_TEMPLATE =
...@@ -1722,6 +1807,10 @@ static void PrepareAttrMapForOps() { ...@@ -1722,6 +1807,10 @@ static void PrepareAttrMapForOps() {
operators_with_attrs["transfer_dtype"] = {}; operators_with_attrs["transfer_dtype"] = {};
operators_with_attrs["transfer_dtype"]["out_dtype"] = 5; operators_with_attrs["transfer_dtype"]["out_dtype"] = 5;
operators_with_attrs["transfer_dtype"]["in_dtype"] = 5; operators_with_attrs["transfer_dtype"]["in_dtype"] = 5;
// Handle "c_split"
operators_with_attrs["c_split"] = {};
operators_with_attrs["c_split"]["nranks"] = 1;
} }
static void CollectOperatorsToCodeGen(const std::string& op_list_path) { static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
......
sigmoid
matmul_v2
reduce_sum
elementwise_add
rsqrt rsqrt
multihead_matmul multihead_matmul
addmm addmm
...@@ -19,7 +15,9 @@ pow2_decay_with_linear_warmup ...@@ -19,7 +15,9 @@ pow2_decay_with_linear_warmup
split split
fc fc
clear_float_status clear_float_status
matmul_v2
load load
c_embedding
elementwise_max elementwise_max
adadelta adadelta
chunk_eval chunk_eval
...@@ -43,8 +41,10 @@ expand_v2 ...@@ -43,8 +41,10 @@ expand_v2
lgamma lgamma
solve solve
deformable_psroi_pooling deformable_psroi_pooling
transfer_layout
instance_norm instance_norm
decode_jpeg decode_jpeg
distributed_push_sparse
gather_nd gather_nd
reduce_prod reduce_prod
matrix_rank matrix_rank
...@@ -57,10 +57,12 @@ sequence_slice ...@@ -57,10 +57,12 @@ sequence_slice
lookup_table lookup_table
softplus softplus
depthwise_conv2d depthwise_conv2d
c_allreduce_sum
fused_fc_elementwise_layernorm fused_fc_elementwise_layernorm
sigmoid_cross_entropy_with_logits sigmoid_cross_entropy_with_logits
exp exp
scatter scatter
c_allreduce_min
equal_all equal_all
searchsorted searchsorted
fusion_squared_mat_sub fusion_squared_mat_sub
...@@ -73,6 +75,7 @@ momentum ...@@ -73,6 +75,7 @@ momentum
temporal_shift temporal_shift
nce nce
mv mv
global_scatter
proximal_gd proximal_gd
memcpy_h2d memcpy_h2d
add_position_encoding add_position_encoding
...@@ -90,13 +93,18 @@ randperm ...@@ -90,13 +93,18 @@ randperm
sequence_scatter sequence_scatter
partial_sum partial_sum
relu6 relu6
partial_allgather
c_scatter
alltoall
conv3d conv3d
lstm_unit lstm_unit
not_equal not_equal
transpose2 transpose2
c_sync_comm_stream
uniform_random_batch_size_like uniform_random_batch_size_like
unfold unfold
lrn lrn
isclose
softmax_with_cross_entropy softmax_with_cross_entropy
isfinite_v2 isfinite_v2
bernoulli bernoulli
...@@ -105,6 +113,7 @@ gaussian_random ...@@ -105,6 +113,7 @@ gaussian_random
flatten2 flatten2
matmul matmul
cvm cvm
recv_v2
adamax adamax
masked_select masked_select
range range
...@@ -112,6 +121,7 @@ bitwise_not ...@@ -112,6 +121,7 @@ bitwise_not
trace trace
multinomial multinomial
modified_huber_loss modified_huber_loss
c_reduce_prod
roll roll
squared_l2_distance squared_l2_distance
conv3d_transpose conv3d_transpose
...@@ -128,8 +138,10 @@ multiclass_nms2 ...@@ -128,8 +138,10 @@ multiclass_nms2
bpr_loss bpr_loss
fft_c2c fft_c2c
bicubic_interp_v2 bicubic_interp_v2
angle
reshape reshape
coalesce_tensor coalesce_tensor
dgc
roi_align roi_align
reshape2 reshape2
reduce_any reduce_any
...@@ -139,6 +151,7 @@ sequence_reshape ...@@ -139,6 +151,7 @@ sequence_reshape
bilateral_slice bilateral_slice
fill_any_like fill_any_like
empty empty
partial_recv
pad_constant_like pad_constant_like
pool2d pool2d
size size
...@@ -148,11 +161,14 @@ stack ...@@ -148,11 +161,14 @@ stack
dgc_momentum dgc_momentum
lamb lamb
generate_proposals_v2 generate_proposals_v2
c_sync_calc_stream
bitwise_or bitwise_or
gru_unit gru_unit
fake_channel_wise_quantize_dequantize_abs_max fake_channel_wise_quantize_dequantize_abs_max
sampling_id sampling_id
unsqueeze2 unsqueeze2
transfer_dtype
allreduce
average_accumulates average_accumulates
sequence_enumerate sequence_enumerate
fusion_seqconv_eltadd_relu fusion_seqconv_eltadd_relu
...@@ -160,6 +176,7 @@ bce_loss ...@@ -160,6 +176,7 @@ bce_loss
generate_proposal_labels generate_proposal_labels
im2sequence im2sequence
isinf isinf
c_reducescatter
adagrad adagrad
linear_chain_crf linear_chain_crf
retinanet_target_assign retinanet_target_assign
...@@ -170,6 +187,7 @@ lookup_table_v2 ...@@ -170,6 +187,7 @@ lookup_table_v2
detection_map detection_map
l1_norm l1_norm
sqrt sqrt
partial_send
fused_elemwise_activation fused_elemwise_activation
slogdeterminant slogdeterminant
share_buffer share_buffer
...@@ -191,7 +209,10 @@ linear_interp ...@@ -191,7 +209,10 @@ linear_interp
auc auc
logical_or logical_or
batch_norm batch_norm
c_reduce_sum
elementwise_add
acos acos
send_and_recv
unpool unpool
cumprod cumprod
sample_logits sample_logits
...@@ -206,6 +227,7 @@ matrix_power ...@@ -206,6 +227,7 @@ matrix_power
greater_equal greater_equal
generate_proposals generate_proposals
bilinear_interp bilinear_interp
sigmoid
inplace_abn inplace_abn
softshrink softshrink
mul mul
...@@ -243,6 +265,8 @@ overlap_add ...@@ -243,6 +265,8 @@ overlap_add
fill_constant_batch_size_like fill_constant_batch_size_like
fill_any fill_any
dequantize_log dequantize_log
c_split
barrier
max_pool2d_with_index max_pool2d_with_index
pad3d pad3d
norm norm
...@@ -258,6 +282,7 @@ pow ...@@ -258,6 +282,7 @@ pow
stanh stanh
label_smooth label_smooth
merged_momentum merged_momentum
c_reduce_min
ascend_trigger ascend_trigger
fused_feedforward fused_feedforward
rpn_target_assign rpn_target_assign
...@@ -271,6 +296,7 @@ frame ...@@ -271,6 +296,7 @@ frame
bincount bincount
shape shape
group_norm group_norm
c_softmax_with_cross_entropy
resnet_unit resnet_unit
sequence_expand_as sequence_expand_as
cos_sim cos_sim
...@@ -319,6 +345,7 @@ adamw ...@@ -319,6 +345,7 @@ adamw
elementwise_pow elementwise_pow
prior_box prior_box
p_norm p_norm
c_concat
unique_consecutive unique_consecutive
lod_reset lod_reset
pad pad
...@@ -339,6 +366,7 @@ pad2d ...@@ -339,6 +366,7 @@ pad2d
inverse inverse
spectral_norm spectral_norm
shuffle_channel shuffle_channel
send_v2
psroi_pool psroi_pool
seed seed
ceil ceil
...@@ -347,6 +375,7 @@ reduce_min ...@@ -347,6 +375,7 @@ reduce_min
cos cos
ncclAllReduce ncclAllReduce
cudnn_lstm cudnn_lstm
reduce_sum
digamma digamma
assign_value assign_value
increment increment
...@@ -366,6 +395,8 @@ atan ...@@ -366,6 +395,8 @@ atan
less_than less_than
unsqueeze unsqueeze
crf_decoding crf_decoding
global_gather
c_allreduce_prod
log_softmax log_softmax
ftrl ftrl
matrix_nms matrix_nms
...@@ -374,6 +405,7 @@ cast ...@@ -374,6 +405,7 @@ cast
tanh_shrink tanh_shrink
hard_shrink hard_shrink
multiclass_nms multiclass_nms
c_broadcast
fusion_transpose_flatten_concat fusion_transpose_flatten_concat
sequence_unpad sequence_unpad
fused_elemwise_add_activation fused_elemwise_add_activation
...@@ -393,6 +425,7 @@ alloc_float_status ...@@ -393,6 +425,7 @@ alloc_float_status
sequence_concat sequence_concat
fusion_seqpool_cvm_concat fusion_seqpool_cvm_concat
similarity_focus similarity_focus
c_allreduce_max
argsort argsort
sequence_expand sequence_expand
sgd sgd
...@@ -413,6 +446,7 @@ pyramid_hash ...@@ -413,6 +446,7 @@ pyramid_hash
fake_quantize_dequantize_moving_average_abs_max fake_quantize_dequantize_moving_average_abs_max
multi_dot multi_dot
sequence_pool sequence_pool
broadcast
transpose transpose
top_k top_k
dist dist
...@@ -466,10 +500,13 @@ squared_l2_norm ...@@ -466,10 +500,13 @@ squared_l2_norm
elementwise_sub elementwise_sub
margin_rank_loss margin_rank_loss
faster_tokenizer faster_tokenizer
c_identity
c_reduce_max
relu relu
is_empty is_empty
reduce_all reduce_all
edit_distance edit_distance
distributed_lookup_table
bmm bmm
yolo_box yolo_box
soft_relu soft_relu
...@@ -484,6 +521,7 @@ nearest_interp ...@@ -484,6 +521,7 @@ nearest_interp
gather gather
trilinear_interp_v2 trilinear_interp_v2
box_clip box_clip
c_allgather
isnan_v2 isnan_v2
softmax softmax
conv2d_fusion conv2d_fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册