未验证 提交 e2173b68 编写于 作者: Z Zhen Wang 提交者: GitHub

Add the macro `-DPADDLE_WITH_CINN`. (#36660)

上级 bbd4bd73
......@@ -363,6 +363,7 @@ endif (WITH_LITE)
if (WITH_CINN)
message(STATUS "Compile Paddle with CINN.")
include(external/cinn)
add_definitions(-DPADDLE_WITH_CINN)
endif (WITH_CINN)
if (WITH_CRYPTO)
......
......@@ -26,7 +26,9 @@ add_subdirectory(details)
add_subdirectory(fleet)
add_subdirectory(io)
add_subdirectory(new_executor)
add_subdirectory(paddle2cinn)
if (WITH_CINN)
add_subdirectory(paddle2cinn)
endif()
#ddim lib
proto_library(framework_proto SRCS framework.proto)
proto_library(pass_desc_proto SRCS pass_desc.proto DEPS framework_proto)
......@@ -353,7 +355,7 @@ target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_h
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor
graph build_strategy bind_threaded_ssa_graph_executor collective_helper
fast_threaded_ssa_graph_executor variable_helper cinn_runner)
fast_threaded_ssa_graph_executor variable_helper)
cc_library(executor_cache SRCS executor_cache.cc DEPS parallel_executor)
if(WITH_PSCORE)
......
......@@ -139,7 +139,12 @@ set(IR_PASS_DEPS graph_viz_pass multi_devices_graph_pass
coalesce_grad_tensor_pass fuse_all_reduce_op_pass backward_optimizer_op_deps_pass
fuse_adam_op_pass fuse_sgd_op_pass fuse_momentum_op_pass
sync_batch_norm_pass runtime_context_cache_pass graph_to_program_pass
fix_op_run_order_pass build_cinn_pass)
fix_op_run_order_pass)
if (WITH_CINN)
set(IR_PASS_DEPS ${IR_PASS_DEPS} build_cinn_pass)
endif()
if(NOT APPLE AND NOT WIN32 AND (WITH_GPU OR WITH_ROCM))
set(IR_PASS_DEPS ${IR_PASS_DEPS} fusion_group_pass)
endif()
......
......@@ -20,8 +20,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.h"
DECLARE_bool(convert_all_blocks);
DECLARE_bool(use_cinn);
DECLARE_bool(use_mkldnn);
#ifdef PADDLE_WITH_CINN
DECLARE_bool(use_cinn);
#endif
namespace paddle {
namespace framework {
......@@ -72,10 +74,13 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
// Note: This pass is used to check whether the multi_device_graph is right.
AppendPass("multi_devices_check_pass");
// Note: This pass is used to enable cinn.
#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
}
#endif
SetCollectiveContext();
}
......@@ -486,7 +491,9 @@ USE_PASS(fuse_momentum_op_pass);
USE_PASS(fuse_all_reduce_op_pass);
USE_PASS(runtime_context_cache_pass);
USE_PASS(add_reader_dependency_pass);
#ifdef PADDLE_WITH_CINN
USE_PASS(build_cinn_pass);
#endif
#ifdef PADDLE_WITH_MKLDNN
USE_PASS(mkldnn_placement_pass);
#endif
......
......@@ -705,8 +705,10 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
* Value Range: bool, default=false
* Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN
*/
#ifdef PADDLE_WITH_CINN
PADDLE_DEFINE_EXPORTED_bool(
use_cinn, false, "It controls whether to run PaddlePaddle using CINN");
#endif
DEFINE_int32(record_pool_max_size, 2000000,
"SlotRecordDataset slot record pool max size");
......
......@@ -14,16 +14,28 @@
from __future__ import print_function
import logging
import numpy as np
import paddle
import unittest
paddle.enable_static()
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def set_cinn_flag(val):
try:
paddle.set_flags({'FLAGS_use_cinn': val})
except ValueError:
logger.warning("The used paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase):
def test_run_from_cinn(self):
paddle.set_flags({'FLAGS_use_cinn': False})
set_cinn_flag(False)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -49,7 +61,7 @@ class TestParallelExecutorRunCinn(unittest.TestCase):
fetch_list=[prediction.name],
return_merged=False)
paddle.set_flags({'FLAGS_use_cinn': False})
set_cinn_flag(False)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册