From cc8a78582223c851c1b9d4d58496971af5a32af6 Mon Sep 17 00:00:00 2001 From: Jiabin Yang <360788950@qq.com> Date: Fri, 3 Feb 2023 10:29:12 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Prim=E3=80=91Blacklist=20bwd=20comp=20?= =?UTF-8?q?(#50148)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor dir for prim * support blacklist for bwd comp * fix type error * remove additional file * fix git ignore * add more test * merge develop --- .../generator/eager_gen.py | 2 +- .../elementwise/elementwise_add_op.cc | 2 +- .../elementwise/elementwise_div_op.cc | 2 +- .../elementwise/elementwise_mul_op.cc | 2 +- .../elementwise/elementwise_sub_op.cc | 2 +- paddle/fluid/operators/expand_v2_op.cc | 2 +- .../operators/generator/templates/op.c.j2 | 2 +- .../operators/reduce_ops/reduce_sum_op.cc | 2 +- paddle/fluid/prim/api/.gitignore | 5 +- paddle/fluid/prim/api/CMakeLists.txt | 4 +- paddle/fluid/prim/api/all.h | 6 +- .../api/auto_code_generated/CMakeLists.txt | 9 ++- .../prim/api/auto_code_generated/prim_gen.py | 13 ++-- .../composite_backward_api.h | 4 +- .../fluid/prim/api/generated/CMakeLists.txt | 1 - .../CMakeLists.txt | 5 -- paddle/fluid/prim/api/manual/CMakeLists.txt | 1 - .../fluid/prim/api/manual_prim/CMakeLists.txt | 5 ++ .../prim_manual_api.h} | 7 +- .../static_prim_api.cc | 5 +- .../utils/CMakeLists.txt | 0 .../utils/eager_utils.cc | 2 +- .../utils/static_utils.cc | 2 +- .../api/{manual => manual_prim}/utils/utils.h | 0 paddle/fluid/prim/tests/test_static_prim.cc | 2 +- .../prim/utils/static/static_global_utils.h | 17 ++++- paddle/fluid/prim/utils/utils.cc | 18 ++++- paddle/fluid/prim/utils/utils.h | 5 +- paddle/fluid/pybind/pybind.cc | 8 +- python/paddle/fluid/core.py | 23 ++++-- .../prim/prim/flags/test_prim_flags.py | 10 +++ .../unittests/prim/test_comp_skip_op_set.py | 75 +++++++++++++++++++ 32 files changed, 182 insertions(+), 61 deletions(-) rename paddle/fluid/prim/api/{manual/backward => composite_backward}/composite_backward_api.h (98%) delete mode 100644 paddle/fluid/prim/api/generated/CMakeLists.txt rename paddle/fluid/prim/api/{generated/prim_api => generated_prim}/CMakeLists.txt (63%) delete mode 100644 paddle/fluid/prim/api/manual/CMakeLists.txt create mode 100644 paddle/fluid/prim/api/manual_prim/CMakeLists.txt rename paddle/fluid/prim/api/{manual/prim_api/prim_api.h => manual_prim/prim_manual_api.h} (78%) rename paddle/fluid/prim/api/{generated/prim_api => manual_prim}/static_prim_api.cc (98%) rename paddle/fluid/prim/api/{manual => manual_prim}/utils/CMakeLists.txt (100%) rename paddle/fluid/prim/api/{manual => manual_prim}/utils/eager_utils.cc (97%) rename paddle/fluid/prim/api/{manual => manual_prim}/utils/static_utils.cc (98%) rename paddle/fluid/prim/api/{manual => manual_prim}/utils/utils.h (100%) create mode 100644 python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 650bf0626f1..3497a1217cf 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -330,7 +330,7 @@ NODE_CC_FILE_TEMPLATE = """ #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/phi/api/include/sparse_api.h" #include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/utils/utils.h" DECLARE_bool(check_nan_inf); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 48a5d2e433a..26fcd53621a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 41549ede1eb..7d96c310658 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 740c9381d92..61467be4c9b 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index 2a9e14867ac..d19f557bfe3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 3c05ab9295c..253c2856063 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" diff --git a/paddle/fluid/operators/generator/templates/op.c.j2 b/paddle/fluid/operators/generator/templates/op.c.j2 index 2339822af28..f54f91073da 100644 --- a/paddle/fluid/operators/generator/templates/op.c.j2 +++ b/paddle/fluid/operators/generator/templates/op.c.j2 @@ -5,7 +5,7 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 25e6ad9b65c..9af1770a41d 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" diff --git a/paddle/fluid/prim/api/.gitignore b/paddle/fluid/prim/api/.gitignore index 377e800f00a..2dad6249be3 100644 --- a/paddle/fluid/prim/api/.gitignore +++ b/paddle/fluid/prim/api/.gitignore @@ -1,3 +1,2 @@ -generated/prim_api/eager_prim_api.cc -generated/prim_api/tmp_eager_prim_api.cc -generated/prim_api/*.h +generated_prim/*.cc +generated_prim/*.h diff --git a/paddle/fluid/prim/api/CMakeLists.txt b/paddle/fluid/prim/api/CMakeLists.txt index 436cecc3258..6cf3dacef9f 100644 --- a/paddle/fluid/prim/api/CMakeLists.txt +++ b/paddle/fluid/prim/api/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(auto_code_generated) -add_subdirectory(manual) -add_subdirectory(generated) +add_subdirectory(manual_prim) +add_subdirectory(generated_prim) if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_library( diff --git a/paddle/fluid/prim/api/all.h b/paddle/fluid/prim/api/all.h index 2996d2aa265..b275e163cbc 100644 --- a/paddle/fluid/prim/api/all.h +++ b/paddle/fluid/prim/api/all.h @@ -13,6 +13,6 @@ // limitations under the License. #pragma once -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" +#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" diff --git a/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt b/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt index e36af681bbd..ebff0ec688a 100644 --- a/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt +++ b/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt @@ -5,16 +5,17 @@ set(legacy_api_yaml_path "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml" ) set(tmp_eager_prim_api_cc_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_eager_prim_api.cc" + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc" ) set(tmp_prim_api_h_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_prim_api.h" + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_prim_generated_api.h" ) set(eager_prim_api_cc_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc" + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc" ) set(prim_api_h_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/prim_api.h") + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h" +) set(prim_api_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py) diff --git a/paddle/fluid/prim/api/auto_code_generated/prim_gen.py b/paddle/fluid/prim/api/auto_code_generated/prim_gen.py index 7bc59df4f33..787eeb3e440 100644 --- a/paddle/fluid/prim/api/auto_code_generated/prim_gen.py +++ b/paddle/fluid/prim/api/auto_code_generated/prim_gen.py @@ -28,11 +28,11 @@ def header_include(): """ -def eager_source_include(header_file_path): +def eager_source_include(): return """ #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" """ @@ -73,10 +73,7 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path): header_file.write(header_include()) header_file.write(namespace[0]) header_file.write(namespace[1]) - include_header_file = ( - "#include paddle/fluid/prim/api/generated/prim_api/prim_api.h" - ) - eager_prim_source_file.write(eager_source_include(include_header_file)) + eager_prim_source_file.write(eager_source_include()) eager_prim_source_file.write(namespace[0]) for api in apis: @@ -106,13 +103,13 @@ def main(): parser.add_argument( '--prim_api_header_path', help='output of generated prim_api header code file', - default='paddle/fluid/prim/api/generated/prim_api/prim_api.h', + default='paddle/fluid/prim/api/generated_prim/prim_generated_api.h', ) parser.add_argument( '--eager_prim_api_source_path', help='output of generated eager_prim_api source code file', - default='paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc', + default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc', ) options = parser.parse_args() diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h similarity index 98% rename from paddle/fluid/prim/api/manual/backward/composite_backward_api.h rename to paddle/fluid/prim/api/composite_backward/composite_backward_api.h index a9c8953a228..e782d6b65bb 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -13,9 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/all.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/ddim.h" diff --git a/paddle/fluid/prim/api/generated/CMakeLists.txt b/paddle/fluid/prim/api/generated/CMakeLists.txt deleted file mode 100644 index a1b75527c20..00000000000 --- a/paddle/fluid/prim/api/generated/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(prim_api) diff --git a/paddle/fluid/prim/api/generated/prim_api/CMakeLists.txt b/paddle/fluid/prim/api/generated_prim/CMakeLists.txt similarity index 63% rename from paddle/fluid/prim/api/generated/prim_api/CMakeLists.txt rename to paddle/fluid/prim/api/generated_prim/CMakeLists.txt index ee39c73f99f..6e030052d77 100644 --- a/paddle/fluid/prim/api/generated/prim_api/CMakeLists.txt +++ b/paddle/fluid/prim/api/generated_prim/CMakeLists.txt @@ -1,8 +1,3 @@ -cc_library( - static_prim_api - SRCS static_prim_api.cc - DEPS proto_desc static_utils) - if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_library( eager_prim_api diff --git a/paddle/fluid/prim/api/manual/CMakeLists.txt b/paddle/fluid/prim/api/manual/CMakeLists.txt deleted file mode 100644 index 512d2b1553c..00000000000 --- a/paddle/fluid/prim/api/manual/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(utils) diff --git a/paddle/fluid/prim/api/manual_prim/CMakeLists.txt b/paddle/fluid/prim/api/manual_prim/CMakeLists.txt new file mode 100644 index 00000000000..7437c737a7b --- /dev/null +++ b/paddle/fluid/prim/api/manual_prim/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(utils) +cc_library( + static_prim_api + SRCS static_prim_api.cc + DEPS proto_desc static_utils) diff --git a/paddle/fluid/prim/api/manual/prim_api/prim_api.h b/paddle/fluid/prim/api/manual_prim/prim_manual_api.h similarity index 78% rename from paddle/fluid/prim/api/manual/prim_api/prim_api.h rename to paddle/fluid/prim/api/manual_prim/prim_manual_api.h index 65d411d8630..80d11aed348 100644 --- a/paddle/fluid/prim/api/manual/prim_api/prim_api.h +++ b/paddle/fluid/prim/api/manual_prim/prim_manual_api.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// prim api which can't be generated #pragma once +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/optional.h" - +// TODO(jiabin): Make this Header only for handwritten api, instead of include +// prim_generated_api.h namespace paddle { namespace prim {} // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/manual_prim/static_prim_api.cc similarity index 98% rename from paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc rename to paddle/fluid/prim/api/manual_prim/static_prim_api.cc index b879ade5a9e..71d547c139a 100644 --- a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/manual_prim/static_prim_api.cc @@ -26,9 +26,8 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/api/include/tensor.h" diff --git a/paddle/fluid/prim/api/manual/utils/CMakeLists.txt b/paddle/fluid/prim/api/manual_prim/utils/CMakeLists.txt similarity index 100% rename from paddle/fluid/prim/api/manual/utils/CMakeLists.txt rename to paddle/fluid/prim/api/manual_prim/utils/CMakeLists.txt diff --git a/paddle/fluid/prim/api/manual/utils/eager_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc similarity index 97% rename from paddle/fluid/prim/api/manual/utils/eager_utils.cc rename to paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc index 353945557f1..04854428d8e 100644 --- a/paddle/fluid/prim/api/manual/utils/eager_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/utils/global_utils.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/phi/api/include/tensor.h" namespace paddle { diff --git a/paddle/fluid/prim/api/manual/utils/static_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc similarity index 98% rename from paddle/fluid/prim/api/manual/utils/static_utils.cc rename to paddle/fluid/prim/api/manual_prim/utils/static_utils.cc index 74656cfe7d4..8cfcffd92c2 100644 --- a/paddle/fluid/prim/api/manual/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/phi/api/include/tensor.h" diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h similarity index 100% rename from paddle/fluid/prim/api/manual/utils/utils.h rename to paddle/fluid/prim/api/manual_prim/utils/utils.h diff --git a/paddle/fluid/prim/tests/test_static_prim.cc b/paddle/fluid/prim/tests/test_static_prim.cc index 313a3ccc99b..5a53101ab13 100644 --- a/paddle/fluid/prim/tests/test_static_prim.cc +++ b/paddle/fluid/prim/tests/test_static_prim.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/utils.h" #include "paddle/phi/core/enforce.h" diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index e878c857f26..e6a8054f1a7 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -69,6 +69,18 @@ class StaticCompositeContext { enable_bwd_prim_ = enable_prim; } + size_t CheckSkipCompOps(const std::string& op_type) const { + return skip_comp_ops_.count(op_type); + } + + void AddSkipCompOps(const std::string& op_type) { + skip_comp_ops_.insert(op_type); + } + + void RemoveSkipCompOps(const std::string& op_type) { + skip_comp_ops_.erase(op_type); + } + void SetTargetGradName(const std::map& m) { target_grad_name_ = m; } @@ -79,10 +91,13 @@ class StaticCompositeContext { private: StaticCompositeContext() - : current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {} + : current_block_desc_(nullptr), + generator_(new UniqueNameGenerator()), + skip_comp_ops_({"matmul_v2"}) {} framework::BlockDesc* current_block_desc_; std::unique_ptr generator_; + std::unordered_set skip_comp_ops_; std::map target_grad_name_; static thread_local bool enable_bwd_prim_; static thread_local bool enable_fwd_prim_; diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index a869e5609b9..e7653161680 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -24,7 +24,7 @@ bool PrimCommonUtils::IsBwdPrimEnabled() { } void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); + StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); } bool PrimCommonUtils::IsFwdPrimEnabled() { @@ -32,11 +32,23 @@ bool PrimCommonUtils::IsFwdPrimEnabled() { } void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); + StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); } void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); + StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); +} + +size_t PrimCommonUtils::CheckSkipCompOps(const std::string& op_type) { + return StaticCompositeContext::Instance().CheckSkipCompOps(op_type); +} + +void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) { + StaticCompositeContext::Instance().AddSkipCompOps(op_type); +} + +void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) { + StaticCompositeContext::Instance().RemoveSkipCompOps(op_type); } void PrimCommonUtils::SetTargetGradName( diff --git a/paddle/fluid/prim/utils/utils.h b/paddle/fluid/prim/utils/utils.h index 4ede84a947b..8718496b3f1 100644 --- a/paddle/fluid/prim/utils/utils.h +++ b/paddle/fluid/prim/utils/utils.h @@ -13,9 +13,9 @@ // limitations under the License. #pragma once - #include #include +#include namespace paddle { namespace prim { @@ -26,6 +26,9 @@ class PrimCommonUtils { static bool IsFwdPrimEnabled(); static void SetFwdPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled); + static size_t CheckSkipCompOps(const std::string& op_type); + static void AddSkipCompOps(const std::string& op_type); + static void RemoveSkipCompOps(const std::string& op_type); static void SetTargetGradName(const std::map& m); }; } // namespace prim diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8712e428bdf..29bf54823db 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1246,6 +1246,9 @@ All parameter, weight, gradient are variables in Paddle. return static_cast( defalut_val.index() - 1); }); + m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps); + m.def("_remove_skip_comp_ops", + &paddle::prim::PrimCommonUtils::RemoveSkipCompOps); m.def("get_grad_op_desc", [](const OpDesc &op_desc, const std::unordered_set &no_grad_set, @@ -1277,8 +1280,11 @@ All parameter, weight, gradient are variables in Paddle. // priority of CompGradOpMaker is less than GradCompMaker for better // performance. std::vector> grad_op_descs; + auto need_skip = + paddle::prim::PrimCommonUtils::CheckSkipCompOps(op_desc.Type()); + VLOG(3) << "need skip: " << need_skip << std::endl; if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) { - if (grad_comp_op_maker != nullptr) { + if ((grad_comp_op_maker != nullptr) && (!need_skip)) { VLOG(3) << "Runing composite fun for " << op_desc.Type(); grad_op_descs = grad_comp_op_maker(op_desc, no_grad_set, diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 54b51f0face..dcdd8847d84 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -306,6 +306,8 @@ try: from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _set_current_stream from .libpaddle import _get_phi_kernel_name + from .libpaddle import _add_skip_comp_ops + from .libpaddle import _remove_skip_comp_ops # prim controller flags from .libpaddle import __set_bwd_prim_enabled @@ -409,7 +411,7 @@ def __sync_stat_with_flag(flag): __set_fwd_prim_enabled(True) else: raise TypeError(f"flag {flag} should be true or false.") - logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) elif flag is "FLAGS_prim_backward": flag_value = os.getenv("FLAGS_prim_backward") assert flag_value is not None @@ -420,7 +422,7 @@ def __sync_stat_with_flag(flag): __set_bwd_prim_enabled(True) else: raise TypeError(f"flag {flag} should be true or false.") - logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) elif flag is "FLAGS_prim_all": flag_value = os.getenv("FLAGS_prim_all") assert flag_value is not None @@ -431,7 +433,7 @@ def __sync_stat_with_flag(flag): __set_all_prim_enabled(True) else: raise TypeError(f"flag {flag} should be true or false.") - logging.debug( + print( "all prim enabled: ", bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), ) @@ -441,19 +443,24 @@ def __sync_stat_with_flag(flag): ) +# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors. +def _test_use_sync(value): + __sync_stat_with_flag(value) + + def _set_prim_backward_enabled(value): __set_bwd_prim_enabled(bool(value)) - logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) def _set_prim_forward_enabled(value): __set_fwd_prim_enabled(bool(value)) - logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) def _set_prim_all_enabled(value): __set_all_prim_enabled(bool(value)) - logging.debug( + print( "all prim enabled: ", bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), ) @@ -462,7 +469,7 @@ def _set_prim_all_enabled(value): def __sync_prim_backward_status(): flag_value = os.getenv("FLAGS_prim_backward") if flag_value is None: - logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_backward") @@ -470,7 +477,7 @@ def __sync_prim_backward_status(): def __sync_prim_forward_status(): flag_value = os.getenv("FLAGS_prim_forward") if flag_value is None: - logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_forward") diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py index 8f3053af919..ef6c2951ff7 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py @@ -47,6 +47,16 @@ class TestPrimFlags(unittest.TestCase): core.check_and_set_prim_all_enabled() self.assertFalse(core._is_fwd_prim_enabled()) + del os.environ['FLAGS_prim_backward'] + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_bwd_prim_enabled()) + del os.environ['FLAGS_prim_forward'] + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_fwd_prim_enabled()) + + with self.assertRaises(TypeError): + core._test_use_sync("aaaa") + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py b/python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py new file mode 100644 index 00000000000..15648226e78 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import paddle +from paddle.fluid import core, framework + + +class TestGetGradOpDescPrimEnabled(unittest.TestCase): + def setUp(self): + self.fwd_type = 'tanh' + self.inputs = {'X': ['x']} + self.outputs = {'Out': ['y']} + self.no_grad_var = set() + self.grad_sub_block = tuple() + self.desired_ops = 'tanh_grad' + self.desired_ops_no_skip = ('pow', 'scale', 'elementwise_mul') + paddle.enable_static() + block = framework.Block(framework.Program(), 0) + block.append_op( + type=self.fwd_type, + inputs={ + n: [block.create_var(name=v, stop_gradient=False) for v in vs] + for n, vs in self.inputs.items() + }, + outputs={ + n: [block.create_var(name=v, stop_gradient=False) for v in vs] + for n, vs in self.outputs.items() + }, + ) + self.fwd = block.ops[0].desc + + def tearDown(self): + paddle.disable_static() + + def test_get_grad_op_desc_without_skip(self): + core._set_prim_backward_enabled(True) + actual = tuple( + desc.type() + for desc in core.get_grad_op_desc( + self.fwd, self.no_grad_var, self.grad_sub_block + )[0] + ) + self.assertEquals(actual, self.desired_ops_no_skip) + core._set_prim_backward_enabled(False) + + def test_get_grad_op_desc_with_skip(self): + core._set_prim_backward_enabled(True) + core._add_skip_comp_ops("tanh") + actual = tuple( + desc.type() + for desc in core.get_grad_op_desc( + self.fwd, self.no_grad_var, self.grad_sub_block + )[0] + ) + core._remove_skip_comp_ops("tanh") + self.assertEquals(actual[0], self.desired_ops) + core._set_prim_backward_enabled(False) + + +if __name__ == '__main__': + unittest.main() -- GitLab