diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 42716d4c45c63eae36200d1b44e41319ab321eba..04931c7c4b35e114d02f9b374cdeeb853d526386 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -1,17 +1,11 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc) -cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_method graph lod_tensor proto_desc) -cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope) -cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector) - -if (WITH_CINN) - cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) - cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn) - - cc_test(test_transform_desc SRCS transform_desc_test.cc DEPS transform_desc) - cc_test(test_cinn_graph_symbolization SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization) -endif() +cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler) +cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) +cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn) +cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) -cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc) -cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object) -cc_test(test_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS build_cinn_pass) +cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler) +cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc) +cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization) +cc_test(cinn_compiler_test SRCS cinn_compiler_test.cc DEPS cinn_compiler place proto_desc graph_viz_pass build_cinn_pass cinn) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index caddc8fbb7381d1cf242d233d1f58db4e516dbc4..e86a475e59add09cdfe46ad2aa8f48931ee5cf6c 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -14,45 +14,21 @@ limitations under the License. */ #include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" +#include +#include #include #include #include #include +#include #include +#include "cinn/frontend/op_mapper_registry.h" +#include "cinn/frontend/op_mappers/use_op_mappers.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" -// #include "cinn/frontend/op_mapper_registry.h" -// #include "cinn/frontend/op_mappers/use_op_mappers.h" - -// TODO(jiangcheng05): just for local compile, remove after -// paddle and CINN have been binded -// The APIs are the same as CINN: -// https://github.com/PaddlePaddle/CINN/blob/develop/cinn/utils/registry.h -namespace cinn { -namespace frontend { -class OpMapperRegistry { - public: - static OpMapperRegistry* Global() { - static OpMapperRegistry inst; - return &inst; - } - - inline const OpMapperRegistry* Find(const std::string& name) { - std::unordered_set fmap_ = {"mul", "add", "relu", "sigmoid", - "softmax"}; - auto p = fmap_.find(name); - if (p != fmap_.end()) { - return this; - } else { - return nullptr; - } - } -}; - -} // namespace frontend -} // namespace cinn +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" namespace paddle { namespace framework { @@ -141,17 +117,17 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs) { // Graph's constructor must has one parameter, and in our code, // the ProgramDesc is useless, so here we pass a temporary object. - auto sub_graph = std::make_unique(framework::ProgramDesc()); + auto subgraph = std::make_unique(framework::ProgramDesc()); std::unordered_map old_op2new_op; for (auto* op : cluster) { - auto sub_node = sub_graph->CreateOpNode(op->Op()); + auto sub_node = subgraph->CreateOpNode(op->Op()); old_op2new_op[op] = sub_node; } std::unordered_map old_var2new_var; for (auto* var : cluster_internals) { - auto sub_node = sub_graph->CreateVarNode(var->Var()); + auto sub_node = subgraph->CreateVarNode(var->Var()); old_var2new_var[var] = sub_node; } @@ -190,9 +166,9 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, } } - AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, sub_graph.get()); - AddParamVar(param_vars, cluster, old_op2new_op, sub_graph.get()); - AddOutputVar(output_vars, cluster, old_op2new_op, sub_graph.get()); + AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, subgraph.get()); + AddParamVar(param_vars, cluster, old_op2new_op, subgraph.get()); + AddOutputVar(output_vars, cluster, old_op2new_op, subgraph.get()); for (auto* var : cluster_internals) { for (auto* op : var->inputs) { @@ -207,7 +183,7 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, } } - return sub_graph; + return subgraph; } // This interface is used to classify all variables involved in a cluster into @@ -256,11 +232,24 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster, } } -Node* AddSpecialOpToGraph(Graph* graph, const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs) { +Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, + const std::string& compilation_key, Graph* graph) { // add special cinn op framework::OpDesc special_op_desc; special_op_desc.SetType(kCinnLaunchOp); + std::vector input_names; + std::transform(cluster_inputs.begin(), cluster_inputs.end(), + std::back_inserter(input_names), + [](Node* n) { return n->Name(); }); + special_op_desc.SetInput("X", input_names); + std::vector output_names; + std::transform(cluster_outputs.begin(), cluster_outputs.end(), + std::back_inserter(output_names), + [](Node* n) { return n->Name(); }); + special_op_desc.SetOutput("Out", output_names); + special_op_desc.SetAttr(kCompilationKey, compilation_key); + special_op_desc.Flush(); auto* special_op_node = graph->CreateOpNode(&special_op_desc); special_op_node->inputs.assign(cluster_inputs.begin(), cluster_inputs.end()); special_op_node->outputs.assign(cluster_outputs.begin(), @@ -268,9 +257,9 @@ Node* AddSpecialOpToGraph(Graph* graph, const GraphNodeSet& cluster_inputs, return special_op_node; } -void AddLinkToSpecialOp(Node* special_op_node, - const GraphNodeSet& cluster_inputs, - const GraphNodeSet& cluster_outputs) { +void AddLinkToSpecialOp(const GraphNodeSet& cluster_inputs, + const GraphNodeSet& cluster_outputs, + Node* special_op_node) { // add new link from cluster_inputs to special_op_node for (auto* var_node : cluster_inputs) { var_node->outputs.push_back(special_op_node); @@ -338,14 +327,15 @@ void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs, const GraphNodeSet& cluster_outputs, const GraphNodeSet& cluster_internals, + const std::string& compilation_key, Graph* graph) { // First, add the special op node whose name is "kCinnLaunchOp" into graph - auto special_op_node = - AddSpecialOpToGraph(graph, cluster_inputs, cluster_outputs); + auto special_op_node = AddSpecialOpToGraph(cluster_inputs, cluster_outputs, + compilation_key, graph); // Second, remove all graph's links which are from or to cluster nodes RemoveLinkFromCluster(cluster, cluster_inputs, cluster_outputs); // Third, add new links from or to the the special op node - AddLinkToSpecialOp(special_op_node, cluster_inputs, cluster_outputs); + AddLinkToSpecialOp(cluster_inputs, cluster_outputs, special_op_node); // Finally, remove the cinn sub graph from graph RemoveSubGraphFromGraph(cluster, cluster_internals, graph); } @@ -354,8 +344,7 @@ void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster, // Here we using SubgraphDetector to detecte the subgraph that // all of op node supported by CINN. We using OpMapperRegistry // to check whether the op node supported by CINN. -void SearchAllSubgraphs(Graph* graph, - std::vector>* cinn_subgraphs) { +void SearchAllSubgraphs(Graph* graph) { auto teller = [](const Node* node) { return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) != nullptr; @@ -363,29 +352,26 @@ void SearchAllSubgraphs(Graph* graph, std::vector clusters = framework::ir::SubgraphDetector(graph, teller)(); - cinn_subgraphs->clear(); + auto* cinn_compiler = CinnCompiler::GetInstance(); for (const auto& node_vec : clusters) { - // classify var node to inputs, outputs, and internals. + // Classify var node to inputs, outputs, and internals. GraphNodeSet cluster_set(node_vec.begin(), node_vec.end()); GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals; AnalyseClusterVariables(cluster_set, &cluster_inputs, &cluster_outputs, &cluster_internals); - - cinn_subgraphs->emplace_back( + // Create a new subgraph according to the found cluster and + // save it in CinnCompiler + std::string compilation_key = cinn_compiler->AddGraph( CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs)); - - // replacing subgraph to a new special op node + // Replace the found cluster to a new special op node ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs, - cluster_outputs, cluster_internals, graph); + cluster_outputs, cluster_internals, + compilation_key, graph); } } -void BuildCinnPass::ApplyImpl(Graph* graph) const { - auto& cinn_subgraphs = - Get>>("cinn_subgraphs"); - SearchAllSubgraphs(graph, &cinn_subgraphs); -} +void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); } } // namespace paddle2cinn } // namespace framework diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index e71160ba108ecf4bf349291d2e8669b11a5df827..556ff228915e4dc772cff8a7ba562d0d5a9117ab 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -21,6 +21,7 @@ namespace framework { namespace paddle2cinn { constexpr char kCinnLaunchOp[] = "CinnLaunchOp"; +constexpr char kCompilationKey[] = "compilation_key"; // A pass named BuildCinnPass, the function of this pass is: // @@ -39,12 +40,13 @@ constexpr char kCinnLaunchOp[] = "CinnLaunchOp"; // Firstly, both op nodes should be compile supported. // Secondly, there should be a direct path between the two op nodes through a // var node. -// Thirdly, there should be no extral path between the two op nodes through +// Thirdly, there should be no extra path between the two op nodes through // unsupported op nodes. // Lastly, if op nodes a and b can be divied into a cluster, op nodes b and c -// can be devided into a cluster, a and c can also be devided into a cluster. -// The implementation of cluster detection is enclosured in class -// SubGraphDetector. +// can be divided into a cluster, a and c can also be divided into a cluster. +// The implementation of cluster detection is encapsulated in the +// SubGraphDetector +// class. // // b) How to deal with the links between the var nodes in global graph and the // op nodes in a cluster? diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index bf68a2b554b7f1d62cdc8e31dbba7aa050df9008..ab5768e0b2be357962abe560bc1548756da463b0 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" @@ -83,6 +84,18 @@ inline bool CheckGraphIndependence(const std::unordered_set& nodes) { return true; } +// Get compilation_key values +std::vector GetCompilationKeys(const Graph& graph) { + std::vector compilation_keys; + for (auto& node : graph.Nodes()) { + if (node->IsOp() && node->Name() == kCinnLaunchOp) { + compilation_keys.emplace_back( + BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); + } + } + return compilation_keys; +} + std::unique_ptr BuildNoCinnSubgraph() { ProgramDesc prog; auto g = std::make_unique(prog); @@ -133,17 +146,14 @@ TEST(BuildCinnPassTest, NoCinnSubgraph) { auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); - std::vector> cinn_subgraphs; - pass->SetNotOwned>>("cinn_subgraphs", - &cinn_subgraphs); pass->Apply(g.get()); // After search, origin graph should no change ASSERT_EQ(previous_nodes, g->Nodes()); ASSERT_TRUE(CheckGraphIndependence(g->Nodes())); - // After search, there should one cinn subgraph - ASSERT_TRUE(cinn_subgraphs.empty()); + // After search, there should be no cinn subgraph + ASSERT_TRUE(GetCompilationKeys(*g).empty()); } std::unique_ptr BuildAllOpSupportCinnGraph() { @@ -212,9 +222,6 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); - std::vector> cinn_subgraphs; - pass->SetNotOwned>>("cinn_subgraphs", - &cinn_subgraphs); pass->Apply(g.get()); // After search, the graph should as following @@ -250,10 +257,12 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { // | --> mul --> v3 -- // v2 -- | --> add --> v5 --> relu --> v6 // feed --> v4 -- - ASSERT_EQ(cinn_subgraphs.size(), static_cast(1)); - const auto& subgraph = cinn_subgraphs.back(); + auto compilation_keys = GetCompilationKeys(*g); + ASSERT_EQ(compilation_keys.size(), static_cast(1)); + auto* cinn_compiler = CinnCompiler::GetInstance(); + const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); - const auto& subnodes = subgraph->Nodes(); + const auto& subnodes = subgraph.Nodes(); ASSERT_EQ(subnodes.size(), static_cast(11)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); @@ -338,9 +347,6 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); - std::vector> cinn_subgraphs; - pass->SetNotOwned>>("cinn_subgraphs", - &cinn_subgraphs); pass->Apply(g.get()); // After search, the graph should as following @@ -366,10 +372,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) { // feed --> v1 -- // | --> mul --> v3 --> relu --> v4 // v2 -- - ASSERT_EQ(cinn_subgraphs.size(), static_cast(1)); - const auto& subgraph = cinn_subgraphs.back(); + auto compilation_keys = GetCompilationKeys(*g); + ASSERT_EQ(compilation_keys.size(), static_cast(1)); + auto* cinn_compiler = CinnCompiler::GetInstance(); + const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); - const auto& subnodes = subgraph->Nodes(); + const auto& subnodes = subgraph.Nodes(); ASSERT_EQ(subnodes.size(), static_cast(7)); ASSERT_TRUE(CheckGraphIndependence(subnodes)); @@ -450,9 +458,6 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { auto pass = paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); - std::vector> cinn_subgraphs; - pass->SetNotOwned>>("cinn_subgraphs", - &cinn_subgraphs); pass->Apply(g.get()); // After search, the graph should as following @@ -478,7 +483,8 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { // After search, there should has two cinn subgraphs, // and each of subgraphs just has one node. - ASSERT_EQ(cinn_subgraphs.size(), static_cast(2)); + auto compilation_keys = GetCompilationKeys(*g); + ASSERT_EQ(compilation_keys.size(), static_cast(2)); // subgraph1: // feed --> v4 --> relu --> v5 @@ -486,12 +492,13 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { // feed --> v1 -- // | --> mul --> v3 // v2 -- - const auto& subgraph1 = cinn_subgraphs[0]; - const auto& subnodes1 = subgraph1->Nodes(); + auto* cinn_compiler = CinnCompiler::GetInstance(); + const auto& subgraph1 = cinn_compiler->FindGraph(compilation_keys[0]); + const auto& subnodes1 = subgraph1.Nodes(); ASSERT_TRUE(CheckGraphIndependence(subnodes1)); - const auto& subgraph2 = cinn_subgraphs[1]; - const auto& subnodes2 = subgraph2->Nodes(); + const auto& subgraph2 = cinn_compiler->FindGraph(compilation_keys[1]); + const auto& subnodes2 = subgraph2.Nodes(); ASSERT_TRUE(CheckGraphIndependence(subnodes2)); if (CheckNodeExisted(subnodes1, "relu")) { diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc index ac6c83be4fae3c944a285ccce99fe3285280ed09..923282c59e2d4aa35770b0f134137d1cfc2d24d2 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.cc @@ -28,32 +28,38 @@ namespace paddle2cinn { CinnCacheKey::CinnCacheKey( const ir::Graph& graph, - const std::map& feed_tensors) { - this->SetKey(graph, feed_tensors); + const std::map& input_tensors, + const std::string& arch_str) { + this->SetKey(graph, input_tensors, arch_str); } CinnCacheKey::CinnCacheKey(const ir::Graph& graph, - const std::map& feed_shapes) { - this->SetKey(graph, feed_shapes); + const std::map& input_shapes, + const std::string& arch_str) { + this->SetKey(graph, input_shapes, arch_str); } void CinnCacheKey::SetKey( const ir::Graph& graph, - const std::map& feed_tensors) { + const std::map& input_tensors, + const std::string& arch_str) { ProgramDesc program; GraphToProgram(graph, &program); program.Proto()->SerializeToString(&graph_serialize_str_); - for (const auto& name_tensor : feed_tensors) { - feed_shapes_[name_tensor.first] = name_tensor.second->dims(); + for (const auto& name_tensor : input_tensors) { + input_shapes_[name_tensor.first] = name_tensor.second->dims(); } + arch_str_ = arch_str; } void CinnCacheKey::SetKey(const ir::Graph& graph, - const std::map& feed_shapes) { + const std::map& input_shapes, + const std::string& arch_str) { ProgramDesc program; GraphToProgram(graph, &program); program.Proto()->SerializeToString(&graph_serialize_str_); - feed_shapes_ = feed_shapes; + input_shapes_ = input_shapes; + arch_str_ = arch_str; } bool CinnCacheKey::operator!=(const CinnCacheKey& other) const { @@ -62,7 +68,7 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const { bool CinnCacheKey::operator==(const CinnCacheKey& other) const { return graph_serialize_str_ == other.graph_serialize_str_ && - feed_shapes_ == other.feed_shapes_; + input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_; } size_t CinnCacheKey::Hash::hash_combine(size_t seed, size_t value) { @@ -73,12 +79,13 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const { std::size_t ret = 0; std::hash string_hasher; - for (const auto& name_shape : key.feed_shapes_) { + for (const auto& name_shape : key.input_shapes_) { ret = hash_combine(ret, string_hasher(name_shape.first)); ret = hash_combine(ret, string_hasher(name_shape.second.to_str())); } ret = hash_combine(ret, string_hasher(key.graph_serialize_str_)); + ret = hash_combine(ret, string_hasher(key.arch_str_)); return ret; } diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h index 9627ae92aaba25fac9bae404d2ebfffba912db21..02b152a681c446fd6ffbbc82c5c675bc82297057 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key.h @@ -26,24 +26,28 @@ namespace paddle2cinn { // Class to store the keys for compiling CINN. // -// CINN cannot handle changable shape now, so CinnRunner keeps a cache mapping +// CINN cannot handle changable shape now, so CinnCompiler keeps a cache mapping // from CinnCacheKey to CinnCompiledObject. // -// The CinnCacheKey contains a graph serialized string and the feeded tensor +// The CinnCacheKey contains a graph serialized string and the input tensor // shapes. class CinnCacheKey { public: CinnCacheKey(const ir::Graph& graph, - const std::map& feed_tensors); + const std::map& input_tensors, + const std::string& arch_str); CinnCacheKey(const ir::Graph& graph, - const std::map& feed_shapes); + const std::map& input_shapes, + const std::string& arch_str); ~CinnCacheKey() {} void SetKey(const ir::Graph& graph, - const std::map& feed_tensors); + const std::map& input_tensors, + const std::string& arch_str); void SetKey(const ir::Graph& graph, - const std::map& feed_shapes); + const std::map& input_shapes, + const std::string& arch_str); bool operator==(const CinnCacheKey& other) const; bool operator!=(const CinnCacheKey& other) const; @@ -55,7 +59,8 @@ class CinnCacheKey { private: std::string graph_serialize_str_; - std::map feed_shapes_; + std::map input_shapes_; + std::string arch_str_; }; } // namespace paddle2cinn diff --git a/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc index a84ade26bfd1248cd44793fa0e1a12fdfea809ee..f13f44998211f4d2a45088766c175e79f09853fa 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_cache_key_test.cc @@ -47,17 +47,19 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) { DDim ddim = paddle::framework::make_ddim({1, 2, 3}); std::map feed_shapes = {{"X", ddim}}; - CinnCacheKey cache_key1(empty_graph, feed_tensors); - CinnCacheKey cache_key2(empty_graph, feed_shapes); - EXPECT_EQ(cache_key1, cache_key2); - - CinnCacheKey cache_key3(graph, feed_shapes); - CinnCacheKey cache_key4(graph, feed_tensors); + CinnCacheKey cache_key0(empty_graph, feed_tensors, "x86"); + CinnCacheKey cache_key1(empty_graph, feed_shapes, "x86"); + EXPECT_EQ(cache_key0, cache_key1); + + CinnCacheKey cache_key2(graph, feed_shapes, "x86"); + CinnCacheKey cache_key3(graph, feed_shapes, "nvgpu"); + CinnCacheKey cache_key4(graph, feed_tensors, "nvgpu"); + EXPECT_NE(cache_key2, cache_key3); EXPECT_EQ(cache_key3, cache_key4); CinnCacheKey cache_key5(empty_graph, - std::map()); - CinnCacheKey cache_key6(empty_graph, std::map()); + std::map(), "unk"); + CinnCacheKey cache_key6(empty_graph, std::map(), "unk"); EXPECT_EQ(cache_key5, cache_key6); EXPECT_NE(cache_key1, cache_key3); @@ -69,19 +71,19 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) { EXPECT_NE(cache_key5, cache_key1); EXPECT_NE(cache_key2, cache_key6); + test_set.insert(cache_key0); test_set.insert(cache_key1); - test_set.insert(cache_key2); test_set.insert(cache_key3); test_set.insert(cache_key4); test_set.insert(cache_key5); test_set.insert(cache_key6); EXPECT_EQ(test_set.size(), 3U); - auto iter = test_set.find(cache_key1); + auto iter = test_set.find(cache_key0); EXPECT_NE(iter, test_set.end()); test_set.erase(iter); EXPECT_EQ(test_set.size(), 2U); - EXPECT_EQ(test_set.find(cache_key2), test_set.end()); + EXPECT_EQ(test_set.find(cache_key1), test_set.end()); iter = test_set.find(cache_key3); EXPECT_NE(iter, test_set.end()); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiled_object.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiled_object.cc deleted file mode 100644 index a90494bafe9bb6fdefca6378ef03b9ee5fda1e62..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiled_object.cc +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h" - -#include - -#include "paddle/fluid/framework/feed_fetch_type.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" - -namespace paddle { -namespace framework { -namespace paddle2cinn { - -CinnCompiledObject::CinnCompiledObject() { - // TODO(zhhsplendid): complete this function after CINN interface is ready -} -CinnCompiledObject::~CinnCompiledObject() { - // TODO(zhhsplendid): complete this function after CINN interface is ready -} - -void CinnCompiledObject::Compile( - const ir::Graph& graph, - std::map* feed_targets) { - // TODO(zhhsplendid): complete this function after CINN interface is ready -} - -std::map CinnCompiledObject::Run( - Scope* scope, std::map* feed_targets) { - // TODO(zhhsplendid): complete this function after CINN interface is ready - return std::map(); -} - -} // namespace paddle2cinn -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h b/paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h deleted file mode 100644 index 21191d44345877dec2520d9ae848283f63017b27..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/framework/feed_fetch_type.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/scope.h" - -namespace paddle { -namespace framework { -namespace paddle2cinn { - -// Class to store and call CINN complied object -class CinnCompiledObject { - public: - CinnCompiledObject(); - ~CinnCompiledObject(); - - // Compiles use CINN. CINN compilation needs model graph, input names, and - // input_shapes - void Compile(const ir::Graph& graph, - std::map* feed_targets); - - // Feed LoDTensors to tun CINN compiled object and return fetched result - std::map Run( - Scope* scope, std::map* feed_targets); - - // Converts compiled object to Paddle Graph - // To be discussed - // ir::Graph ToGraph(); -}; - -} // namespace paddle2cinn -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiled_object_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiled_object_test.cc deleted file mode 100644 index 5a7861edf210c4d8cba5213d10ac6e2811e1d223..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiled_object_test.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include - -#include "gtest/gtest.h" - -#include "paddle/fluid/framework/feed_fetch_type.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h" -#include "paddle/fluid/framework/program_desc.h" - -namespace paddle { -namespace framework { -namespace paddle2cinn { - -TEST(CinnCompiledObjecctTest, TodoTest) { - ProgramDesc empty_program; - ir::Graph empty_graph(empty_program); - std::map empty_feed; - Scope empty_scope; - - CinnCompiledObject compiled_obj; - compiled_obj.Compile(empty_graph, &empty_feed); - auto fetch = compiled_obj.Run(&empty_scope, &empty_feed); -} - -} // namespace paddle2cinn -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..44cea60bdcb8e42d2448041eb475cb9514cc332d --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -0,0 +1,127 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" + +#include +#include +#include + +#include "cinn/common/target.h" +#include "cinn/common/type.h" +#include "cinn/frontend/decomposer/use_decomposer.h" +#include "cinn/frontend/net_builder.h" // need to remove after +#include "cinn/frontend/pass/use_program_pass.h" +#include "cinn/frontend/program_pass.h" +#include "cinn/frontend/syntax.h" +#include "cinn/hlir/framework/graph.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "cinn/hlir/framework/pass.h" +#include "cinn/hlir/pass/use_pass.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using ir::Graph; +using ::cinn::common::Target; +using ::cinn::common::Float; +using ::cinn::hlir::framework::GraphCompiler; +using ::cinn::hlir::framework::BuildScope; +using ::cinn::frontend::ProgramPass; +using ::cinn::hlir::framework::ApplyPass; + +CinnCompiler* CinnCompiler::GetInstance() { + static CinnCompiler instance; + return &instance; +} + +std::string CinnCompiler::AddGraph(std::unique_ptr graph) { + std::string graph_key; + ProgramDesc program; + GraphToProgram(*graph, &program); + program.Proto()->SerializeToString(&graph_key); + if (!graphs_.count(graph_key)) { + graphs_[graph_key] = std::move(graph); + } else { + LOG(WARNING) + << "The graph being added is already in CinnCompiler. Its key is:\n" + << graph_key; + } + return graph_key; +} + +const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { + PADDLE_ENFORCE_NE( + graphs_.count(graph_key), 0, + platform::errors::InvalidArgument("Can not find the target graph: %s", + graph_key.c_str())); + return *graphs_.at(graph_key); +} + +const CinnCompiledObject& CinnCompiler::Compile( + const Graph& graph, + const std::map& input_tensors, + const Target& target) { + CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); + if (!cache_.count(cur_key)) { + real_compiled_num_++; + cache_[cur_key] = CompileGraph(graph, input_tensors, target); + } + return *cache_[cur_key]; +} + +const CinnCompiledObject& CinnCompiler::Compile( + const std::string& compilation_key, + const std::map& input_tensors, + const Target& target) { + const auto& graph = FindGraph(compilation_key); + return Compile(graph, input_tensors, target); +} + +std::unique_ptr CinnCompiler::CompileGraph( + const ir::Graph& graph, + const std::map& input_tensors, + const Target& target) const { + CinnGraphSymbolization symbol{real_compiled_num_, graph, target, + input_tensors}; + auto frontend_program = symbol(); + ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); + auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( + frontend_program, target); + VLOG(4) << "The " << real_compiled_num_ << "-th compilation (" + << target.arch_str() << "), and its related graph:\n" + << cinn_graph->Visualize(); + ApplyPass(cinn_graph.get(), "OpFusion"); + auto scope = BuildScope(target, cinn_graph); + GraphCompiler graph_compiler(target, scope, cinn_graph); + GraphCompiler::CompileOptions options; + options.with_instantiate_variables = false; + auto compiled_res = graph_compiler.Build(options); + auto compiled_obj = std::make_unique(); + *compiled_obj = {std::move(compiled_res.runtime_program), scope, + symbol.var_model_to_program_map()}; + return compiled_obj; +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..3b0fb5cf6965f499650e47a5768323e1d98c4b19 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -0,0 +1,88 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "cinn/common/target.h" +#include "cinn/hlir/framework/graph_compiler.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/macros.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +struct CinnCompiledObject { + std::unique_ptr<::cinn::hlir::framework::Program> runtime_program; + std::shared_ptr<::cinn::hlir::framework::Scope> scope; + std::unordered_map paddle2cinn_varmap; +}; + +// Entrance to use CINN. +// +// CINN cannot handle changable shape now, so CinnCompiler keeps a cache mapping +// from CinnCacheKey to CinnCompiledObject. If cache hits, we will re-use cache +// stored CinnCompiledObject, otherwise we will compile again and put into +// cache. +class CinnCompiler { + public: + // Singleton + static CinnCompiler* GetInstance(); + + const CinnCompiledObject& Compile( + const ir::Graph& graph, + const std::map& input_tensors, + const ::cinn::common::Target& target); + + const CinnCompiledObject& Compile( + const std::string& compilation_key, + const std::map& input_tensors, + const ::cinn::common::Target& target); + + std::string AddGraph(std::unique_ptr graph); + + const ir::Graph& FindGraph(const std::string& key) const; + + std::int64_t real_compiled_num() const { return real_compiled_num_; } + + ~CinnCompiler() = default; + + private: + CinnCompiler() = default; + std::unique_ptr CompileGraph( + const ir::Graph& graph, + const std::map& input_tensors, + const ::cinn::common::Target& target) const; + + std::unordered_map> graphs_; + std::unordered_map, + CinnCacheKey::Hash> + cache_; + std::atomic_int64_t real_compiled_num_{0}; + + DISABLE_COPY_AND_ASSIGN(CinnCompiler); +}; + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..22792e0f8c359aa5245f04435fe2b3ca428a48d9 --- /dev/null +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -0,0 +1,168 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" + +#include +#include +#include + +#include "cinn/common/target.h" +#include "gtest/gtest.h" +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace paddle2cinn { + +using ir::Graph; +using ::cinn::common::Target; + +// X - +// | -> mul -> MUL_OUT - +// Y - | -> elementwise_add -> ADD_OUT -> relu -> RELU_OUT +// Z - +std::unique_ptr CreateGraph() { + ProgramDesc program; + auto* global_block = program.MutableBlock(0); + // mul + auto* x = global_block->Var("X"); + x->SetType(proto::VarType::LOD_TENSOR); + x->SetLoDLevel(0); + x->SetDataType(proto::VarType::FP32); + x->SetShape({1000, 784}); + + auto* y = global_block->Var("Y"); + y->SetType(proto::VarType::LOD_TENSOR); + y->SetLoDLevel(0); + y->SetDataType(proto::VarType::FP32); + y->SetShape({784, 100}); + y->SetPersistable(true); + y->SetIsParameter(true); + + auto* mul_op = global_block->AppendOp(); + mul_op->SetType("mul"); + mul_op->SetInput("X", {x->Name()}); + mul_op->SetInput("Y", {y->Name()}); + + auto* mul_out = global_block->Var("MUL_OUT"); + mul_out->SetType(proto::VarType::LOD_TENSOR); + mul_op->SetOutput("Out", {mul_out->Name()}); + + // add + auto* z = global_block->Var("Z"); + z->SetType(proto::VarType::LOD_TENSOR); + z->SetLoDLevel(0); + z->SetDataType(proto::VarType::FP32); + z->SetShape({100}); + z->SetPersistable(true); + z->SetIsParameter(true); + + auto* add_op = global_block->AppendOp(); + add_op->SetType("elementwise_add"); + add_op->SetInput("X", {mul_out->Name()}); + add_op->SetInput("Y", {z->Name()}); + + auto* add_out = global_block->Var("ADD_OUT"); + add_out->SetType(proto::VarType::LOD_TENSOR); + add_op->SetOutput("Out", {add_out->Name()}); + + // relu + auto* relu_op = global_block->AppendOp(); + relu_op->SetType("relu"); + relu_op->SetInput("X", {add_out->Name()}); + + auto* relu_out = global_block->Var("RELU_OUT"); + relu_out->SetType(proto::VarType::LOD_TENSOR); + relu_op->SetOutput("Out", {relu_out->Name()}); + program.Flush(); + return std::make_unique(program); +} + +TEST(CinnCompilerTest, Compile) { + auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); + auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass"); + auto viz_graph = [&viz_pass](const std::string& viz_path, Graph* graph) { + viz_pass->Erase("graph_viz_path"); + viz_pass->Set("graph_viz_path", new std::string(viz_path)); + viz_pass->Apply(graph); + }; + + // create a graph + auto graph = CreateGraph(); + viz_graph("origin_graph.dot", graph.get()); + // apply build_cinn_pass + cinn_pass->Apply(graph.get()); + viz_graph("processed_graph.dot", graph.get()); + // get the compilation_key + std::vector compilation_keys; + for (auto& node : graph->Nodes()) { + if (node->IsOp() && node->Name() == kCinnLaunchOp) { + compilation_keys.emplace_back( + BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); + } + } + ASSERT_EQ(compilation_keys.size(), 1); + + const auto& compilation_key = compilation_keys[0]; + auto* cinn_compiler = CinnCompiler::GetInstance(); + const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key); + // viz_graph("compiling_graph.dot", const_cast(&compiling_graph)); + + EXPECT_THROW(cinn_compiler->FindGraph("no_existed"), + paddle::platform::EnforceNotMet); + + LoDTensor tensor1, tensor2, tensor3; + tensor1.Resize({1000, 784}); + tensor2.Resize({784, 100}); + tensor3.Resize({100}); + tensor1.mutable_data(platform::CPUPlace()); + tensor2.mutable_data(platform::CPUPlace()); + tensor3.mutable_data(platform::CPUPlace()); + std::map input_tensors = { + {"X", &tensor1}, {"Y", &tensor2}, {"Z", &tensor3}}; + + auto compile_fn = [&](const Target& target) { + const auto& compiled_obj = + cinn_compiler->Compile(compiling_graph, input_tensors, target); + ASSERT_NE(compiled_obj.runtime_program, nullptr); + ASSERT_NE(compiled_obj.scope, nullptr); + ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty()); + const auto& cached_obj = + cinn_compiler->Compile(compilation_key, input_tensors, target); + ASSERT_EQ(reinterpret_cast(&compiled_obj), + reinterpret_cast(&cached_obj)); + }; + + // GPU Compilation + compile_fn(::cinn::common::DefaultNVGPUTarget()); + ASSERT_EQ(cinn_compiler->real_compiled_num(), 1); + // CPU Compilation + compile_fn(::cinn::common::DefaultHostTarget()); + ASSERT_EQ(cinn_compiler->real_compiled_num(), 2); +} + +} // namespace paddle2cinn +} // namespace framework +} // namespace paddle + +USE_PASS(build_cinn_pass); +USE_PASS(graph_viz_pass); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_runner.cc b/paddle/fluid/framework/paddle2cinn/cinn_runner.cc deleted file mode 100644 index ba90095cae6799b91b5f14a904f4cd960083d524..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/paddle2cinn/cinn_runner.cc +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" - -#include -#include -#include - -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/scope.h" -#include "paddle/fluid/framework/tensor.h" - -namespace paddle { -namespace framework { -namespace paddle2cinn { - -using ir::Graph; - -std::once_flag CinnRunner::get_instance_once_flag_; -std::shared_ptr CinnRunner::instance_; - -std::shared_ptr CinnRunner::GetInstance() { - std::call_once(get_instance_once_flag_, - [&]() { instance_.reset(new CinnRunner()); }); - return instance_; -} - -void CinnRunner::ReplaceWithCinn(Graph* graph) { - // TODO(zhhsplendid): call CINN Api when it is ready -} - -std::map CinnRunner::Run( - const Graph& graph, Scope* scope, - std::map* feed_targets) { - CinnCacheKey cur_key(graph, *feed_targets); - std::shared_ptr obj_to_run; - if (cache_.find(cur_key) != cache_.end()) { - obj_to_run = cache_[cur_key]; - } else { - obj_to_run = std::make_shared(); - obj_to_run->Compile(graph, feed_targets); - cache_[cur_key] = obj_to_run; - } - return obj_to_run->Run(scope, feed_targets); -} - -} // namespace paddle2cinn -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_runner.h b/paddle/fluid/framework/paddle2cinn/cinn_runner.h deleted file mode 100644 index 23d9565d2f3926de33bab4a3c7fa5ac320763840..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/paddle2cinn/cinn_runner.h +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" -#include "paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h" -#include "paddle/fluid/framework/scope.h" - -namespace paddle { -namespace framework { -namespace paddle2cinn { - -// Entrance to run CINN. -// -// CINN cannot handle changable shape now, so CinnRunner keeps a cache mapping -// from CinnCacheKey to CinnCompiledObject. If cache hits, we will re-use cache -// stored CinnCompiledObject, otherwise we will compile again and put into -// cache. -class CinnRunner { - public: - ~CinnRunner() {} - - // Singleton - static std::shared_ptr GetInstance(); - - // Replace Paddle graph with some CINN subgraphs/ops - void ReplaceWithCinn(ir::Graph* graph); - - // Feed LoDTensors to tun CINN compiled object and return fetched result - std::map Run( - const ir::Graph& graph, Scope* scope, - std::map* feed_targets); - - private: - CinnRunner() {} - - static std::once_flag get_instance_once_flag_; - static std::shared_ptr instance_; - std::unordered_map, - CinnCacheKey::Hash> - cache_; -}; - -} // namespace paddle2cinn -} // namespace framework -} // namespace paddle diff --git a/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc deleted file mode 100644 index c02b994c147ca11518e7d0f3a2cd7a2e1e875f94..0000000000000000000000000000000000000000 --- a/paddle/fluid/framework/paddle2cinn/cinn_runner_test.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h" - -#include - -#include "gtest/gtest.h" -#include "paddle/fluid/framework/ir/graph.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/framework/scope.h" - -namespace paddle { -namespace framework { -namespace paddle2cinn { - -using ir::Graph; - -TEST(CinnRunnerTest, TodoTest) { - ProgramDesc empty_program; - Graph empty_graph(empty_program); - Scope empty_scope; - std::map empty_feed; - - std::shared_ptr cinn_runner = CinnRunner::GetInstance(); - cinn_runner->ReplaceWithCinn(&empty_graph); - cinn_runner->Run(empty_graph, &empty_scope, &empty_feed); -} - -} // namespace paddle2cinn -} // namespace framework -} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py index bc0652b165eb654081d07bfc503440e8542d91e8..d26c7a1bb441edf7f10d5ce3b982411804906426 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py @@ -27,16 +27,18 @@ logger = logging.getLogger(__name__) def set_cinn_flag(val): + cinn_compiled = False try: paddle.set_flags({'FLAGS_use_cinn': val}) + cinn_compiled = True except ValueError: logger.warning("The used paddle is not compiled with CINN.") + return cinn_compiled +@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.") class TestParallelExecutorRunCinn(unittest.TestCase): def test_run_from_cinn(self): - set_cinn_flag(False) - main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program):