未验证 提交 9ecb7461 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Enabled Eager Dygraph AutoCodeGen for 500+ existing ops (#37753)

* Handled dispensable tensors in AutoCodeGen for Eager Dygraph

* Enabled Eager Dygraph AutoCodeGen for 500+ existing ops
上级 7094251b
...@@ -47,12 +47,12 @@ if(WIN32) ...@@ -47,12 +47,12 @@ if(WIN32)
endif() endif()
add_custom_target(eager_codegen add_custom_target(eager_codegen
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt"
DEPENDS ${EAGER_CODEGEN_DEPS} DEPENDS ${EAGER_CODEGEN_DEPS}
VERBATIM) VERBATIM)
else() else()
add_custom_target(eager_codegen add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/op_list.txt"
DEPENDS eager_generator DEPENDS eager_generator
VERBATIM) VERBATIM)
endif() endif()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h>
#include <algorithm> #include <algorithm>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
...@@ -26,6 +27,9 @@ ...@@ -26,6 +27,9 @@
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
DEFINE_bool(generate_all, false,
"Generate all operators currently registered in Paddle");
static std::unordered_set<std::string> operators_to_skip = { static std::unordered_set<std::string> operators_to_skip = {
"fused_elemwise_add_activation", // No Default Attr "fused_elemwise_add_activation", // No Default Attr
"fused_elemwise_activation", // No Default Attr "fused_elemwise_activation", // No Default Attr
...@@ -40,12 +44,10 @@ static std::unordered_set<std::string> operators_to_skip = { ...@@ -40,12 +44,10 @@ static std::unordered_set<std::string> operators_to_skip = {
"pull_box_sparse", "pull_box_sparse",
"fused_attention", "fused_attention",
"diag_v2", "diag_v2",
}; "transfer_dtype",
"c_split"};
static std::unordered_set<std::string> operators_to_codegen = {
"sigmoid", "matmul_v2", "reduce_sum", "elementwise_add",
"share_buffer", "var_conv_2d", "split"};
static std::unordered_set<std::string> operators_to_codegen = {};
static std::unordered_set<std::string> skipped_operators = {}; static std::unordered_set<std::string> skipped_operators = {};
namespace paddle { namespace paddle {
...@@ -353,7 +355,10 @@ static bool CheckOpProto(proto::OpProto* op_proto) { ...@@ -353,7 +355,10 @@ static bool CheckOpProto(proto::OpProto* op_proto) {
// Only handle matmul_v2 for now // Only handle matmul_v2 for now
VLOG(1) << "------ Analyzing Op ------: " << op_type; VLOG(1) << "------ Analyzing Op ------: " << op_type;
if (!FLAGS_generate_all) {
if (!operators_to_codegen.count(op_type)) return false; if (!operators_to_codegen.count(op_type)) return false;
}
if (operators_to_skip.count(op_type)) return false; if (operators_to_skip.count(op_type)) return false;
return true; return true;
...@@ -976,7 +981,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -976,7 +981,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum); paddle::string::Sprintf(FWD_NUM_ARG_TEMPLATE, outnum);
dygraph_function_args_str += arg_str; dygraph_function_args_str += arg_str;
const char* FWD_OUTS_CONTENT_TEMPLATE = const char* FWD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput(%s) },"; "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput(%s) },";
outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE, outs_contents_str += paddle::string::Sprintf(FWD_OUTS_CONTENT_TEMPLATE,
output_name, outnum); output_name, outnum);
} else { } else {
...@@ -1253,7 +1258,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -1253,7 +1258,7 @@ static std::string GenerateGradNodeCCContents(
if (duplicable_input_name_set.count(fwd_input_name)) { if (duplicable_input_name_set.count(fwd_input_name)) {
const char* GRAD_OUTS_CONTENT_TEMPLATE = const char* GRAD_OUTS_CONTENT_TEMPLATE =
"{ \"%s\", egr::ConstructDuplicableOutput( " "{ \"%s\", egr::EagerUtils::ConstructDuplicableOutput( "
"this->OutputMeta()[%d].Size() ) },"; "this->OutputMeta()[%d].Size() ) },";
outs_contents_str += paddle::string::Sprintf( outs_contents_str += paddle::string::Sprintf(
GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position); GRAD_OUTS_CONTENT_TEMPLATE, grad_output_name, fwd_input_position);
...@@ -1639,13 +1644,30 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1639,13 +1644,30 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
static void CollectOperatorsToCodeGen(const std::string& op_list_path) {
std::string line;
std::ifstream op_list_file(op_list_path);
if (op_list_file.is_open()) {
while (getline(op_list_file, line)) {
operators_to_codegen.insert(line);
}
op_list_file.close();
} else {
PADDLE_THROW(
paddle::platform::errors::Fatal("Unable to open op_list.txt file"));
}
}
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 2) { if (argc != 3) {
std::cerr << "argc must be 2" << std::endl; std::cerr << "argc must be 2" << std::endl;
return -1; return -1;
} }
std::string eager_root = argv[1]; std::string eager_root = argv[1];
std::string op_list_path = argv[2];
CollectOperatorsToCodeGen(op_list_path);
paddle::framework::DygraphCodeGeneration(eager_root); paddle::framework::DygraphCodeGeneration(eager_root);
return 0; return 0;
......
...@@ -2,3 +2,504 @@ sigmoid ...@@ -2,3 +2,504 @@ sigmoid
matmul_v2 matmul_v2
reduce_sum reduce_sum
elementwise_add elementwise_add
rsqrt
multihead_matmul
addmm
gru
round
rank_attention
fused_embedding_fc_lstm
where_index
bicubic_interp
arg_min
tile
bilinear_tensor_product
ctc_align
pow2_decay_with_linear_warmup
split
fc
clear_float_status
load
elementwise_max
adadelta
chunk_eval
check_finite_and_unscale
sparse_momentum
tan
adam
fsp
where
logical_xor
multiclass_nms3
one_hot_v2
sequence_softmax
affine_channel
triangular_solve
sequence_topk_avg_pooling
space_to_depth
reverse
fused_embedding_eltwise_layernorm
expand_v2
lgamma
solve
deformable_psroi_pooling
instance_norm
decode_jpeg
gather_nd
reduce_prod
matrix_rank
asin
lstmp
iou_similarity
huber_loss
one_hot
sequence_slice
lookup_table
softplus
depthwise_conv2d
fused_fc_elementwise_layernorm
sigmoid_cross_entropy_with_logits
exp
scatter
equal_all
searchsorted
fusion_squared_mat_sub
unique
log
conv_shift
smooth_l1_loss
linear_interp_v2
momentum
temporal_shift
nce
mv
proximal_gd
memcpy_h2d
add_position_encoding
cosh
hash
grad_add
sign
prelu
linspace
fill_diagonal
logsigmoid
load_combine
fetch_v2
randperm
sequence_scatter
partial_sum
relu6
conv3d
lstm_unit
not_equal
transpose2
uniform_random_batch_size_like
unfold
lrn
softmax_with_cross_entropy
isfinite_v2
bernoulli
max_pool3d_with_index
gaussian_random
flatten2
matmul
cvm
adamax
masked_select
range
bitwise_not
trace
multinomial
modified_huber_loss
roll
squared_l2_distance
conv3d_transpose
share_data
fake_quantize_abs_max
unique_with_counts
fill
concat
fill_zeros_like
hierarchical_sigmoid
isinf_v2
squeeze
multiclass_nms2
bpr_loss
fft_c2c
bicubic_interp_v2
reshape
coalesce_tensor
roi_align
reshape2
reduce_any
unstack
scatter_nd_add
sequence_reshape
bilateral_slice
fill_any_like
empty
pad_constant_like
pool2d
size
imag
eigh
stack
dgc_momentum
lamb
generate_proposals_v2
bitwise_or
gru_unit
fake_channel_wise_quantize_dequantize_abs_max
sampling_id
unsqueeze2
average_accumulates
sequence_enumerate
fusion_seqconv_eltadd_relu
bce_loss
generate_proposal_labels
im2sequence
isinf
adagrad
linear_chain_crf
retinanet_target_assign
fusion_group
teacher_student_sigmoid_loss
random_crop
lookup_table_v2
detection_map
l1_norm
sqrt
fused_elemwise_activation
slogdeterminant
share_buffer
bitwise_and
diag_embed
unbind
dropout
moving_average_abs_max_scale
beam_search
log_loss
greater_than
kron
sigmoid_focal_loss
rmsprop
conv2d
uniform_random_inplace
maxout
linear_interp
auc
logical_or
batch_norm
acos
unpool
cumprod
sample_logits
pull_box_extended_sparse
crop_tensor
fill_constant
deformable_conv
generate_mask_labels
locality_aware_nms
expand_as
matrix_power
greater_equal
generate_proposals
bilinear_interp
inplace_abn
softshrink
mul
data_norm
get_tensor_from_selected_rows
spp
floor
gelu
retinanet_detection_output
minus
push_dense
silu
sequence_erase
real
nearest_interp_v2
dgc_clip_by_norm
squeeze2
strided_slice
conj
precision_recall
save
fusion_seqexpand_concat_fc
fake_quantize_range_abs_max
depthwise_conv2d_transpose
positive_negative_pair
square
var_conv_2d
log1p
fused_softmax_mask_upper_triangle
clip_by_norm
atan2
box_decoder_and_assign
fft_r2c
roi_pool
overlap_add
fill_constant_batch_size_like
fill_any
dequantize_log
max_pool2d_with_index
pad3d
norm
viterbi_decode
mish
box_coder
flatten
elementwise_mod
margin_cross_entropy
pull_sparse
logical_and
pow
stanh
label_smooth
merged_momentum
ascend_trigger
fused_feedforward
rpn_target_assign
roi_perspective_transform
expand
prroi_pool
pool3d
memcpy
distribute_fpn_proposals
frame
bincount
shape
group_norm
resnet_unit
sequence_expand_as
cos_sim
eigvals
save_combine
class_center_sample
read_file
isfinite
arg_max
equal
fake_dequantize_max_abs
qr
anchor_generator
layer_norm
merge_selected_rows
less_equal
rnn
fusion_lstm
lars_momentum
hard_sigmoid
isnan
elementwise_floordiv
correlation
histogram
gather_tree
segment_pool
sync_batch_norm
fusion_repeated_fc_relu
nop
fused_attention
expand_as_v2
filter_by_instag
diag_v2
pull_box_sparse
nll_loss
dot
scale
ncclBcast
shuffle_batch
ncclReduce
diag
multiplex
leaky_relu
allclose
adamw
elementwise_pow
prior_box
p_norm
unique_consecutive
lod_reset
pad
sequence_conv
log10
set_value
bitwise_xor
center_loss
randint
attention_lstm
uniform_random
slice
meshgrid
hard_swish
sin
mean_iou
pad2d
inverse
spectral_norm
shuffle_channel
psroi_pool
seed
ceil
eig
reduce_min
cos
ncclAllReduce
cudnn_lstm
digamma
assign_value
increment
tdm_sampler
fused_softmax_mask
sequence_reverse
eigvalsh
diagonal
trunc
log2
marker
tanh
yolov3_loss
graph_send_recv
accuracy
atan
less_than
unsqueeze
crf_decoding
log_softmax
ftrl
matrix_nms
top_k_v2
cast
tanh_shrink
hard_shrink
multiclass_nms
fusion_transpose_flatten_concat
sequence_unpad
fused_elemwise_add_activation
pull_sparse_v2
frobenius_norm
crop
cross_entropy2
skip_layernorm
tdm_child
fused_embedding_seq_pool
erf
conv2d_inception_fusion
trilinear_interp
logsumexp
fusion_seqpool_concat
alloc_float_status
sequence_concat
fusion_seqpool_cvm_concat
similarity_focus
argsort
sequence_expand
sgd
fused_bn_add_activation
bilinear_interp_v2
clip
deformable_conv_v1
hinge_loss
determinant
conv2d_transpose
memcpy_d2h
softsign
fake_quantize_dequantize_abs_max
broadcast_tensors
grid_sampler
fft_c2r
pyramid_hash
fake_quantize_dequantize_moving_average_abs_max
multi_dot
sequence_pool
transpose
top_k
dist
affine_grid
gaussian_random_batch_size_like
fake_channel_wise_dequantize_max_abs
reciprocal
sequence_mask
fill_diagonal_tensor
abs
partial_concat
elu
index_select
row_conv
cross
elementwise_mul
decayed_adagrad
bipartite_match
run_program
fake_quantize_moving_average_abs_max
mine_hard_examples
target_assign
lstm
truncated_gaussian_random
match_matrix_tensor
elementwise_div
kldiv_loss
cumsum
sum
proximal_adagrad
update_loss_scaling
shard_index
selu
mean
gumbel_softmax
sequence_pad
tree_conv
assign
flatten_contiguous_range
tril_triu
brelu
celu
reduce_mean
sinh
rank_loss
reduce_max
fusion_gru
fill_zeros_like2
expm1
squared_l2_norm
elementwise_sub
margin_rank_loss
faster_tokenizer
relu
is_empty
reduce_all
edit_distance
bmm
yolo_box
soft_relu
density_prior_box
eye
swish
cross_entropy
dpsgd
cholesky
batch_fc
nearest_interp
gather
trilinear_interp_v2
box_clip
isnan_v2
softmax
conv2d_fusion
fused_batch_norm_act
get_float_status
index_sample
elementwise_min
logical_not
collect_fpn_proposals
pixel_shuffle
thresholded_relu
polygon_box_transform
lookup_table_dequant
warpctc
fake_channel_wise_quantize_abs_max
dequantize_abs_max
svd
flip
...@@ -60,7 +60,7 @@ TEST(EagerUtils, AutoGradMeta) { ...@@ -60,7 +60,7 @@ TEST(EagerUtils, AutoGradMeta) {
std::vector<AutogradMeta*> autograd_metas = std::vector<AutogradMeta*> autograd_metas =
EagerUtils::multi_autograd_meta(&ets); EagerUtils::multi_autograd_meta(&ets);
std::vector<AutogradMeta*> unsafe_autograd_metas = std::vector<AutogradMeta*> unsafe_autograd_metas =
EagerUtils::unsafe_autograd_meta(&ets); EagerUtils::unsafe_autograd_meta(ets);
CHECK_NOTNULL(unsafe_autograd_metas[0]); CHECK_NOTNULL(unsafe_autograd_metas[0]);
CHECK_NOTNULL(unsafe_autograd_metas[1]); CHECK_NOTNULL(unsafe_autograd_metas[1]);
......
...@@ -48,9 +48,9 @@ AutogradMeta* EagerUtils::unsafe_autograd_meta(const egr::EagerTensor& target) { ...@@ -48,9 +48,9 @@ AutogradMeta* EagerUtils::unsafe_autograd_meta(const egr::EagerTensor& target) {
} }
std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta( std::vector<AutogradMeta*> EagerUtils::unsafe_autograd_meta(
std::vector<egr::EagerTensor>* targets) { const std::vector<egr::EagerTensor>& targets) {
std::vector<AutogradMeta*> metas; std::vector<AutogradMeta*> metas;
for (const egr::EagerTensor& t : *targets) { for (const egr::EagerTensor& t : targets) {
metas.push_back(unsafe_autograd_meta(t)); metas.push_back(unsafe_autograd_meta(t));
} }
return metas; return metas;
......
...@@ -114,7 +114,7 @@ class EagerUtils { ...@@ -114,7 +114,7 @@ class EagerUtils {
// This method will return an AutogradMeta pointer unsafely. // This method will return an AutogradMeta pointer unsafely.
static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target); static AutogradMeta* unsafe_autograd_meta(const egr::EagerTensor& target);
static std::vector<AutogradMeta*> unsafe_autograd_meta( static std::vector<AutogradMeta*> unsafe_autograd_meta(
std::vector<egr::EagerTensor>* targets); const std::vector<egr::EagerTensor>& targets);
template <typename T, typename... Args> template <typename T, typename... Args>
static bool ComputeRequireGrad(T trace_backward, Args&&... args) { static bool ComputeRequireGrad(T trace_backward, Args&&... args) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册