diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 0049311a9315ae9bb5f5c6d1626ffec423d7b1ff..7cdbee1746a8ff4610911bc09f1ca882289fd299 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -363,6 +363,7 @@ endif (WITH_LITE) if (WITH_CINN) message(STATUS "Compile Paddle with CINN.") include(external/cinn) + add_definitions(-DPADDLE_WITH_CINN) endif (WITH_CINN) if (WITH_CRYPTO) diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 4dfcf0985b85e1010a8ea4390fe337e7b98703f1..edb43b8d38c27698a7abb1bccb9dd42607791efa 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -26,7 +26,9 @@ add_subdirectory(details) add_subdirectory(fleet) add_subdirectory(io) add_subdirectory(new_executor) -add_subdirectory(paddle2cinn) +if (WITH_CINN) + add_subdirectory(paddle2cinn) +endif() #ddim lib proto_library(framework_proto SRCS framework.proto) proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto) @@ -353,7 +355,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h cc_library(parallel_executor SRCS parallel_executor.cc DEPS threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor graph build_strategy bind_threaded_ssa_graph_executor collective_helper - fast_threaded_ssa_graph_executor variable_helper cinn_runner) + fast_threaded_ssa_graph_executor variable_helper) cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor) if(WITH_PSCORE) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 5e2fd08406fa75f6fc1234869a04d730dd72bec8..87f77ec2fff3a6e5dd5bfb38a110a527b6d35c8f 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -139,7 +139,12 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass - fix_op_run_order_pass build_cinn_pass) + fix_op_run_order_pass) + +if (WITH_CINN) + set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass) +endif() + if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) endif() diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 6b6ee4083312327d8841b797c8517c9e383be991..1bb1ae0ea675581b304d3d781ffecc65632202c9 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -20,8 +20,10 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" DECLARE_bool(convert_all_blocks); -DECLARE_bool(use_cinn); DECLARE_bool(use_mkldnn); +#ifdef PADDLE_WITH_CINN +DECLARE_bool(use_cinn); +#endif namespace paddle { namespace framework { @@ -72,10 +74,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { // Note: This pass is used to check whether the multi_device_graph is right. AppendPass("multi_devices_check_pass"); - // Note: This pass is used to enable cinn. +#ifdef PADDLE_WITH_CINN if (FLAGS_use_cinn) { + // Note: This pass is used to enable cinn. AppendPass("build_cinn_pass"); } +#endif + SetCollectiveContext(); } @@ -486,7 +491,9 @@ USE_PASS(fuse_momentum_op_pass); USE_PASS(fuse_all_reduce_op_pass); USE_PASS(runtime_context_cache_pass); USE_PASS(add_reader_dependency_pass); +#ifdef PADDLE_WITH_CINN USE_PASS(build_cinn_pass); +#endif #ifdef PADDLE_WITH_MKLDNN USE_PASS(mkldnn_placement_pass); #endif diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index dd65d743fad31a12925a8f7454eafea02d7ba9f3..ef908be8462ed6cc852322cb33eff93ad6025621 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -705,8 +705,10 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, * Value Range: bool, default=false * Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN */ +#ifdef PADDLE_WITH_CINN PADDLE_DEFINE_EXPORTED_bool( use_cinn, false, "It controls whether to run PaddlePaddle using CINN"); +#endif DEFINE_int32(record_pool_max_size, 2000000, "SlotRecordDataset slot record pool max size"); diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py index d4722c2e1819f9964f7e57474d47c661ab3d5634..bc0652b165eb654081d07bfc503440e8542d91e8 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py @@ -14,16 +14,28 @@ from __future__ import print_function +import logging import numpy as np import paddle import unittest paddle.enable_static() +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) +logger = logging.getLogger(__name__) + + +def set_cinn_flag(val): + try: + paddle.set_flags({'FLAGS_use_cinn': val}) + except ValueError: + logger.warning("The used paddle is not compiled with CINN.") + class TestParallelExecutorRunCinn(unittest.TestCase): def test_run_from_cinn(self): - paddle.set_flags({'FLAGS_use_cinn': False}) + set_cinn_flag(False) main_program = paddle.static.Program() startup_program = paddle.static.Program() @@ -49,7 +61,7 @@ class TestParallelExecutorRunCinn(unittest.TestCase): fetch_list=[prediction.name], return_merged=False) - paddle.set_flags({'FLAGS_use_cinn': False}) + set_cinn_flag(False) if __name__ == '__main__':