未验证 提交 e2173b68 编写于 作者: Z Zhen Wang 提交者: GitHub

Add the macro `-DPADDLE_WITH_CINN`. (#36660)

上级 bbd4bd73
...@@ -363,6 +363,7 @@ endif (WITH_LITE) ...@@ -363,6 +363,7 @@ endif (WITH_LITE)
if (WITH_CINN) if (WITH_CINN)
message(STATUS "Compile Paddle with CINN.") message(STATUS "Compile Paddle with CINN.")
include(external/cinn) include(external/cinn)
add_definitions(-DPADDLE_WITH_CINN)
endif (WITH_CINN) endif (WITH_CINN)
if (WITH_CRYPTO) if (WITH_CRYPTO)
......
...@@ -26,7 +26,9 @@ add_subdirectory(details) ...@@ -26,7 +26,9 @@ add_subdirectory(details)
add_subdirectory(fleet) add_subdirectory(fleet)
add_subdirectory(io) add_subdirectory(io)
add_subdirectory(new_executor) add_subdirectory(new_executor)
add_subdirectory(paddle2cinn) if (WITH_CINN)
add_subdirectory(paddle2cinn)
endif()
#ddim lib #ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto) proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto)
...@@ -353,7 +355,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h ...@@ -353,7 +355,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library(parallel_executor SRCS parallel_executor.cc DEPS cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy bind_threaded_ssa_graph_executor collective_helper graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper cinn_runner) fast_threaded_ssa_graph_executor variable_helper)
cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor) cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
if(WITH_PSCORE) if(WITH_PSCORE)
......
...@@ -139,7 +139,12 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass ...@@ -139,7 +139,12 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
fix_op_run_order_pass build_cinn_pass) fix_op_run_order_pass)
if (WITH_CINN)
set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass)
endif()
if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM)) if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass) set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif() endif()
......
...@@ -20,8 +20,10 @@ limitations under the License. */ ...@@ -20,8 +20,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
DECLARE_bool(convert_all_blocks); DECLARE_bool(convert_all_blocks);
DECLARE_bool(use_cinn);
DECLARE_bool(use_mkldnn); DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_CINN
DECLARE_bool(use_cinn);
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -72,10 +74,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { ...@@ -72,10 +74,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right. // Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass"); AppendPass("multi_devices_check_pass");
// Note: This pass is used to enable cinn. #ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) { if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass"); AppendPass("build_cinn_pass");
} }
#endif
SetCollectiveContext(); SetCollectiveContext();
} }
...@@ -486,7 +491,9 @@ USE_PASS(fuse_momentum_op_pass); ...@@ -486,7 +491,9 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass); USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass); USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass); USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
#endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass); USE_PASS(mkldnn_placement_pass);
#endif #endif
......
...@@ -705,8 +705,10 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, ...@@ -705,8 +705,10 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
* Value Range: bool, default=false * Value Range: bool, default=false
* Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN * Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN
*/ */
#ifdef PADDLE_WITH_CINN
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
use_cinn, false, "It controls whether to run PaddlePaddle using CINN"); use_cinn, false, "It controls whether to run PaddlePaddle using CINN");
#endif
DEFINE_int32(record_pool_max_size, 2000000, DEFINE_int32(record_pool_max_size, 2000000,
"SlotRecordDataset slot record pool max size"); "SlotRecordDataset slot record pool max size");
......
...@@ -14,16 +14,28 @@ ...@@ -14,16 +14,28 @@
from __future__ import print_function from __future__ import print_function
import logging
import numpy as np import numpy as np
import paddle import paddle
import unittest import unittest
paddle.enable_static() paddle.enable_static()
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def set_cinn_flag(val):
try:
paddle.set_flags({'FLAGS_use_cinn': val})
except ValueError:
logger.warning("The used paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase): class TestParallelExecutorRunCinn(unittest.TestCase):
def test_run_from_cinn(self): def test_run_from_cinn(self):
paddle.set_flags({'FLAGS_use_cinn': False}) set_cinn_flag(False)
main_program = paddle.static.Program() main_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
...@@ -49,7 +61,7 @@ class TestParallelExecutorRunCinn(unittest.TestCase): ...@@ -49,7 +61,7 @@ class TestParallelExecutorRunCinn(unittest.TestCase):
fetch_list=[prediction.name], fetch_list=[prediction.name],
return_merged=False) return_merged=False)
paddle.set_flags({'FLAGS_use_cinn': False}) set_cinn_flag(False)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册