未验证 提交 343bff7b 编写于 作者: A Aurelius84 提交者: GitHub

[D2SCinn]Add build_cinn_pass in BuildStrategy (#49496)

上级 257e6c99
......@@ -56,7 +56,7 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
AppendPrintGraphPass("graph_viz_pass", "_original_graph");
#ifdef PADDLE_WITH_CINN
if (FLAGS_use_cinn) {
if (FLAGS_use_cinn || strategy.build_cinn_pass_) {
// Note: This pass is used to enable cinn.
AppendPass("build_cinn_pass");
AppendPrintGraphPass("graph_viz_pass", "_build_cinn_graph");
......
......@@ -103,6 +103,9 @@ struct BuildStrategy {
// Fix the op run order.
bool fix_op_run_order_{false};
// Lowering sub-graph into cinn ops.
bool build_cinn_pass_{false};
// Operator fusion
// TODO(dev-paddle): fuse_elewise_add_act_ops may cause some models have
// cycle.
......
......@@ -633,7 +633,33 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
[](BuildStrategy &self, int nranks) {
self.hierarchical_allreduce_inter_nranks_ = nranks;
})
.def_property(
"build_cinn_pass",
[](const BuildStrategy &self) { return self.build_cinn_pass_; },
[](BuildStrategy &self, bool b) {
PADDLE_ENFORCE_NE(self.IsFinalized(),
true,
platform::errors::PreconditionNotMet(
"BuildStrategy has been finlaized, "
"cannot be configured again."));
self.build_cinn_pass_ = b;
},
R"DOC((bool, optional): build_cinn_pass indicates whether
to lowering some operators in graph into cinn ops
to execute, which will speed up the process of execution.
Default False.
Examples:
.. code-block:: python
import paddle
import paddle.static as static
paddle.enable_static()
build_strategy = static.BuildStrategy()
build_strategy.build_cinn_pass = True
)DOC")
.def_property(
"fuse_elewise_add_act_ops",
[](const BuildStrategy &self) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册