From dcfe2f1adc8cd9c610c2c162dc197fff6d3002a9 Mon Sep 17 00:00:00 2001 From: Sonder <55493212+AndSonder@users.noreply.github.com> Date: Thu, 17 Aug 2023 13:16:17 +0800 Subject: [PATCH] Support control flow for static build [Step 1: support subgraph] (#56185) * remove execution_config.used_for_control_flow_op * update * update * open static build flag * close static build flag * open static build flag * add searchsorted to analyze dtype list * recover and add test_searchsorted_op to static build list * Update CMakeLists.txt * Update CMakeLists.txt --- .../framework/new_executor/interpreter/static_build.cc | 7 +++++++ paddle/fluid/framework/new_executor/program_interpreter.cc | 1 - test/legacy_test/CMakeLists.txt | 2 ++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/new_executor/interpreter/static_build.cc b/paddle/fluid/framework/new_executor/interpreter/static_build.cc index 2a9987874e4..10d75f1be6f 100644 --- a/paddle/fluid/framework/new_executor/interpreter/static_build.cc +++ b/paddle/fluid/framework/new_executor/interpreter/static_build.cc @@ -474,6 +474,13 @@ void FakeInitializeOutputsForFunctionKernel( ? DataType::INT64 : in_dtype; } + } else if (op_type == "searchsorted") { + bool out_int32 = op.Attr("out_int32"); + if (out_int32) { + dtype = DataType::INT32; + } else { + dtype = DataType::INT64; + } } else { VLOG(4) << "Get dtype result from InferMeta"; RuntimeInferShapeContext infer_shape_ctx(op, runtime_ctx); diff --git a/paddle/fluid/framework/new_executor/program_interpreter.cc b/paddle/fluid/framework/new_executor/program_interpreter.cc index 64ffdbef619..e288804e09a 100644 --- a/paddle/fluid/framework/new_executor/program_interpreter.cc +++ b/paddle/fluid/framework/new_executor/program_interpreter.cc @@ -47,7 +47,6 @@ ProgramInterpreter::ProgramInterpreter(const platform::Place& place, static_build_ = FLAGS_new_executor_static_build && !FLAGS_new_executor_use_cuda_graph && - !execution_config.used_for_control_flow_op && interpreter::BlockCanBeStaticBuilt(block); exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught); diff --git a/test/legacy_test/CMakeLists.txt b/test/legacy_test/CMakeLists.txt index b5ad632433d..7aded1aa469 100644 --- a/test/legacy_test/CMakeLists.txt +++ b/test/legacy_test/CMakeLists.txt @@ -1273,6 +1273,7 @@ set(STATIC_BUILD_TESTS test_adamw_op test_arg_min_max_op test_assign_pos_op + test_bucketize_api test_bincount_op test_c_embedding_op test_decayed_adagrad_op @@ -1303,6 +1304,7 @@ set(STATIC_BUILD_TESTS test_prune_gate_by_capacity_op test_random_routing_op test_reduce_op + test_searchsorted_op test_segment_ops test_sparse_momentum_op test_sgd_op_bf16 -- GitLab