From 9ac736c2ad9b4ad90cbfe2a4df001de01ab981b2 Mon Sep 17 00:00:00 2001 From: sneaxiy <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 12 May 2022 16:02:43 +0800 Subject: [PATCH] Add cinn pass to program (#42623) * add cinn pass to program * remove build_cinn_pass ut * polish ut, add ut * guard ut with is_compiled_with_cinn * enable ut test_build_cinn_pass_resnet --- .../framework/paddle2cinn/build_cinn_pass.cc | 12 +++++- .../framework/paddle2cinn/cinn_compiler.cc | 4 +- paddle/fluid/operators/cinn/cinn_launch_op.h | 14 +++++++ python/paddle/distributed/passes/cpp_pass.py | 33 +++++++++++++++ .../distributed_passes/dist_pass_test_base.py | 6 +-- .../unittests/distributed_passes/model_zoo.py | 37 ++++++++++++++++ .../test_build_cinn_pass_resnet.py | 41 ++++++++++++++++++ .../test_build_cinn_pass_simple_net.py | 42 +++++++++++++++++++ 8 files changed, 183 insertions(+), 6 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_resnet.py create mode 100644 python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_simple_net.py diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index e259d6d417a..0de89aaad3b 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -545,6 +545,15 @@ void ReplaceSubGraphWithCinnOpNode( RemoveSubGraphFromGraph(cluster, cluster_internals, graph); } +static bool IsInplaceOp(const OpDesc& op_desc) { + auto inputs = op_desc.InputArgumentNames(); + std::unordered_set input_set(inputs.begin(), inputs.end()); + for (auto& name : op_desc.OutputArgumentNames()) { + if (input_set.count(name) > 0) return true; + } + return false; +} + // Search all subgraphs which all op node supported by CINN, // Here we using SubgraphDetector to detecte the subgraph that // all of op node supported by CINN. We using OpMapperRegistry @@ -565,9 +574,10 @@ void SearchAllSubgraphs(Graph* graph) { if (deny_ops.size()) { return registered && !deny_ops.count(node->Name()); } + // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // return true only when it is registered in CINN - return registered; + return registered && (node->IsOp() && !IsInplaceOp(*node->Op())); }; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index 51dca93c7c7..549c8549617 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -61,8 +61,8 @@ using ::cinn::hlir::framework::BuildScope; using ::cinn::hlir::framework::GraphCompiler; CinnCompiler* CinnCompiler::GetInstance() { - static CinnCompiler instance; - return &instance; + static CinnCompiler* instance = new CinnCompiler(); + return instance; } const CinnCompiledObject& CinnCompiler::Compile( diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.h b/paddle/fluid/operators/cinn/cinn_launch_op.h index 024bf2bceb3..6001a4f5c07 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn/cinn_launch_op.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include #include @@ -101,8 +102,21 @@ class CinnLaunchOpKernel : public framework::OpKernel { // Step 2. Get compilation result of the graph auto target = details::PlaceToCinnTarget(place); + using ClockType = std::chrono::steady_clock; + std::chrono::time_point start_t, end_t; + if (VLOG_IS_ON(1)) { + VLOG(1) << "Starts to compile at thread " << std::this_thread::get_id(); + start_t = ClockType::now(); + } const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile( compilation_key, inputs_name2tensor, target, stream); + if (VLOG_IS_ON(1)) { + end_t = ClockType::now(); + auto time_sec = std::chrono::duration_cast( + end_t - start_t); + VLOG(1) << "Ends to compile at thread " << std::this_thread::get_id() + << " , time cost : " << time_sec.count() << " ms"; + } details::DebugCinnCompiledResult(cinn_compiled_object); auto* launch_context = cinn_compiled_object.launch_context.get(); diff --git a/python/paddle/distributed/passes/cpp_pass.py b/python/paddle/distributed/passes/cpp_pass.py index 4a4e5ecbbb4..72525255b7e 100644 --- a/python/paddle/distributed/passes/cpp_pass.py +++ b/python/paddle/distributed/passes/cpp_pass.py @@ -13,6 +13,7 @@ # limitations under the License. from .pass_base import PassType, CPPPassWrapper, register_pass +from paddle.fluid.framework import core, _apply_pass as _apply_cpp_pass @register_pass("fuse_elewise_add_act") @@ -93,3 +94,35 @@ class InplaceAddtoOpPass(CPPPassWrapper): def _type(self): return PassType.CALC_OPT + + +@register_pass("build_cinn") +class BuildCINNPass(CPPPassWrapper): + def __init__(self): + super(BuildCINNPass, self).__init__() + self.set_attr("allow_ops", []) + self.set_attr("deny_ops", []) + + @property + def cpp_name(self): + return "build_cinn_pass" + + def _type(self): + return PassType.CALC_OPT + + def _apply_single_impl(self, main_program, startup_program, context): + allow_ops = ";".join(self.get_attr("allow_ops")) + deny_ops = ";".join(self.get_attr("deny_ops")) + + assert 'FLAGS_allow_cinn_ops' in core.globals( + ), "PaddlePaddle is not compiled with CINN support" + old_allow_ops = core.globals()['FLAGS_allow_cinn_ops'] + old_deny_ops = core.globals()['FLAGS_deny_cinn_ops'] + try: + core.globals()['FLAGS_allow_cinn_ops'] = allow_ops + core.globals()['FLAGS_deny_cinn_ops'] = deny_ops + _apply_cpp_pass(main_program, startup_program, self.cpp_name, {}, + self.cpp_attr_types) + finally: + core.globals()['FLAGS_allow_cinn_ops'] = old_allow_ops + core.globals()['FLAGS_deny_cinn_ops'] = old_deny_ops diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py index f0ed2cdc049..786ee06487f 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/dist_pass_test_base.py @@ -39,7 +39,7 @@ def prepare_python_path_and_return_module(path): paths.append(dirname) python_path = ":".join(paths) else: - python_path = path + python_path = dirname os.environ[env_name] = python_path print('GLOG_v=', os.environ.get('GLOG_v', None), flush=1) return filename[:-len(py_suffix)] @@ -85,9 +85,9 @@ class DistPassTestBase(unittest.TestCase): raise NotImplementedError() def check_main(self, model=None, gpus=None, **kwargs): - no_pass_rets = self._distributed_launch( - model=model, apply_pass=True, gpus=gpus, **kwargs) pass_rets = self._distributed_launch( + model=model, apply_pass=True, gpus=gpus, **kwargs) + no_pass_rets = self._distributed_launch( model=model, apply_pass=False, gpus=gpus, **kwargs) self.check_results(no_pass_rets, pass_rets) diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/model_zoo.py b/python/paddle/fluid/tests/unittests/distributed_passes/model_zoo.py index 0b522b79c4e..7eebee47e59 100644 --- a/python/paddle/fluid/tests/unittests/distributed_passes/model_zoo.py +++ b/python/paddle/fluid/tests/unittests/distributed_passes/model_zoo.py @@ -59,3 +59,40 @@ def resnet_model(place, batch_size, image_shape=[3, 224, 224], main_program = paddle.static.default_main_program() startup_program = paddle.static.default_startup_program() return main_program, startup_program, [image, label], [loss], reader + + +def simple_net(place, batch_size, image_shape=[784], num_classes=10): + image = paddle.static.data( + shape=[batch_size] + image_shape, dtype='float32', name='image') + label = paddle.static.data( + shape=[batch_size, 1], dtype='int64', name='label') + linears = [nn.Linear(784, 784) for _ in range(3)] + hidden = image + for linear in linears: + hidden = linear(hidden) + hidden = nn.ReLU()(hidden) + loss_fn = nn.loss.CrossEntropyLoss() + loss = loss_fn(hidden, label) + optimizer = paddle.optimizer.Adam(learning_rate=1e-3) + + dist_strategy = fleet.DistributedStrategy() + dist_strategy.fuse_all_reduce_ops = False + dist_strategy.without_graph_optimization = True + fleet.init(is_collective=True, strategy=dist_strategy) + optimizer = fleet.distributed_optimizer(optimizer) + optimizer.minimize(loss) + + rank = paddle.distributed.get_rank() + + def reader(): + seed = get_seed_from_env() + np.random.seed(seed + rank) + for _ in range(10): + image_np = np.random.random(size=image.shape).astype('float32') + label_np = np.random.randint( + low=0, high=num_classes, size=label.shape).astype('int64') + yield image_np, label_np + + main_program = paddle.static.default_main_program() + startup_program = paddle.static.default_startup_program() + return main_program, startup_program, [image, label], [loss], reader diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_resnet.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_resnet.py new file mode 100644 index 00000000000..8430eb615a2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_resnet.py @@ -0,0 +1,41 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed.passes import new_pass, PassManager +import unittest +from dist_pass_test_base import DistPassTestBase +from model_zoo import resnet_model + + +class TestBuildCINNPass(DistPassTestBase): + def init(self): + self.atol = 0.5 + self.rtol = 0.0 + + def apply_passes(self, main_prog, startup_prog): + pass_manager = PassManager([ + new_pass("build_cinn"), + new_pass("fuse_elewise_add_act"), + ]) + pass_manager.apply([main_prog], [startup_prog]) + print(pass_manager.names) + + def test_bs_32(self): + if paddle.is_compiled_with_cinn(): + self.check_main(resnet_model, batch_size=32) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_simple_net.py b/python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_simple_net.py new file mode 100644 index 00000000000..e030420d324 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/distributed_passes/test_build_cinn_pass_simple_net.py @@ -0,0 +1,42 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle.distributed.passes import new_pass, PassManager +import unittest +from dist_pass_test_base import DistPassTestBase +from model_zoo import simple_net + + +class TestBuildCINNPass(DistPassTestBase): + def init(self): + self.atol = 0.0 + self.rtol = 0.0 + + def apply_passes(self, main_prog, startup_prog): + pass_manager = PassManager([ + new_pass("build_cinn"), + new_pass("fuse_elewise_add_act"), + ]) + pass_manager.apply([main_prog], [startup_prog]) + op_types = [op.type for op in main_prog.global_block().ops] + self.assertTrue('cinn_launch' in op_types) + + def test_bs_32(self): + if paddle.is_compiled_with_cinn(): + self.check_main(simple_net, batch_size=32) + + +if __name__ == "__main__": + unittest.main() -- GitLab