From 343bff7b8b3c196d20fd1c53dcda89b2c29ff301 Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Wed, 4 Jan 2023 21:12:11 +0800 Subject: [PATCH] [D2SCinn]Add build_cinn_pass in BuildStrategy (#49496) --- .../fluid/framework/details/build_strategy.cc | 2 +- .../fluid/framework/details/build_strategy.h | 3 +++ paddle/fluid/pybind/parallel_executor.cc | 26 +++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/framework/details/build_strategy.cc b/paddle/fluid/framework/details/build_strategy.cc index 43f6329083..486770cdbd 100644 --- a/paddle/fluid/framework/details/build_strategy.cc +++ b/paddle/fluid/framework/details/build_strategy.cc @@ -56,7 +56,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder { AppendPrintGraphPass("graph_viz_pass", "_original_graph"); #ifdef PADDLE_WITH_CINN - if (FLAGS_use_cinn) { + if (FLAGS_use_cinn || strategy.build_cinn_pass_) { // Note: This pass is used to enable cinn. AppendPass("build_cinn_pass"); AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph"); diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index c1ef2eba64..4d51099529 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -103,6 +103,9 @@ struct BuildStrategy { // Fix the op run order. bool fix_op_run_order_{false}; + // Lowering sub-graph into cinn ops. + bool build_cinn_pass_{false}; + // Operator fusion // TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have // cycle. diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index d0aea4e76d..962bdd736f 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -633,7 +633,33 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT [](BuildStrategy &self, int nranks) { self.hierarchical_allreduce_inter_nranks_ = nranks; }) + .def_property( + "build_cinn_pass", + [](const BuildStrategy &self) { return self.build_cinn_pass_; }, + [](BuildStrategy &self, bool b) { + PADDLE_ENFORCE_NE(self.IsFinalized(), + true, + platform::errors::PreconditionNotMet( + "BuildStrategy has been finlaized, " + "cannot be configured again.")); + self.build_cinn_pass_ = b; + }, + R"DOC((bool, optional): build_cinn_pass indicates whether + to lowering some operators in graph into cinn ops + to execute, which will speed up the process of execution. + Default False. + + Examples: + .. code-block:: python + + import paddle + import paddle.static as static + paddle.enable_static() + + build_strategy = static.BuildStrategy() + build_strategy.build_cinn_pass = True + )DOC") .def_property( "fuse_elewise_add_act_ops", [](const BuildStrategy &self) { -- GitLab