diff --git a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py index 650bf0626f1ad22bdda64403fd76dd639fc5efce..3497a1217cfcba95d40482c3e158f3862c16e23b 100644 --- a/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/generator/eager_gen.py @@ -330,7 +330,7 @@ NODE_CC_FILE_TEMPLATE = """ #include "paddle/fluid/eager/nan_inf_utils.h" #include "paddle/phi/api/include/sparse_api.h" #include "paddle/fluid/eager/api/manual/eager_manual/nodes/nodes.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/api/all.h" #include "paddle/fluid/prim/utils/utils.h" DECLARE_bool(check_nan_inf); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cc b/paddle/fluid/operators/elementwise/elementwise_add_op.cc index 48a5d2e433a100061c4d8a903ea045a21828cf84..26fcd53621a2b12ee51a983d410a401ad45e0915 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_div_op.cc b/paddle/fluid/operators/elementwise/elementwise_div_op.cc index 41549ede1ebc6184b56db0026afe179b290e4281..7d96c3106584ad2cdc761b53f516b128c1a3f4b4 100644 --- a/paddle/fluid/operators/elementwise/elementwise_div_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_div_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc index 740c9381d92e233ceb2be3de156a9a62e1ac22f5..61467be4c9bd59057a905b3b45133b9c547c20a1 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/operators/elementwise/elementwise_op.h" #include "paddle/fluid/platform/complex.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc index 2a9e14867acf1f3caf105a6b31c69d31f073df39..d19f557bfe3c59b6672eb3e89d6e967e729c67d9 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op.cc @@ -15,7 +15,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" namespace paddle { diff --git a/paddle/fluid/operators/expand_v2_op.cc b/paddle/fluid/operators/expand_v2_op.cc index 3c05ab9295c6769bc7b949bc55bcd2321c063ba4..253c2856063ec7042b79298c5902029c27c99a64 100644 --- a/paddle/fluid/operators/expand_v2_op.cc +++ b/paddle/fluid/operators/expand_v2_op.cc @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" diff --git a/paddle/fluid/operators/generator/templates/op.c.j2 b/paddle/fluid/operators/generator/templates/op.c.j2 index 2339822af280fb2050d3e84dea3daa22395913e3..f54f91073da158fc1f144a7bb161dab583f59b09 100644 --- a/paddle/fluid/operators/generator/templates/op.c.j2 +++ b/paddle/fluid/operators/generator/templates/op.c.j2 @@ -5,7 +5,7 @@ #include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/core/infermeta_utils.h" diff --git a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc index 25e6ad9b65cc0662fd3ee5f1811cc1d20f2473f3..9af1770a41de69f2973442ff410f7e4c479f5db8 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc +++ b/paddle/fluid/operators/reduce_ops/reduce_sum_op.cc @@ -17,7 +17,7 @@ #include #include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" +#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h" #include "paddle/phi/core/infermeta_utils.h" #include "paddle/phi/infermeta/unary.h" diff --git a/paddle/fluid/prim/api/.gitignore b/paddle/fluid/prim/api/.gitignore index 377e800f00a0e08893961c1910c9b479e3143181..2dad6249be3ef8332b3df8788569ef77d8fc385d 100644 --- a/paddle/fluid/prim/api/.gitignore +++ b/paddle/fluid/prim/api/.gitignore @@ -1,3 +1,2 @@ -generated/prim_api/eager_prim_api.cc -generated/prim_api/tmp_eager_prim_api.cc -generated/prim_api/*.h +generated_prim/*.cc +generated_prim/*.h diff --git a/paddle/fluid/prim/api/CMakeLists.txt b/paddle/fluid/prim/api/CMakeLists.txt index 436cecc32582b39cfe08b2a06f9d4dba55387f50..6cf3dacef9f4ec8fc07a28530335b8fc8833be39 100644 --- a/paddle/fluid/prim/api/CMakeLists.txt +++ b/paddle/fluid/prim/api/CMakeLists.txt @@ -1,6 +1,6 @@ add_subdirectory(auto_code_generated) -add_subdirectory(manual) -add_subdirectory(generated) +add_subdirectory(manual_prim) +add_subdirectory(generated_prim) if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_library( diff --git a/paddle/fluid/prim/api/all.h b/paddle/fluid/prim/api/all.h index 2996d2aa2657c8b8c09cfabd30daa7c2adf707b6..b275e163cbc88d21e128d1ed71991193eeb8e76e 100644 --- a/paddle/fluid/prim/api/all.h +++ b/paddle/fluid/prim/api/all.h @@ -13,6 +13,6 @@ // limitations under the License. #pragma once -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/backward/composite_backward_api.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" +#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" diff --git a/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt b/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt index e36af681bbd89589d58e5a7003beacb83ff08c24..ebff0ec688a7e55a3da8f6a98586fd0d36eae71e 100644 --- a/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt +++ b/paddle/fluid/prim/api/auto_code_generated/CMakeLists.txt @@ -5,16 +5,17 @@ set(legacy_api_yaml_path "${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parsed_ops/legacy_ops.parsed.yaml" ) set(tmp_eager_prim_api_cc_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_eager_prim_api.cc" + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_eager_prim_api.cc" ) set(tmp_prim_api_h_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/tmp_prim_api.h" + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/tmp_prim_generated_api.h" ) set(eager_prim_api_cc_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc" + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/eager_prim_api.cc" ) set(prim_api_h_path - "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated/prim_api/prim_api.h") + "${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/generated_prim/prim_generated_api.h" +) set(prim_api_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/prim/api/auto_code_generated/prim_gen.py) diff --git a/paddle/fluid/prim/api/auto_code_generated/prim_gen.py b/paddle/fluid/prim/api/auto_code_generated/prim_gen.py index 7bc59df4f33d2de7bdbf76737461f0b848865c36..787eeb3e4409f1461e1cd41a351fd0f48e0aa739 100644 --- a/paddle/fluid/prim/api/auto_code_generated/prim_gen.py +++ b/paddle/fluid/prim/api/auto_code_generated/prim_gen.py @@ -28,11 +28,11 @@ def header_include(): """ -def eager_source_include(header_file_path): +def eager_source_include(): return """ #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" """ @@ -73,10 +73,7 @@ def generate_api(api_yaml_path, header_file_path, eager_prim_source_file_path): header_file.write(header_include()) header_file.write(namespace[0]) header_file.write(namespace[1]) - include_header_file = ( - "#include paddle/fluid/prim/api/generated/prim_api/prim_api.h" - ) - eager_prim_source_file.write(eager_source_include(include_header_file)) + eager_prim_source_file.write(eager_source_include()) eager_prim_source_file.write(namespace[0]) for api in apis: @@ -106,13 +103,13 @@ def main(): parser.add_argument( '--prim_api_header_path', help='output of generated prim_api header code file', - default='paddle/fluid/prim/api/generated/prim_api/prim_api.h', + default='paddle/fluid/prim/api/generated_prim/prim_generated_api.h', ) parser.add_argument( '--eager_prim_api_source_path', help='output of generated eager_prim_api source code file', - default='paddle/fluid/prim/api/generated/prim_api/eager_prim_api.cc', + default='paddle/fluid/prim/api/generated_prim/eager_prim_api.cc', ) options = parser.parse_args() diff --git a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h similarity index 98% rename from paddle/fluid/prim/api/manual/backward/composite_backward_api.h rename to paddle/fluid/prim/api/composite_backward/composite_backward_api.h index a9c8953a228f2fd41366e8b770ddfa6b10ad8278..e782d6b65bba62ac47615e5f4de4b339575781ee 100644 --- a/paddle/fluid/prim/api/manual/backward/composite_backward_api.h +++ b/paddle/fluid/prim/api/composite_backward/composite_backward_api.h @@ -13,9 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/all.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/core/ddim.h" diff --git a/paddle/fluid/prim/api/generated/CMakeLists.txt b/paddle/fluid/prim/api/generated/CMakeLists.txt deleted file mode 100644 index a1b75527c20b49d688bde9ea120a74046a411123..0000000000000000000000000000000000000000 --- a/paddle/fluid/prim/api/generated/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(prim_api) diff --git a/paddle/fluid/prim/api/generated/prim_api/CMakeLists.txt b/paddle/fluid/prim/api/generated_prim/CMakeLists.txt similarity index 63% rename from paddle/fluid/prim/api/generated/prim_api/CMakeLists.txt rename to paddle/fluid/prim/api/generated_prim/CMakeLists.txt index ee39c73f99f2f935664959f292884c7c95103452..6e030052d77a0fe1d73c06b94401c55650567b95 100644 --- a/paddle/fluid/prim/api/generated/prim_api/CMakeLists.txt +++ b/paddle/fluid/prim/api/generated_prim/CMakeLists.txt @@ -1,8 +1,3 @@ -cc_library( - static_prim_api - SRCS static_prim_api.cc - DEPS proto_desc static_utils) - if(NOT (NOT WITH_PYTHON AND ON_INFER)) cc_library( eager_prim_api diff --git a/paddle/fluid/prim/api/manual/CMakeLists.txt b/paddle/fluid/prim/api/manual/CMakeLists.txt deleted file mode 100644 index 512d2b1553c8c94a06445f3c59c4b77d10d74032..0000000000000000000000000000000000000000 --- a/paddle/fluid/prim/api/manual/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(utils) diff --git a/paddle/fluid/prim/api/manual_prim/CMakeLists.txt b/paddle/fluid/prim/api/manual_prim/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..7437c737a7b7f6bc8a3fec463b0d916dc680b87d --- /dev/null +++ b/paddle/fluid/prim/api/manual_prim/CMakeLists.txt @@ -0,0 +1,5 @@ +add_subdirectory(utils) +cc_library( + static_prim_api + SRCS static_prim_api.cc + DEPS proto_desc static_utils) diff --git a/paddle/fluid/prim/api/manual/prim_api/prim_api.h b/paddle/fluid/prim/api/manual_prim/prim_manual_api.h similarity index 78% rename from paddle/fluid/prim/api/manual/prim_api/prim_api.h rename to paddle/fluid/prim/api/manual_prim/prim_manual_api.h index 65d411d86307ded238a4bc07e6336659663ca406..80d11aed3489e6d781673fd3ef5c3a6f36e9e49b 100644 --- a/paddle/fluid/prim/api/manual/prim_api/prim_api.h +++ b/paddle/fluid/prim/api/manual_prim/prim_manual_api.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,15 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -// prim api which can't be generated #pragma once +#include "paddle/fluid/prim/api/generated_prim/prim_generated_api.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/int_array.h" #include "paddle/phi/common/place.h" #include "paddle/phi/common/scalar.h" #include "paddle/utils/optional.h" - +// TODO(jiabin): Make this Header only for handwritten api, instead of include +// prim_generated_api.h namespace paddle { namespace prim {} // namespace prim } // namespace paddle diff --git a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc b/paddle/fluid/prim/api/manual_prim/static_prim_api.cc similarity index 98% rename from paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc rename to paddle/fluid/prim/api/manual_prim/static_prim_api.cc index b879ade5a9e8921167fc38bf6b84fe5ef147a59f..71d547c139a1fff7881afdcf85204f1dbaf3ba2b 100644 --- a/paddle/fluid/prim/api/generated/prim_api/static_prim_api.cc +++ b/paddle/fluid/prim/api/manual_prim/static_prim_api.cc @@ -26,9 +26,8 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/prim/api/generated/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/prim_api/prim_api.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/phi/api/include/tensor.h" diff --git a/paddle/fluid/prim/api/manual/utils/CMakeLists.txt b/paddle/fluid/prim/api/manual_prim/utils/CMakeLists.txt similarity index 100% rename from paddle/fluid/prim/api/manual/utils/CMakeLists.txt rename to paddle/fluid/prim/api/manual_prim/utils/CMakeLists.txt diff --git a/paddle/fluid/prim/api/manual/utils/eager_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc similarity index 97% rename from paddle/fluid/prim/api/manual/utils/eager_utils.cc rename to paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc index 353945557f1d02386645a79c6b2d871fe90fb588..04854428d8e2b9ce0247ce296b09bfb1b515e895 100644 --- a/paddle/fluid/prim/api/manual/utils/eager_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/eager_utils.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h" #include "paddle/fluid/eager/api/utils/global_utils.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/phi/api/include/tensor.h" namespace paddle { diff --git a/paddle/fluid/prim/api/manual/utils/static_utils.cc b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc similarity index 98% rename from paddle/fluid/prim/api/manual/utils/static_utils.cc rename to paddle/fluid/prim/api/manual_prim/utils/static_utils.cc index 74656cfe7d48d17fe0c3fc2122896ef10f8535b7..8cfcffd92c2ea18a3f0723df282493e0052b01b6 100644 --- a/paddle/fluid/prim/api/manual/utils/static_utils.cc +++ b/paddle/fluid/prim/api/manual_prim/utils/static_utils.cc @@ -17,7 +17,7 @@ #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/static/static_global_utils.h" #include "paddle/phi/api/include/tensor.h" diff --git a/paddle/fluid/prim/api/manual/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h similarity index 100% rename from paddle/fluid/prim/api/manual/utils/utils.h rename to paddle/fluid/prim/api/manual_prim/utils/utils.h diff --git a/paddle/fluid/prim/tests/test_static_prim.cc b/paddle/fluid/prim/tests/test_static_prim.cc index 313a3ccc99b74de65305d8d8d1b07f06760e4593..5a53101ab13bb01c8336cf462873517a48596dcc 100644 --- a/paddle/fluid/prim/tests/test_static_prim.cc +++ b/paddle/fluid/prim/tests/test_static_prim.cc @@ -19,7 +19,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/prim/api/manual/utils/utils.h" +#include "paddle/fluid/prim/api/manual_prim/utils/utils.h" #include "paddle/fluid/prim/utils/static/desc_tensor.h" #include "paddle/fluid/prim/utils/utils.h" #include "paddle/phi/core/enforce.h" diff --git a/paddle/fluid/prim/utils/static/static_global_utils.h b/paddle/fluid/prim/utils/static/static_global_utils.h index e878c857f26254e470aad155d478ac210c53d67d..e6a8054f1a74784fed248b68034bf659e2d3b9d5 100644 --- a/paddle/fluid/prim/utils/static/static_global_utils.h +++ b/paddle/fluid/prim/utils/static/static_global_utils.h @@ -69,6 +69,18 @@ class StaticCompositeContext { enable_bwd_prim_ = enable_prim; } + size_t CheckSkipCompOps(const std::string& op_type) const { + return skip_comp_ops_.count(op_type); + } + + void AddSkipCompOps(const std::string& op_type) { + skip_comp_ops_.insert(op_type); + } + + void RemoveSkipCompOps(const std::string& op_type) { + skip_comp_ops_.erase(op_type); + } + void SetTargetGradName(const std::map& m) { target_grad_name_ = m; } @@ -79,10 +91,13 @@ class StaticCompositeContext { private: StaticCompositeContext() - : current_block_desc_(nullptr), generator_(new UniqueNameGenerator()) {} + : current_block_desc_(nullptr), + generator_(new UniqueNameGenerator()), + skip_comp_ops_({"matmul_v2"}) {} framework::BlockDesc* current_block_desc_; std::unique_ptr generator_; + std::unordered_set skip_comp_ops_; std::map target_grad_name_; static thread_local bool enable_bwd_prim_; static thread_local bool enable_fwd_prim_; diff --git a/paddle/fluid/prim/utils/utils.cc b/paddle/fluid/prim/utils/utils.cc index a869e5609b91fe1f3fc619d4254a48339b109a9c..e76531616807a63d5880fb2f4ad04cdbe033ac8f 100644 --- a/paddle/fluid/prim/utils/utils.cc +++ b/paddle/fluid/prim/utils/utils.cc @@ -24,7 +24,7 @@ bool PrimCommonUtils::IsBwdPrimEnabled() { } void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); + StaticCompositeContext::Instance().SetBwdPrimEnabled(enable_prim); } bool PrimCommonUtils::IsFwdPrimEnabled() { @@ -32,11 +32,23 @@ bool PrimCommonUtils::IsFwdPrimEnabled() { } void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); + StaticCompositeContext::Instance().SetFwdPrimEnabled(enable_prim); } void PrimCommonUtils::SetAllPrimEnabled(bool enable_prim) { - return StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); + StaticCompositeContext::Instance().SetAllPrimEnabled(enable_prim); +} + +size_t PrimCommonUtils::CheckSkipCompOps(const std::string& op_type) { + return StaticCompositeContext::Instance().CheckSkipCompOps(op_type); +} + +void PrimCommonUtils::AddSkipCompOps(const std::string& op_type) { + StaticCompositeContext::Instance().AddSkipCompOps(op_type); +} + +void PrimCommonUtils::RemoveSkipCompOps(const std::string& op_type) { + StaticCompositeContext::Instance().RemoveSkipCompOps(op_type); } void PrimCommonUtils::SetTargetGradName( diff --git a/paddle/fluid/prim/utils/utils.h b/paddle/fluid/prim/utils/utils.h index 4ede84a947b22f27582b8433fc10d6f119c56cb8..8718496b3f1884a25f7605847250eb2a59a45e58 100644 --- a/paddle/fluid/prim/utils/utils.h +++ b/paddle/fluid/prim/utils/utils.h @@ -13,9 +13,9 @@ // limitations under the License. #pragma once - #include #include +#include namespace paddle { namespace prim { @@ -26,6 +26,9 @@ class PrimCommonUtils { static bool IsFwdPrimEnabled(); static void SetFwdPrimEnabled(bool enabled); static void SetAllPrimEnabled(bool enabled); + static size_t CheckSkipCompOps(const std::string& op_type); + static void AddSkipCompOps(const std::string& op_type); + static void RemoveSkipCompOps(const std::string& op_type); static void SetTargetGradName(const std::map& m); }; } // namespace prim diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 8712e428bdf5ec3ac81d251224812c80fd33c4c0..29bf54823dba404915356e7fe838509b327646c3 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1246,6 +1246,9 @@ All parameter, weight, gradient are variables in Paddle. return static_cast( defalut_val.index() - 1); }); + m.def("_add_skip_comp_ops", &paddle::prim::PrimCommonUtils::AddSkipCompOps); + m.def("_remove_skip_comp_ops", + &paddle::prim::PrimCommonUtils::RemoveSkipCompOps); m.def("get_grad_op_desc", [](const OpDesc &op_desc, const std::unordered_set &no_grad_set, @@ -1277,8 +1280,11 @@ All parameter, weight, gradient are variables in Paddle. // priority of CompGradOpMaker is less than GradCompMaker for better // performance. std::vector> grad_op_descs; + auto need_skip = + paddle::prim::PrimCommonUtils::CheckSkipCompOps(op_desc.Type()); + VLOG(3) << "need skip: " << need_skip << std::endl; if (paddle::prim::PrimCommonUtils::IsBwdPrimEnabled()) { - if (grad_comp_op_maker != nullptr) { + if ((grad_comp_op_maker != nullptr) && (!need_skip)) { VLOG(3) << "Runing composite fun for " << op_desc.Type(); grad_op_descs = grad_comp_op_maker(op_desc, no_grad_set, diff --git a/python/paddle/fluid/core.py b/python/paddle/fluid/core.py index 54b51f0facef9d39ebbe4945cf505f4fb0890eb2..dcdd8847d842426f5d2ca04dc7e4d2a95b7876c1 100644 --- a/python/paddle/fluid/core.py +++ b/python/paddle/fluid/core.py @@ -306,6 +306,8 @@ try: from .libpaddle import _Profiler, _ProfilerResult, _RecordEvent from .libpaddle import _set_current_stream from .libpaddle import _get_phi_kernel_name + from .libpaddle import _add_skip_comp_ops + from .libpaddle import _remove_skip_comp_ops # prim controller flags from .libpaddle import __set_bwd_prim_enabled @@ -409,7 +411,7 @@ def __sync_stat_with_flag(flag): __set_fwd_prim_enabled(True) else: raise TypeError(f"flag {flag} should be true or false.") - logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) elif flag is "FLAGS_prim_backward": flag_value = os.getenv("FLAGS_prim_backward") assert flag_value is not None @@ -420,7 +422,7 @@ def __sync_stat_with_flag(flag): __set_bwd_prim_enabled(True) else: raise TypeError(f"flag {flag} should be true or false.") - logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) elif flag is "FLAGS_prim_all": flag_value = os.getenv("FLAGS_prim_all") assert flag_value is not None @@ -431,7 +433,7 @@ def __sync_stat_with_flag(flag): __set_all_prim_enabled(True) else: raise TypeError(f"flag {flag} should be true or false.") - logging.debug( + print( "all prim enabled: ", bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), ) @@ -441,19 +443,24 @@ def __sync_stat_with_flag(flag): ) +# Alert!!! This method is only for test coveraget, user should never use it directly, this may cause serious system errors. +def _test_use_sync(value): + __sync_stat_with_flag(value) + + def _set_prim_backward_enabled(value): __set_bwd_prim_enabled(bool(value)) - logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) def _set_prim_forward_enabled(value): __set_fwd_prim_enabled(bool(value)) - logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) def _set_prim_all_enabled(value): __set_all_prim_enabled(bool(value)) - logging.debug( + print( "all prim enabled: ", bool(_is_fwd_prim_enabled() and _is_bwd_prim_enabled()), ) @@ -462,7 +469,7 @@ def _set_prim_all_enabled(value): def __sync_prim_backward_status(): flag_value = os.getenv("FLAGS_prim_backward") if flag_value is None: - logging.debug("backward prim enabled: ", bool(_is_bwd_prim_enabled())) + print("backward prim enabled: ", bool(_is_bwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_backward") @@ -470,7 +477,7 @@ def __sync_prim_backward_status(): def __sync_prim_forward_status(): flag_value = os.getenv("FLAGS_prim_forward") if flag_value is None: - logging.debug("forward prim enabled: ", bool(_is_fwd_prim_enabled())) + print("forward prim enabled: ", bool(_is_fwd_prim_enabled())) else: __sync_stat_with_flag("FLAGS_prim_forward") diff --git a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py index 8f3053af919e926c274a0f5d5bfba7d75ed12805..ef6c2951ff7dc8054a70313d9339986acbc7cb71 100644 --- a/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py +++ b/python/paddle/fluid/tests/unittests/prim/prim/flags/test_prim_flags.py @@ -47,6 +47,16 @@ class TestPrimFlags(unittest.TestCase): core.check_and_set_prim_all_enabled() self.assertFalse(core._is_fwd_prim_enabled()) + del os.environ['FLAGS_prim_backward'] + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_bwd_prim_enabled()) + del os.environ['FLAGS_prim_forward'] + core.check_and_set_prim_all_enabled() + self.assertFalse(core._is_fwd_prim_enabled()) + + with self.assertRaises(TypeError): + core._test_use_sync("aaaa") + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py b/python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py new file mode 100644 index 0000000000000000000000000000000000000000..15648226e7859a4be83d525b3efbd5e5d55e0bc7 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/prim/test_comp_skip_op_set.py @@ -0,0 +1,75 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import paddle +from paddle.fluid import core, framework + + +class TestGetGradOpDescPrimEnabled(unittest.TestCase): + def setUp(self): + self.fwd_type = 'tanh' + self.inputs = {'X': ['x']} + self.outputs = {'Out': ['y']} + self.no_grad_var = set() + self.grad_sub_block = tuple() + self.desired_ops = 'tanh_grad' + self.desired_ops_no_skip = ('pow', 'scale', 'elementwise_mul') + paddle.enable_static() + block = framework.Block(framework.Program(), 0) + block.append_op( + type=self.fwd_type, + inputs={ + n: [block.create_var(name=v, stop_gradient=False) for v in vs] + for n, vs in self.inputs.items() + }, + outputs={ + n: [block.create_var(name=v, stop_gradient=False) for v in vs] + for n, vs in self.outputs.items() + }, + ) + self.fwd = block.ops[0].desc + + def tearDown(self): + paddle.disable_static() + + def test_get_grad_op_desc_without_skip(self): + core._set_prim_backward_enabled(True) + actual = tuple( + desc.type() + for desc in core.get_grad_op_desc( + self.fwd, self.no_grad_var, self.grad_sub_block + )[0] + ) + self.assertEquals(actual, self.desired_ops_no_skip) + core._set_prim_backward_enabled(False) + + def test_get_grad_op_desc_with_skip(self): + core._set_prim_backward_enabled(True) + core._add_skip_comp_ops("tanh") + actual = tuple( + desc.type() + for desc in core.get_grad_op_desc( + self.fwd, self.no_grad_var, self.grad_sub_block + )[0] + ) + core._remove_skip_comp_ops("tanh") + self.assertEquals(actual[0], self.desired_ops) + core._set_prim_backward_enabled(False) + + +if __name__ == '__main__': + unittest.main()