未验证 提交 08482a86 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Enabled Eager AutoCodeGen for All Existing Operators & Possible Future Operators (#37969)

* Rearranged Eager AutoCodeGen directory structure

* Removed USE_OP in Eager AutoCodeGen

* Enabled generation for Operators without Grad/Inputs/Outputs

* Resolved operators without input

* Fixed merge conflicts

* Enabled Eager AutoCodeGen for 10+ more operators

* Refactored Eager AutoCodeGen with more organized helper objects

* Enabled Eager AutoCodeGen for operators with multiple OpBases

* Adjusted Eager AutoCodeGen to Enable Passing Output Tensor as Input Argument

* Handled Dispensable Inputs/Outputs in Eager AutoCodeGen

* Enabled Eager AutoCodeGen for All Existing Operators & Possible Future Operators

* Fixed CI issues
上级 bce1e572
...@@ -47,12 +47,12 @@ if(WIN32) ...@@ -47,12 +47,12 @@ if(WIN32)
endif() endif()
add_custom_target(eager_codegen add_custom_target(eager_codegen
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt" COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
DEPENDS ${EAGER_CODEGEN_DEPS} DEPENDS ${EAGER_CODEGEN_DEPS}
VERBATIM) VERBATIM)
else() else()
add_custom_target(eager_codegen add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt" COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
DEPENDS eager_generator DEPENDS eager_generator
VERBATIM) VERBATIM)
endif() endif()
...@@ -33,8 +33,6 @@ namespace framework { ...@@ -33,8 +33,6 @@ namespace framework {
static std::unordered_map<std::string, paddle::framework::AttributeMap> static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {}; operators_with_attrs = {};
static std::unordered_set<std::string> operators_to_codegen = {};
static std::string LegalizeVariableName(const std::string& var_name) { static std::string LegalizeVariableName(const std::string& var_name) {
std::string ret = var_name; std::string ret = var_name;
std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_' std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_'
...@@ -469,8 +467,6 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -469,8 +467,6 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now // Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type; VLOG(1) << "------ Analyzing Op ------: " << op_type;
if (!operators_to_codegen.count(op_type)) return false;
return true; return true;
} }
...@@ -2005,33 +2001,17 @@ static void PrepareAttrMapForOps() { ...@@ -2005,33 +2001,17 @@ static void PrepareAttrMapForOps() {
operators_with_attrs["c_split"]["nranks"] = 1; operators_with_attrs["c_split"]["nranks"] = 1;
} }
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 3) { if (argc != 2) {
std::cerr << "argc must be 3" << std::endl; std::cerr << "argc must be 2" << std::endl;
return -1; return -1;
} }
std::string eager_root = argv[1]; std::string eager_root = argv[1];
std::string op_list_path = argv[2];
paddle::framework::CollectOperatorsToCodeGen(op_list_path);
paddle::framework::PrepareAttrMapForOps(); paddle::framework::PrepareAttrMapForOps();
paddle::framework::DygraphCodeGeneration(eager_root); paddle::framework::DygraphCodeGeneration(eager_root);
......
rsqrt
multihead_matmul
addmm
gru
round
rank_attention
fused_embedding_fc_lstm
where_index
bicubic_interp
arg_min
tile
bilinear_tensor_product
ctc_align
pow2_decay_with_linear_warmup
split
fc
clear_float_status
matmul_v2
load
c_embedding
elementwise_max
adadelta
chunk_eval
check_finite_and_unscale
sparse_momentum
tan
adam
fsp
where
logical_xor
multiclass_nms3
one_hot_v2
sequence_softmax
affine_channel
triangular_solve
sequence_topk_avg_pooling
space_to_depth
reverse
fused_embedding_eltwise_layernorm
expand_v2
lgamma
solve
deformable_psroi_pooling
transfer_layout
instance_norm
decode_jpeg
distributed_push_sparse
gather_nd
reduce_prod
matrix_rank
asin
lstmp
iou_similarity
huber_loss
one_hot
sequence_slice
lookup_table
softplus
depthwise_conv2d
c_allreduce_sum
fused_fc_elementwise_layernorm
sigmoid_cross_entropy_with_logits
exp
scatter
c_allreduce_min
equal_all
searchsorted
fusion_squared_mat_sub
unique
log
conv_shift
smooth_l1_loss
linear_interp_v2
momentum
temporal_shift
nce
mv
global_scatter
proximal_gd
memcpy_h2d
add_position_encoding
cosh
hash
grad_add
sign
prelu
linspace
fill_diagonal
logsigmoid
load_combine
fetch_v2
randperm
sequence_scatter
partial_sum
relu6
partial_allgather
c_scatter
alltoall
conv3d
lstm_unit
not_equal
transpose2
c_sync_comm_stream
uniform_random_batch_size_like
unfold
lrn
isclose
softmax_with_cross_entropy
isfinite_v2
bernoulli
max_pool3d_with_index
gaussian_random
flatten2
matmul
cvm
recv_v2
adamax
masked_select
range
bitwise_not
trace
multinomial
modified_huber_loss
c_reduce_prod
roll
squared_l2_distance
conv3d_transpose
share_data
fake_quantize_abs_max
unique_with_counts
fill
concat
fill_zeros_like
hierarchical_sigmoid
isinf_v2
squeeze
multiclass_nms2
bpr_loss
fft_c2c
bicubic_interp_v2
angle
reshape
coalesce_tensor
dgc
roi_align
reshape2
reduce_any
unstack
scatter_nd_add
sequence_reshape
bilateral_slice
fill_any_like
empty
partial_recv
pad_constant_like
pool2d
size
imag
eigh
stack
dgc_momentum
lamb
generate_proposals_v2
c_sync_calc_stream
bitwise_or
gru_unit
fake_channel_wise_quantize_dequantize_abs_max
sampling_id
unsqueeze2
transfer_dtype
allreduce
average_accumulates
sequence_enumerate
fusion_seqconv_eltadd_relu
bce_loss
generate_proposal_labels
im2sequence
isinf
c_reducescatter
adagrad
linear_chain_crf
retinanet_target_assign
fusion_group
teacher_student_sigmoid_loss
random_crop
lookup_table_v2
detection_map
l1_norm
sqrt
partial_send
fused_elemwise_activation
slogdeterminant
share_buffer
bitwise_and
diag_embed
unbind
dropout
moving_average_abs_max_scale
beam_search
log_loss
greater_than
kron
sigmoid_focal_loss
rmsprop
conv2d
uniform_random_inplace
maxout
linear_interp
auc
logical_or
batch_norm
c_reduce_sum
elementwise_add
acos
send_and_recv
unpool
cumprod
sample_logits
pull_box_extended_sparse
crop_tensor
fill_constant
deformable_conv
generate_mask_labels
locality_aware_nms
expand_as
matrix_power
greater_equal
generate_proposals
bilinear_interp
sigmoid
inplace_abn
softshrink
mul
data_norm
get_tensor_from_selected_rows
spp
floor
gelu
retinanet_detection_output
minus
push_dense
silu
sequence_erase
real
nearest_interp_v2
dgc_clip_by_norm
squeeze2
strided_slice
conj
precision_recall
save
fusion_seqexpand_concat_fc
fake_quantize_range_abs_max
depthwise_conv2d_transpose
positive_negative_pair
square
var_conv_2d
log1p
fused_softmax_mask_upper_triangle
clip_by_norm
atan2
box_decoder_and_assign
fft_r2c
roi_pool
overlap_add
fill_constant_batch_size_like
fill_any
dequantize_log
c_split
barrier
max_pool2d_with_index
pad3d
norm
viterbi_decode
mish
box_coder
flatten
elementwise_mod
margin_cross_entropy
pull_sparse
logical_and
pow
stanh
label_smooth
merged_momentum
c_reduce_min
ascend_trigger
fused_feedforward
rpn_target_assign
roi_perspective_transform
expand
prroi_pool
pool3d
memcpy
distribute_fpn_proposals
frame
bincount
shape
group_norm
c_softmax_with_cross_entropy
resnet_unit
sequence_expand_as
cos_sim
eigvals
save_combine
class_center_sample
read_file
isfinite
arg_max
equal
fake_dequantize_max_abs
qr
anchor_generator
layer_norm
merge_selected_rows
less_equal
rnn
fusion_lstm
lars_momentum
hard_sigmoid
isnan
elementwise_floordiv
correlation
histogram
gather_tree
segment_pool
sync_batch_norm
fusion_repeated_fc_relu
nop
fused_attention
expand_as_v2
filter_by_instag
diag_v2
pull_box_sparse
nll_loss
dot
scale
ncclBcast
shuffle_batch
ncclReduce
diag
multiplex
leaky_relu
allclose
adamw
elementwise_pow
prior_box
p_norm
c_concat
unique_consecutive
lod_reset
pad
sequence_conv
log10
set_value
bitwise_xor
center_loss
randint
attention_lstm
uniform_random
slice
meshgrid
hard_swish
sin
mean_iou
pad2d
inverse
spectral_norm
shuffle_channel
send_v2
psroi_pool
seed
ceil
eig
reduce_min
cos
ncclAllReduce
cudnn_lstm
reduce_sum
digamma
assign_value
increment
tdm_sampler
fused_softmax_mask
sequence_reverse
eigvalsh
diagonal
trunc
log2
marker
tanh
yolov3_loss
graph_send_recv
accuracy
atan
less_than
unsqueeze
crf_decoding
global_gather
c_allreduce_prod
log_softmax
ftrl
matrix_nms
top_k_v2
cast
tanh_shrink
hard_shrink
multiclass_nms
c_broadcast
fusion_transpose_flatten_concat
sequence_unpad
fused_elemwise_add_activation
pull_sparse_v2
frobenius_norm
crop
cross_entropy2
skip_layernorm
tdm_child
fused_embedding_seq_pool
erf
conv2d_inception_fusion
trilinear_interp
logsumexp
fusion_seqpool_concat
alloc_float_status
sequence_concat
fusion_seqpool_cvm_concat
similarity_focus
c_allreduce_max
argsort
sequence_expand
sgd
fused_bn_add_activation
bilinear_interp_v2
clip
deformable_conv_v1
hinge_loss
determinant
conv2d_transpose
memcpy_d2h
softsign
fake_quantize_dequantize_abs_max
broadcast_tensors
grid_sampler
fft_c2r
pyramid_hash
fake_quantize_dequantize_moving_average_abs_max
multi_dot
sequence_pool
broadcast
transpose
top_k
dist
affine_grid
gaussian_random_batch_size_like
fake_channel_wise_dequantize_max_abs
reciprocal
sequence_mask
fill_diagonal_tensor
abs
partial_concat
elu
index_select
row_conv
cross
elementwise_mul
decayed_adagrad
bipartite_match
run_program
fake_quantize_moving_average_abs_max
mine_hard_examples
target_assign
lstm
truncated_gaussian_random
match_matrix_tensor
elementwise_div
kldiv_loss
cumsum
sum
proximal_adagrad
update_loss_scaling
shard_index
selu
mean
gumbel_softmax
sequence_pad
tree_conv
assign
flatten_contiguous_range
tril_triu
brelu
celu
reduce_mean
sinh
rank_loss
reduce_max
fusion_gru
fill_zeros_like2
expm1
squared_l2_norm
elementwise_sub
margin_rank_loss
faster_tokenizer
c_identity
c_reduce_max
relu
is_empty
reduce_all
edit_distance
distributed_lookup_table
bmm
yolo_box
soft_relu
density_prior_box
eye
swish
cross_entropy
dpsgd
cholesky
batch_fc
nearest_interp
gather
trilinear_interp_v2
box_clip
c_allgather
isnan_v2
softmax
conv2d_fusion
fused_batch_norm_act
get_float_status
index_sample
elementwise_min
logical_not
collect_fpn_proposals
pixel_shuffle
thresholded_relu
polygon_box_transform
lookup_table_dequant
warpctc
fake_channel_wise_quantize_abs_max
dequantize_abs_max
svd
flip
...@@ -181,7 +181,7 @@ if(WITH_PYTHON) ...@@ -181,7 +181,7 @@ if(WITH_PYTHON)
":retry\n" ":retry\n"
"ECHO eager_op_function_generator run %build_times% time\n" "ECHO eager_op_function_generator run %build_times% time\n"
"taskkill /f /im eager_op_function_generator.exe 2>NUL\n" "taskkill /f /im eager_op_function_generator.exe 2>NUL\n"
"${op_impl_path}/eager_op_function_generator.exe ${tmp_eager_impl_file} ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt\n" "${op_impl_path}/eager_op_function_generator.exe ${tmp_eager_impl_file}\n"
"if %ERRORLEVEL% NEQ 0 (\n" "if %ERRORLEVEL% NEQ 0 (\n"
" set /a build_times=%build_times%+1\n" " set /a build_times=%build_times%+1\n"
" if %build_times% GEQ 10 (\n" " if %build_times% GEQ 10 (\n"
...@@ -256,7 +256,7 @@ if(WITH_PYTHON) ...@@ -256,7 +256,7 @@ if(WITH_PYTHON)
add_custom_command(OUTPUT ${eager_impl_file} add_custom_command(OUTPUT ${eager_impl_file}
COMMAND ${CMAKE_COMMAND} -E env "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:." COMMAND ${CMAKE_COMMAND} -E env "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:."
"${CMAKE_CURRENT_BINARY_DIR}/eager_op_function_generator" "${CMAKE_CURRENT_BINARY_DIR}/eager_op_function_generator"
"${tmp_eager_impl_file}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt" "${tmp_eager_impl_file}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_eager_impl_file} ${eager_impl_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_eager_impl_file} ${eager_impl_file}
COMMENT "copy_if_different ${tmp_eager_impl_file} to ${eager_impl_file}" COMMENT "copy_if_different ${tmp_eager_impl_file} to ${eager_impl_file}"
DEPENDS ${EAGER_OP_IMPL_DEPS} DEPENDS ${EAGER_OP_IMPL_DEPS}
......
...@@ -32,8 +32,6 @@ ...@@ -32,8 +32,6 @@
#endif #endif
#include "paddle/fluid/pybind/op_function_generator.h" #include "paddle/fluid/pybind/op_function_generator.h"
std::set<std::string> gen_list = {};
// clang-format off // clang-format off
const char* OUT_INITIALIZER_TEMPLATE = const char* OUT_INITIALIZER_TEMPLATE =
R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})"; R"({"%s", {std::shared_ptr<imperative::VarBase>(new imperative::VarBase("auto_"+std::to_string(VarBaseUniqueNameID++)+"_"))}})";
...@@ -313,9 +311,6 @@ GenerateOpFunctions() { ...@@ -313,9 +311,6 @@ GenerateOpFunctions() {
!pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) { !pten::KernelFactory::Instance().HasCompatiblePtenKernel(op_type)) {
continue; continue;
} }
if (!gen_list.count(op_type)) {
continue;
}
std::string func_name = "eager_api_" + op_type; std::string func_name = "eager_api_" + op_type;
std::string op_function_str = GenerateOpFunctionsBody(op_proto, func_name); std::string op_function_str = GenerateOpFunctionsBody(op_proto, func_name);
...@@ -329,28 +324,12 @@ GenerateOpFunctions() { ...@@ -329,28 +324,12 @@ GenerateOpFunctions() {
return std::make_tuple(op_function_list, bind_function_list); return std::make_tuple(op_function_list, bind_function_list);
} }
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
gen_list.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 3) { if (argc != 2) {
std::cerr << "argc must be 3" << std::endl; std::cerr << "argc must be 2" << std::endl;
return -1; return -1;
} }
CollectOperatorsToCodeGen(argv[2]);
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
auto ascend_ptr = paddle::framework::AscendInstance::GetInstance(); auto ascend_ptr = paddle::framework::AscendInstance::GetInstance();
ascend_ptr->InitGEForUT(); ascend_ptr->InitGEForUT();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册