diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index f9c28f42776906de2aec4313a7024396dbb24578..96c052ef505b5799de202b9d06c983c6044747ae 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -14,6 +14,7 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include #include #include #include @@ -66,6 +67,7 @@ const CinnCompiledObject& CinnCompiler::Compile( const Graph& graph, const std::map& input_tensors, const Target& target) { + VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph); CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); bool exist = false; { @@ -73,8 +75,9 @@ const CinnCompiledObject& CinnCompiler::Compile( exist = cache_.count(cur_key) != 0; } if (!exist) { - real_compiled_num_++; - auto compiled_res = CompileGraph(graph, input_tensors, target); + std::int64_t compiled_num = real_compiled_num_.fetch_add(1); + auto compiled_res = + CompileGraph(graph, input_tensors, target, compiled_num); AutoWRLock w_guard{&rwlock_}; if (!cache_.count(cur_key)) { cache_[cur_key] = std::move(compiled_res); @@ -89,7 +92,6 @@ const CinnCompiledObject& CinnCompiler::Compile( const std::string& compilation_key, const std::map& input_tensors, const Target& target) { - VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(compilation_key); const auto& graph = FindGraph(compilation_key); return Compile(graph, input_tensors, target); } @@ -120,10 +122,14 @@ const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { return *graphs_.at(graph_key); } -std::string CinnCompiler::VizGraph(const std::string& key) const { +std::string CinnCompiler::VizGraph(const std::string& graph_key) const { + const Graph& graph = FindGraph(graph_key); + return VizGraph(graph); +} + +std::string CinnCompiler::VizGraph(const Graph& graph) const { Dot dot; std::unordered_map node2dot; - const Graph& graph = FindGraph(key); int id = 0; // Create nodes for (const Node* n : graph.Nodes()) { @@ -164,9 +170,10 @@ std::string CinnCompiler::VizGraph(const std::string& key) const { return dot.Build(); } -std::string CinnCompiler::ReadableKey(const std::string& key) const { +std::string CinnCompiler::ReadableKey( + const std::string& compilation_key) const { proto::ProgramDesc desc; - desc.ParseFromString(key); + desc.ParseFromString(compilation_key); return desc.DebugString(); } @@ -176,20 +183,19 @@ void CinnCompiler::Clear() { graphs_.clear(); cache_.clear(); } - real_compiled_num_ = 0; + real_compiled_num_.store(1); } std::unique_ptr CinnCompiler::CompileGraph( const ir::Graph& graph, const std::map& input_tensors, - const Target& target) const { - CinnGraphSymbolization symbol{real_compiled_num_, graph, target, - input_tensors}; + const Target& target, std::int64_t compiled_num) const { + CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors}; auto frontend_program = symbol(); ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( frontend_program, target); - VLOG(4) << "-- The " << real_compiled_num_ << "-th compilation (" + VLOG(4) << "-- The " << compiled_num << "-th compilation (" << target.arch_str() << "), and its related graph:\n" << cinn_graph->Visualize(); ApplyPass(cinn_graph.get(), "OpFusion"); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index 3996c62cb943ec32a07db4ba6502f79689527ca3..29ec1e424cc230b85bec6a08a18df96ec03445a0 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -15,6 +15,7 @@ #pragma once #include +#include #include #include #include @@ -63,15 +64,17 @@ class CinnCompiler { std::string AddGraph(std::unique_ptr graph); - const ir::Graph& FindGraph(const std::string& key) const; + const ir::Graph& FindGraph(const std::string& graph_key) const; - std::string VizGraph(const std::string& key) const; + std::string VizGraph(const std::string& graph_key) const; - std::string ReadableKey(const std::string& key) const; + std::string VizGraph(const ir::Graph& graph) const; + + std::string ReadableKey(const std::string& compilation_key) const; void Clear(); - std::int64_t real_compiled_num() const { return real_compiled_num_; } + std::int64_t real_compiled_num() const { return real_compiled_num_.load(); } ~CinnCompiler() = default; @@ -80,13 +83,13 @@ class CinnCompiler { std::unique_ptr CompileGraph( const ir::Graph& graph, const std::map& input_tensors, - const ::cinn::common::Target& target) const; + const ::cinn::common::Target& target, std::int64_t compiled_num) const; std::unordered_map> graphs_; std::unordered_map, CinnCacheKey::Hash> cache_; - std::atomic_int64_t real_compiled_num_{0}; + std::atomic_int64_t real_compiled_num_{1}; mutable RWLock rwlock_; DISABLE_COPY_AND_ASSIGN(CinnCompiler); diff --git a/paddle/fluid/operators/cinn_launch_op.h b/paddle/fluid/operators/cinn_launch_op.h index e1b24293f840ff904dcdeea2062841e2525e5e96..b2bbbc50ae222338acf9c713904cfb926859c4fc 100644 --- a/paddle/fluid/operators/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn_launch_op.h @@ -101,7 +101,6 @@ class CinnLaunchOpKernel : public framework::OpKernel { << "value:\n" << CinnCompiler::GetInstance()->ReadableKey(compilation_key); - const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key); auto input_variable_names = ctx.InputNames(kX); const auto& input_tensors = ctx.MultiInput(kX); std::map inputs_name2tensor; @@ -114,8 +113,8 @@ class CinnLaunchOpKernel : public framework::OpKernel { // Step 2. Get compilation result of the graph auto target = details::PlaceToCinnTarget(place); - const auto& cinn_compiled_object = - CinnCompiler::GetInstance()->Compile(graph, inputs_name2tensor, target); + const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile( + compilation_key, inputs_name2tensor, target); details::DebugCinnCompiledResult(cinn_compiled_object); const auto& cinn_runtime_program = cinn_compiled_object.runtime_program; diff --git a/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py new file mode 100644 index 0000000000000000000000000000000000000000..58c080fc0ccc7f88938c5a3e2c7d2d15c6672876 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_resnet50_with_cinn.py @@ -0,0 +1,111 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import logging +import numpy as np +import paddle +import unittest + +paddle.enable_static() + +logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO) +logger = logging.getLogger(__name__) + + +def set_cinn_flag(val): + cinn_compiled = False + try: + paddle.set_flags({'FLAGS_use_cinn': val}) + cinn_compiled = True + except ValueError: + logger.warning("The used paddle is not compiled with CINN.") + return cinn_compiled + + +@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.") +class TestResnet50Accuracy(unittest.TestCase): + def reader(self, limit): + for _ in range(limit): + yield np.random.randint(0, 256, size=[1, 3, 224, 224]).astype('float32'), \ + np.random.randint(0, 1000, size=[1]).astype('int64') + + def generate_random_data(self, loop_num=10): + feed = [] + data = self.reader(loop_num) + for _ in range(loop_num): + x, y = next(data) + feed.append({'image': x, 'label': y}) + return feed + + def build_program(self, main_program, startup_program): + with paddle.static.program_guard(main_program, startup_program): + image = paddle.static.data( + name='image', shape=[1, 3, 224, 224], dtype='float32') + label = paddle.static.data(name='label', shape=[1], dtype='int64') + + model = paddle.vision.models.resnet50() + prediction = model(image) + + loss = paddle.nn.functional.cross_entropy( + input=prediction, label=label) + loss = paddle.mean(loss) + adam = paddle.optimizer.Adam(learning_rate=0.001) + adam.minimize(loss) + return loss + + def train(self, place, iters, feed, use_cinn=False, seed=1234): + np.random.seed(seed) + paddle.seed(seed) + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + set_cinn_flag(use_cinn) + + startup_program = paddle.static.Program() + main_program = paddle.static.Program() + + loss = self.build_program(main_program, startup_program) + exe = paddle.static.Executor(place) + + parallel_exec = paddle.static.CompiledProgram( + main_program).with_data_parallel(loss_name=loss.name) + loss_vals = [] + scope = paddle.static.Scope() + + with paddle.static.scope_guard(scope): + exe.run(startup_program) + for step in range(iters): + loss_v = exe.run(parallel_exec, + feed=feed[step], + fetch_list=[loss], + return_numpy=True) + loss_vals.append(loss_v[0][0]) + return loss_vals + + def test_check_resnet50_accuracy(self): + place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda( + ) else paddle.CPUPlace() + + loop_num = 10 + feed = self.generate_random_data(loop_num) + + loss_c = self.train(place, loop_num, feed, use_cinn=True) + loss_p = self.train(place, loop_num, feed, use_cinn=False) + self.assertTrue(np.allclose(loss_c, loss_p, atol=1e-5)) + + +if __name__ == '__main__': + unittest.main()