未验证 提交 d572fa27 编写于 作者: Z Zhen Wang 提交者: GitHub

Update the `VizGraph` method of CinnCompiler and add more debug info. (#36975)

* Use a more appropriate `Compile` method in cinn_launch_op.

* Update the VizGraph method of CinnCompiler.

* Add resnet50 model training with CINN.
上级 cb6c0e21
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <cstdint>
#include <iterator> #include <iterator>
#include <map> #include <map>
#include <memory> #include <memory>
...@@ -66,6 +67,7 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -66,6 +67,7 @@ const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph, const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) { const Target& target) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
bool exist = false; bool exist = false;
{ {
...@@ -73,8 +75,9 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -73,8 +75,9 @@ const CinnCompiledObject& CinnCompiler::Compile(
exist = cache_.count(cur_key) != 0; exist = cache_.count(cur_key) != 0;
} }
if (!exist) { if (!exist) {
real_compiled_num_++; std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
auto compiled_res = CompileGraph(graph, input_tensors, target); auto compiled_res =
CompileGraph(graph, input_tensors, target, compiled_num);
AutoWRLock w_guard{&rwlock_}; AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) { if (!cache_.count(cur_key)) {
cache_[cur_key] = std::move(compiled_res); cache_[cur_key] = std::move(compiled_res);
...@@ -89,7 +92,6 @@ const CinnCompiledObject& CinnCompiler::Compile( ...@@ -89,7 +92,6 @@ const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key, const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) { const Target& target) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(compilation_key);
const auto& graph = FindGraph(compilation_key); const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target); return Compile(graph, input_tensors, target);
} }
...@@ -120,10 +122,14 @@ const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { ...@@ -120,10 +122,14 @@ const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
return *graphs_.at(graph_key); return *graphs_.at(graph_key);
} }
std::string CinnCompiler::VizGraph(const std::string& key) const { std::string CinnCompiler::VizGraph(const std::string& graph_key) const {
const Graph& graph = FindGraph(graph_key);
return VizGraph(graph);
}
std::string CinnCompiler::VizGraph(const Graph& graph) const {
Dot dot; Dot dot;
std::unordered_map<const Node*, std::string> node2dot; std::unordered_map<const Node*, std::string> node2dot;
const Graph& graph = FindGraph(key);
int id = 0; int id = 0;
// Create nodes // Create nodes
for (const Node* n : graph.Nodes()) { for (const Node* n : graph.Nodes()) {
...@@ -164,9 +170,10 @@ std::string CinnCompiler::VizGraph(const std::string& key) const { ...@@ -164,9 +170,10 @@ std::string CinnCompiler::VizGraph(const std::string& key) const {
return dot.Build(); return dot.Build();
} }
std::string CinnCompiler::ReadableKey(const std::string& key) const { std::string CinnCompiler::ReadableKey(
const std::string& compilation_key) const {
proto::ProgramDesc desc; proto::ProgramDesc desc;
desc.ParseFromString(key); desc.ParseFromString(compilation_key);
return desc.DebugString(); return desc.DebugString();
} }
...@@ -176,20 +183,19 @@ void CinnCompiler::Clear() { ...@@ -176,20 +183,19 @@ void CinnCompiler::Clear() {
graphs_.clear(); graphs_.clear();
cache_.clear(); cache_.clear();
} }
real_compiled_num_ = 0; real_compiled_num_.store(1);
} }
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) const { const Target& target, std::int64_t compiled_num) const {
CinnGraphSymbolization symbol{real_compiled_num_, graph, target, CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
input_tensors};
auto frontend_program = symbol(); auto frontend_program = symbol();
ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target); frontend_program, target);
VLOG(4) << "-- The " << real_compiled_num_ << "-th compilation (" VLOG(4) << "-- The " << compiled_num << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n" << target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize(); << cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion"); ApplyPass(cinn_graph.get(), "OpFusion");
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <atomic> #include <atomic>
#include <cstdint>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -63,15 +64,17 @@ class CinnCompiler { ...@@ -63,15 +64,17 @@ class CinnCompiler {
std::string AddGraph(std::unique_ptr<ir::Graph> graph); std::string AddGraph(std::unique_ptr<ir::Graph> graph);
const ir::Graph& FindGraph(const std::string& key) const; const ir::Graph& FindGraph(const std::string& graph_key) const;
std::string VizGraph(const std::string& key) const; std::string VizGraph(const std::string& graph_key) const;
std::string ReadableKey(const std::string& key) const; std::string VizGraph(const ir::Graph& graph) const;
std::string ReadableKey(const std::string& compilation_key) const;
void Clear(); void Clear();
std::int64_t real_compiled_num() const { return real_compiled_num_; } std::int64_t real_compiled_num() const { return real_compiled_num_.load(); }
~CinnCompiler() = default; ~CinnCompiler() = default;
...@@ -80,13 +83,13 @@ class CinnCompiler { ...@@ -80,13 +83,13 @@ class CinnCompiler {
std::unique_ptr<CinnCompiledObject> CompileGraph( std::unique_ptr<CinnCompiledObject> CompileGraph(
const ir::Graph& graph, const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors, const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target) const; const ::cinn::common::Target& target, std::int64_t compiled_num) const;
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_; std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKey, std::unique_ptr<CinnCompiledObject>, std::unordered_map<CinnCacheKey, std::unique_ptr<CinnCompiledObject>,
CinnCacheKey::Hash> CinnCacheKey::Hash>
cache_; cache_;
std::atomic_int64_t real_compiled_num_{0}; std::atomic_int64_t real_compiled_num_{1};
mutable RWLock rwlock_; mutable RWLock rwlock_;
DISABLE_COPY_AND_ASSIGN(CinnCompiler); DISABLE_COPY_AND_ASSIGN(CinnCompiler);
......
...@@ -101,7 +101,6 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -101,7 +101,6 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
<< "value:\n" << "value:\n"
<< CinnCompiler::GetInstance()->ReadableKey(compilation_key); << CinnCompiler::GetInstance()->ReadableKey(compilation_key);
const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto input_variable_names = ctx.InputNames(kX); auto input_variable_names = ctx.InputNames(kX);
const auto& input_tensors = ctx.MultiInput<LoDTensor>(kX); const auto& input_tensors = ctx.MultiInput<LoDTensor>(kX);
std::map<std::string, const LoDTensor*> inputs_name2tensor; std::map<std::string, const LoDTensor*> inputs_name2tensor;
...@@ -114,8 +113,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -114,8 +113,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// Step 2. Get compilation result of the graph // Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(place); auto target = details::PlaceToCinnTarget(place);
const auto& cinn_compiled_object = const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
CinnCompiler::GetInstance()->Compile(graph, inputs_name2tensor, target); compilation_key, inputs_name2tensor, target);
details::DebugCinnCompiledResult(cinn_compiled_object); details::DebugCinnCompiledResult(cinn_compiled_object);
const auto& cinn_runtime_program = cinn_compiled_object.runtime_program; const auto& cinn_runtime_program = cinn_compiled_object.runtime_program;
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import logging
import numpy as np
import paddle
import unittest
paddle.enable_static()
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
def set_cinn_flag(val):
cinn_compiled = False
try:
paddle.set_flags({'FLAGS_use_cinn': val})
cinn_compiled = True
except ValueError:
logger.warning("The used paddle is not compiled with CINN.")
return cinn_compiled
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
class TestResnet50Accuracy(unittest.TestCase):
def reader(self, limit):
for _ in range(limit):
yield np.random.randint(0, 256, size=[1, 3, 224, 224]).astype('float32'), \
np.random.randint(0, 1000, size=[1]).astype('int64')
def generate_random_data(self, loop_num=10):
feed = []
data = self.reader(loop_num)
for _ in range(loop_num):
x, y = next(data)
feed.append({'image': x, 'label': y})
return feed
def build_program(self, main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
image = paddle.static.data(
name='image', shape=[1, 3, 224, 224], dtype='float32')
label = paddle.static.data(name='label', shape=[1], dtype='int64')
model = paddle.vision.models.resnet50()
prediction = model(image)
loss = paddle.nn.functional.cross_entropy(
input=prediction, label=label)
loss = paddle.mean(loss)
adam = paddle.optimizer.Adam(learning_rate=0.001)
adam.minimize(loss)
return loss
def train(self, place, iters, feed, use_cinn=False, seed=1234):
np.random.seed(seed)
paddle.seed(seed)
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
set_cinn_flag(use_cinn)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
loss = self.build_program(main_program, startup_program)
exe = paddle.static.Executor(place)
parallel_exec = paddle.static.CompiledProgram(
main_program).with_data_parallel(loss_name=loss.name)
loss_vals = []
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
for step in range(iters):
loss_v = exe.run(parallel_exec,
feed=feed[step],
fetch_list=[loss],
return_numpy=True)
loss_vals.append(loss_v[0][0])
return loss_vals
def test_check_resnet50_accuracy(self):
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda(
) else paddle.CPUPlace()
loop_num = 10
feed = self.generate_random_data(loop_num)
loss_c = self.train(place, loop_num, feed, use_cinn=True)
loss_p = self.train(place, loop_num, feed, use_cinn=False)
self.assertTrue(np.allclose(loss_c, loss_p, atol=1e-5))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册