未验证 提交 9ac736c2 编写于 作者: S sneaxiy 提交者: GitHub

Add cinn pass to program (#42623)

* add cinn pass to program

* remove build_cinn_pass ut

* polish ut, add ut

* guard ut with is_compiled_with_cinn

* enable ut test_build_cinn_pass_resnet
上级 cc343a41
...@@ -545,6 +545,15 @@ void ReplaceSubGraphWithCinnOpNode( ...@@ -545,6 +545,15 @@ void ReplaceSubGraphWithCinnOpNode(
RemoveSubGraphFromGraph(cluster, cluster_internals, graph); RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
} }
static bool IsInplaceOp(const OpDesc& op_desc) {
auto inputs = op_desc.InputArgumentNames();
std::unordered_set<std::string> input_set(inputs.begin(), inputs.end());
for (auto& name : op_desc.OutputArgumentNames()) {
if (input_set.count(name) > 0) return true;
}
return false;
}
// Search all subgraphs which all op node supported by CINN, // Search all subgraphs which all op node supported by CINN,
// Here we using SubgraphDetector to detecte the subgraph that // Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry // all of op node supported by CINN. We using OpMapperRegistry
...@@ -565,9 +574,10 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -565,9 +574,10 @@ void SearchAllSubgraphs(Graph* graph) {
if (deny_ops.size()) { if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name()); return registered && !deny_ops.count(node->Name());
} }
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN // return true only when it is registered in CINN
return registered; return registered && (node->IsOp() && !IsInplaceOp(*node->Op()));
}; };
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
......
...@@ -61,8 +61,8 @@ using ::cinn::hlir::framework::BuildScope; ...@@ -61,8 +61,8 @@ using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::GraphCompiler; using ::cinn::hlir::framework::GraphCompiler;
CinnCompiler* CinnCompiler::GetInstance() { CinnCompiler* CinnCompiler::GetInstance() {
static CinnCompiler instance; static CinnCompiler* instance = new CinnCompiler();
return &instance; return instance;
} }
const CinnCompiledObject& CinnCompiler::Compile( const CinnCompiledObject& CinnCompiler::Compile(
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <chrono>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -101,8 +102,21 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -101,8 +102,21 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// Step 2. Get compilation result of the graph // Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(place); auto target = details::PlaceToCinnTarget(place);
using ClockType = std::chrono::steady_clock;
std::chrono::time_point<ClockType> start_t, end_t;
if (VLOG_IS_ON(1)) {
VLOG(1) << "Starts to compile at thread " << std::this_thread::get_id();
start_t = ClockType::now();
}
const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile( const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
compilation_key, inputs_name2tensor, target, stream); compilation_key, inputs_name2tensor, target, stream);
if (VLOG_IS_ON(1)) {
end_t = ClockType::now();
auto time_sec = std::chrono::duration_cast<std::chrono::milliseconds>(
end_t - start_t);
VLOG(1) << "Ends to compile at thread " << std::this_thread::get_id()
<< " , time cost : " << time_sec.count() << " ms";
}
details::DebugCinnCompiledResult(cinn_compiled_object); details::DebugCinnCompiledResult(cinn_compiled_object);
auto* launch_context = cinn_compiled_object.launch_context.get(); auto* launch_context = cinn_compiled_object.launch_context.get();
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
from .pass_base import PassType, CPPPassWrapper, register_pass from .pass_base import PassType, CPPPassWrapper, register_pass
from paddle.fluid.framework import core, _apply_pass as _apply_cpp_pass
@register_pass("fuse_elewise_add_act") @register_pass("fuse_elewise_add_act")
...@@ -93,3 +94,35 @@ class InplaceAddtoOpPass(CPPPassWrapper): ...@@ -93,3 +94,35 @@ class InplaceAddtoOpPass(CPPPassWrapper):
def _type(self): def _type(self):
return PassType.CALC_OPT return PassType.CALC_OPT
@register_pass("build_cinn")
class BuildCINNPass(CPPPassWrapper):
def __init__(self):
super(BuildCINNPass, self).__init__()
self.set_attr("allow_ops", [])
self.set_attr("deny_ops", [])
@property
def cpp_name(self):
return "build_cinn_pass"
def _type(self):
return PassType.CALC_OPT
def _apply_single_impl(self, main_program, startup_program, context):
allow_ops = ";".join(self.get_attr("allow_ops"))
deny_ops = ";".join(self.get_attr("deny_ops"))
assert 'FLAGS_allow_cinn_ops' in core.globals(
), "PaddlePaddle is not compiled with CINN support"
old_allow_ops = core.globals()['FLAGS_allow_cinn_ops']
old_deny_ops = core.globals()['FLAGS_deny_cinn_ops']
try:
core.globals()['FLAGS_allow_cinn_ops'] = allow_ops
core.globals()['FLAGS_deny_cinn_ops'] = deny_ops
_apply_cpp_pass(main_program, startup_program, self.cpp_name, {},
self.cpp_attr_types)
finally:
core.globals()['FLAGS_allow_cinn_ops'] = old_allow_ops
core.globals()['FLAGS_deny_cinn_ops'] = old_deny_ops
...@@ -39,7 +39,7 @@ def prepare_python_path_and_return_module(path): ...@@ -39,7 +39,7 @@ def prepare_python_path_and_return_module(path):
paths.append(dirname) paths.append(dirname)
python_path = ":".join(paths) python_path = ":".join(paths)
else: else:
python_path = path python_path = dirname
os.environ[env_name] = python_path os.environ[env_name] = python_path
print('GLOG_v=', os.environ.get('GLOG_v', None), flush=1) print('GLOG_v=', os.environ.get('GLOG_v', None), flush=1)
return filename[:-len(py_suffix)] return filename[:-len(py_suffix)]
...@@ -85,9 +85,9 @@ class DistPassTestBase(unittest.TestCase): ...@@ -85,9 +85,9 @@ class DistPassTestBase(unittest.TestCase):
raise NotImplementedError() raise NotImplementedError()
def check_main(self, model=None, gpus=None, **kwargs): def check_main(self, model=None, gpus=None, **kwargs):
no_pass_rets = self._distributed_launch(
model=model, apply_pass=True, gpus=gpus, **kwargs)
pass_rets = self._distributed_launch( pass_rets = self._distributed_launch(
model=model, apply_pass=True, gpus=gpus, **kwargs)
no_pass_rets = self._distributed_launch(
model=model, apply_pass=False, gpus=gpus, **kwargs) model=model, apply_pass=False, gpus=gpus, **kwargs)
self.check_results(no_pass_rets, pass_rets) self.check_results(no_pass_rets, pass_rets)
......
...@@ -59,3 +59,40 @@ def resnet_model(place, batch_size, image_shape=[3, 224, 224], ...@@ -59,3 +59,40 @@ def resnet_model(place, batch_size, image_shape=[3, 224, 224],
main_program = paddle.static.default_main_program() main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program() startup_program = paddle.static.default_startup_program()
return main_program, startup_program, [image, label], [loss], reader return main_program, startup_program, [image, label], [loss], reader
def simple_net(place, batch_size, image_shape=[784], num_classes=10):
image = paddle.static.data(
shape=[batch_size] + image_shape, dtype='float32', name='image')
label = paddle.static.data(
shape=[batch_size, 1], dtype='int64', name='label')
linears = [nn.Linear(784, 784) for _ in range(3)]
hidden = image
for linear in linears:
hidden = linear(hidden)
hidden = nn.ReLU()(hidden)
loss_fn = nn.loss.CrossEntropyLoss()
loss = loss_fn(hidden, label)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.fuse_all_reduce_ops = False
dist_strategy.without_graph_optimization = True
fleet.init(is_collective=True, strategy=dist_strategy)
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(loss)
rank = paddle.distributed.get_rank()
def reader():
seed = get_seed_from_env()
np.random.seed(seed + rank)
for _ in range(10):
image_np = np.random.random(size=image.shape).astype('float32')
label_np = np.random.randint(
low=0, high=num_classes, size=label.shape).astype('int64')
yield image_np, label_np
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
return main_program, startup_program, [image, label], [loss], reader
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
from model_zoo import resnet_model
class TestBuildCINNPass(DistPassTestBase):
def init(self):
self.atol = 0.5
self.rtol = 0.0
def apply_passes(self, main_prog, startup_prog):
pass_manager = PassManager([
new_pass("build_cinn"),
new_pass("fuse_elewise_add_act"),
])
pass_manager.apply([main_prog], [startup_prog])
print(pass_manager.names)
def test_bs_32(self):
if paddle.is_compiled_with_cinn():
self.check_main(resnet_model, batch_size=32)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
from model_zoo import simple_net
class TestBuildCINNPass(DistPassTestBase):
def init(self):
self.atol = 0.0
self.rtol = 0.0
def apply_passes(self, main_prog, startup_prog):
pass_manager = PassManager([
new_pass("build_cinn"),
new_pass("fuse_elewise_add_act"),
])
pass_manager.apply([main_prog], [startup_prog])
op_types = [op.type for op in main_prog.global_block().ops]
self.assertTrue('cinn_launch' in op_types)
def test_bs_32(self):
if paddle.is_compiled_with_cinn():
self.check_main(simple_net, batch_size=32)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册