未验证 提交 4c460378 编写于 作者: Z Zhen Wang 提交者: GitHub

Create CinnCompiler class for compiling subgraphs found by build_cinn_pass. (#36562)

* Init the functions of CinnCompiler.

* Add the unit test for CinnCompiler.

* Fix some compilation errors.

* Update the UT of cinn_compiler.

* Use Decomposer&OpFusion passes in CinnCompiler::CompileGraph.

* Update some comments.

* Uncomment some includes in build_cinn_pass.cc.

* Use refs instead of ptrs as returned types of FindGraph & Compile in
CinnCompiler.

* Use the merged CinnGraphSymbolization functions in CinnCompiler.
上级 59d8b8cb
cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper lod_tensor proto_desc)
cc_library(cinn_compiled_object SRCS cinn_compiled_object.cc DEPS feed_fetch_method graph lod_tensor proto_desc)
cc_library(cinn_runner SRCS cinn_runner.cc DEPS cinn_cache_key cinn_compiled_object feed_fetch_method graph lod_tensor scope)
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector)
if (WITH_CINN)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)
cc_test(test_transform_desc SRCS transform_desc_test.cc DEPS transform_desc)
cc_test(test_cinn_graph_symbolization SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization)
endif()
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector cinn_compiler)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph graph_helper transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(cinn_runner_test SRCS cinn_runner_test.cc DEPS cinn_runner proto_desc)
cc_test(cinn_compiled_object_test SRCS cinn_compiled_object_test.cc DEPS cinn_compiled_object)
cc_test(test_build_cinn_pass SRCS build_cinn_pass_test.cc DEPS build_cinn_pass)
cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler)
cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc)
cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization)
cc_test(cinn_compiler_test SRCS cinn_compiler_test.cc DEPS cinn_compiler place proto_desc graph_viz_pass build_cinn_pass cinn)
......@@ -14,45 +14,21 @@ limitations under the License. */
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
// #include "cinn/frontend/op_mapper_registry.h"
// #include "cinn/frontend/op_mappers/use_op_mappers.h"
// TODO(jiangcheng05): just for local compile, remove after
// paddle and CINN have been binded
// The APIs are the same as CINN:
// https://github.com/PaddlePaddle/CINN/blob/develop/cinn/utils/registry.h
namespace cinn {
namespace frontend {
class OpMapperRegistry {
public:
static OpMapperRegistry* Global() {
static OpMapperRegistry inst;
return &inst;
}
inline const OpMapperRegistry* Find(const std::string& name) {
std::unordered_set<std::string> fmap_ = {"mul", "add", "relu", "sigmoid",
"softmax"};
auto p = fmap_.find(name);
if (p != fmap_.end()) {
return this;
} else {
return nullptr;
}
}
};
} // namespace frontend
} // namespace cinn
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
namespace paddle {
namespace framework {
......@@ -141,17 +117,17 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs) {
// Graph's constructor must has one parameter, and in our code,
// the ProgramDesc is useless, so here we pass a temporary object.
auto sub_graph = std::make_unique<Graph>(framework::ProgramDesc());
auto subgraph = std::make_unique<Graph>(framework::ProgramDesc());
std::unordered_map<Node*, Node*> old_op2new_op;
for (auto* op : cluster) {
auto sub_node = sub_graph->CreateOpNode(op->Op());
auto sub_node = subgraph->CreateOpNode(op->Op());
old_op2new_op[op] = sub_node;
}
std::unordered_map<Node*, Node*> old_var2new_var;
for (auto* var : cluster_internals) {
auto sub_node = sub_graph->CreateVarNode(var->Var());
auto sub_node = subgraph->CreateVarNode(var->Var());
old_var2new_var[var] = sub_node;
}
......@@ -190,9 +166,9 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
}
AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, sub_graph.get());
AddParamVar(param_vars, cluster, old_op2new_op, sub_graph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, sub_graph.get());
AddFeedOpAndVar(need_feed_vars, cluster, old_op2new_op, subgraph.get());
AddParamVar(param_vars, cluster, old_op2new_op, subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, subgraph.get());
for (auto* var : cluster_internals) {
for (auto* op : var->inputs) {
......@@ -207,7 +183,7 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
}
}
return sub_graph;
return subgraph;
}
// This interface is used to classify all variables involved in a cluster into
......@@ -256,11 +232,24 @@ void AnalyseClusterVariables(const GraphNodeSet& cluster,
}
}
Node* AddSpecialOpToGraph(Graph* graph, const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
Node* AddSpecialOpToGraph(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const std::string& compilation_key, Graph* graph) {
// add special cinn op
framework::OpDesc special_op_desc;
special_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names;
std::transform(cluster_inputs.begin(), cluster_inputs.end(),
std::back_inserter(input_names),
[](Node* n) { return n->Name(); });
special_op_desc.SetInput("X", input_names);
std::vector<std::string> output_names;
std::transform(cluster_outputs.begin(), cluster_outputs.end(),
std::back_inserter(output_names),
[](Node* n) { return n->Name(); });
special_op_desc.SetOutput("Out", output_names);
special_op_desc.SetAttr(kCompilationKey, compilation_key);
special_op_desc.Flush();
auto* special_op_node = graph->CreateOpNode(&special_op_desc);
special_op_node->inputs.assign(cluster_inputs.begin(), cluster_inputs.end());
special_op_node->outputs.assign(cluster_outputs.begin(),
......@@ -268,9 +257,9 @@ Node* AddSpecialOpToGraph(Graph* graph, const GraphNodeSet& cluster_inputs,
return special_op_node;
}
void AddLinkToSpecialOp(Node* special_op_node,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs) {
void AddLinkToSpecialOp(const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
Node* special_op_node) {
// add new link from cluster_inputs to special_op_node
for (auto* var_node : cluster_inputs) {
var_node->outputs.push_back(special_op_node);
......@@ -338,14 +327,15 @@ void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster,
const GraphNodeSet& cluster_inputs,
const GraphNodeSet& cluster_outputs,
const GraphNodeSet& cluster_internals,
const std::string& compilation_key,
Graph* graph) {
// First, add the special op node whose name is "kCinnLaunchOp" into graph
auto special_op_node =
AddSpecialOpToGraph(graph, cluster_inputs, cluster_outputs);
auto special_op_node = AddSpecialOpToGraph(cluster_inputs, cluster_outputs,
compilation_key, graph);
// Second, remove all graph's links which are from or to cluster nodes
RemoveLinkFromCluster(cluster, cluster_inputs, cluster_outputs);
// Third, add new links from or to the the special op node
AddLinkToSpecialOp(special_op_node, cluster_inputs, cluster_outputs);
AddLinkToSpecialOp(cluster_inputs, cluster_outputs, special_op_node);
// Finally, remove the cinn sub graph from graph
RemoveSubGraphFromGraph(cluster, cluster_internals, graph);
}
......@@ -354,8 +344,7 @@ void ReplaceSubGraphWithSpecialOpNode(const GraphNodeSet& cluster,
// Here we using SubgraphDetector to detecte the subgraph that
// all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph,
std::vector<std::unique_ptr<Graph>>* cinn_subgraphs) {
void SearchAllSubgraphs(Graph* graph) {
auto teller = [](const Node* node) {
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) !=
nullptr;
......@@ -363,29 +352,26 @@ void SearchAllSubgraphs(Graph* graph,
std::vector<GraphNodeVec> clusters =
framework::ir::SubgraphDetector(graph, teller)();
cinn_subgraphs->clear();
auto* cinn_compiler = CinnCompiler::GetInstance();
for (const auto& node_vec : clusters) {
// classify var node to inputs, outputs, and internals.
// Classify var node to inputs, outputs, and internals.
GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());
GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
AnalyseClusterVariables(cluster_set, &cluster_inputs, &cluster_outputs,
&cluster_internals);
cinn_subgraphs->emplace_back(
// Create a new subgraph according to the found cluster and
// save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(
CreateNewSubGraph(cluster_set, cluster_internals, cluster_inputs));
// replacing subgraph to a new special op node
// Replace the found cluster to a new special op node
ReplaceSubGraphWithSpecialOpNode(cluster_set, cluster_inputs,
cluster_outputs, cluster_internals, graph);
cluster_outputs, cluster_internals,
compilation_key, graph);
}
}
void BuildCinnPass::ApplyImpl(Graph* graph) const {
auto& cinn_subgraphs =
Get<std::vector<std::unique_ptr<Graph>>>("cinn_subgraphs");
SearchAllSubgraphs(graph, &cinn_subgraphs);
}
void BuildCinnPass::ApplyImpl(Graph* graph) const { SearchAllSubgraphs(graph); }
} // namespace paddle2cinn
} // namespace framework
......
......@@ -21,6 +21,7 @@ namespace framework {
namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "CinnLaunchOp";
constexpr char kCompilationKey[] = "compilation_key";
// A pass named BuildCinnPass, the function of this pass is:
//
......@@ -39,12 +40,13 @@ constexpr char kCinnLaunchOp[] = "CinnLaunchOp";
// Firstly, both op nodes should be compile supported.
// Secondly, there should be a direct path between the two op nodes through a
// var node.
// Thirdly, there should be no extral path between the two op nodes through
// Thirdly, there should be no extra path between the two op nodes through
// unsupported op nodes.
// Lastly, if op nodes a and b can be divied into a cluster, op nodes b and c
// can be devided into a cluster, a and c can also be devided into a cluster.
// The implementation of cluster detection is enclosured in class
// SubGraphDetector.
// can be divided into a cluster, a and c can also be divided into a cluster.
// The implementation of cluster detection is encapsulated in the
// SubGraphDetector
// class.
//
// b) How to deal with the links between the var nodes in global graph and the
// op nodes in a cluster?
......
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
......@@ -83,6 +84,18 @@ inline bool CheckGraphIndependence(const std::unordered_set<Node*>& nodes) {
return true;
}
// Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
}
}
return compilation_keys;
}
std::unique_ptr<Graph> BuildNoCinnSubgraph() {
ProgramDesc prog;
auto g = std::make_unique<Graph>(prog);
......@@ -133,17 +146,14 @@ TEST(BuildCinnPassTest, NoCinnSubgraph) {
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
std::vector<std::unique_ptr<Graph>> cinn_subgraphs;
pass->SetNotOwned<std::vector<std::unique_ptr<Graph>>>("cinn_subgraphs",
&cinn_subgraphs);
pass->Apply(g.get());
// After search, origin graph should no change
ASSERT_EQ(previous_nodes, g->Nodes());
ASSERT_TRUE(CheckGraphIndependence(g->Nodes()));
// After search, there should one cinn subgraph
ASSERT_TRUE(cinn_subgraphs.empty());
// After search, there should be no cinn subgraph
ASSERT_TRUE(GetCompilationKeys(*g).empty());
}
std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
......@@ -212,9 +222,6 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
std::vector<std::unique_ptr<Graph>> cinn_subgraphs;
pass->SetNotOwned<std::vector<std::unique_ptr<Graph>>>("cinn_subgraphs",
&cinn_subgraphs);
pass->Apply(g.get());
// After search, the graph should as following
......@@ -250,10 +257,12 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
// | --> mul --> v3 --
// v2 -- | --> add --> v5 --> relu --> v6
// feed --> v4 --
ASSERT_EQ(cinn_subgraphs.size(), static_cast<size_t>(1));
const auto& subgraph = cinn_subgraphs.back();
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph->Nodes();
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(11));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
......@@ -338,9 +347,6 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
std::vector<std::unique_ptr<Graph>> cinn_subgraphs;
pass->SetNotOwned<std::vector<std::unique_ptr<Graph>>>("cinn_subgraphs",
&cinn_subgraphs);
pass->Apply(g.get());
// After search, the graph should as following
......@@ -366,10 +372,12 @@ TEST(BuildCinnPassTest, OneCinnSubgraph) {
// feed --> v1 --
// | --> mul --> v3 --> relu --> v4
// v2 --
ASSERT_EQ(cinn_subgraphs.size(), static_cast<size_t>(1));
const auto& subgraph = cinn_subgraphs.back();
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph->Nodes();
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(7));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
......@@ -450,9 +458,6 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
std::vector<std::unique_ptr<Graph>> cinn_subgraphs;
pass->SetNotOwned<std::vector<std::unique_ptr<Graph>>>("cinn_subgraphs",
&cinn_subgraphs);
pass->Apply(g.get());
// After search, the graph should as following
......@@ -478,7 +483,8 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
// After search, there should has two cinn subgraphs,
// and each of subgraphs just has one node.
ASSERT_EQ(cinn_subgraphs.size(), static_cast<size_t>(2));
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(2));
// subgraph1:
// feed --> v4 --> relu --> v5
......@@ -486,12 +492,13 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
// feed --> v1 --
// | --> mul --> v3
// v2 --
const auto& subgraph1 = cinn_subgraphs[0];
const auto& subnodes1 = subgraph1->Nodes();
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph1 = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes1 = subgraph1.Nodes();
ASSERT_TRUE(CheckGraphIndependence(subnodes1));
const auto& subgraph2 = cinn_subgraphs[1];
const auto& subnodes2 = subgraph2->Nodes();
const auto& subgraph2 = cinn_compiler->FindGraph(compilation_keys[1]);
const auto& subnodes2 = subgraph2.Nodes();
ASSERT_TRUE(CheckGraphIndependence(subnodes2));
if (CheckNodeExisted(subnodes1, "relu")) {
......
......@@ -28,32 +28,38 @@ namespace paddle2cinn {
CinnCacheKey::CinnCacheKey(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& feed_tensors) {
this->SetKey(graph, feed_tensors);
const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str) {
this->SetKey(graph, input_tensors, arch_str);
}
CinnCacheKey::CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& feed_shapes) {
this->SetKey(graph, feed_shapes);
const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str) {
this->SetKey(graph, input_shapes, arch_str);
}
void CinnCacheKey::SetKey(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& feed_tensors) {
const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str) {
ProgramDesc program;
GraphToProgram(graph, &program);
program.Proto()->SerializeToString(&graph_serialize_str_);
for (const auto& name_tensor : feed_tensors) {
feed_shapes_[name_tensor.first] = name_tensor.second->dims();
for (const auto& name_tensor : input_tensors) {
input_shapes_[name_tensor.first] = name_tensor.second->dims();
}
arch_str_ = arch_str;
}
void CinnCacheKey::SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& feed_shapes) {
const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str) {
ProgramDesc program;
GraphToProgram(graph, &program);
program.Proto()->SerializeToString(&graph_serialize_str_);
feed_shapes_ = feed_shapes;
input_shapes_ = input_shapes;
arch_str_ = arch_str;
}
bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {
......@@ -62,7 +68,7 @@ bool CinnCacheKey::operator!=(const CinnCacheKey& other) const {
bool CinnCacheKey::operator==(const CinnCacheKey& other) const {
return graph_serialize_str_ == other.graph_serialize_str_ &&
feed_shapes_ == other.feed_shapes_;
input_shapes_ == other.input_shapes_ && arch_str_ == other.arch_str_;
}
size_t CinnCacheKey::Hash::hash_combine(size_t seed, size_t value) {
......@@ -73,12 +79,13 @@ size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
std::size_t ret = 0;
std::hash<std::string> string_hasher;
for (const auto& name_shape : key.feed_shapes_) {
for (const auto& name_shape : key.input_shapes_) {
ret = hash_combine(ret, string_hasher(name_shape.first));
ret = hash_combine(ret, string_hasher(name_shape.second.to_str()));
}
ret = hash_combine(ret, string_hasher(key.graph_serialize_str_));
ret = hash_combine(ret, string_hasher(key.arch_str_));
return ret;
}
......
......@@ -26,24 +26,28 @@ namespace paddle2cinn {
// Class to store the keys for compiling CINN.
//
// CINN cannot handle changable shape now, so CinnRunner keeps a cache mapping
// CINN cannot handle changable shape now, so CinnCompiler keeps a cache mapping
// from CinnCacheKey to CinnCompiledObject.
//
// The CinnCacheKey contains a graph serialized string and the feeded tensor
// The CinnCacheKey contains a graph serialized string and the input tensor
// shapes.
class CinnCacheKey {
public:
CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& feed_tensors);
const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str);
CinnCacheKey(const ir::Graph& graph,
const std::map<std::string, DDim>& feed_shapes);
const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str);
~CinnCacheKey() {}
void SetKey(const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& feed_tensors);
const std::map<std::string, const LoDTensor*>& input_tensors,
const std::string& arch_str);
void SetKey(const ir::Graph& graph,
const std::map<std::string, DDim>& feed_shapes);
const std::map<std::string, DDim>& input_shapes,
const std::string& arch_str);
bool operator==(const CinnCacheKey& other) const;
bool operator!=(const CinnCacheKey& other) const;
......@@ -55,7 +59,8 @@ class CinnCacheKey {
private:
std::string graph_serialize_str_;
std::map<std::string, DDim> feed_shapes_;
std::map<std::string, DDim> input_shapes_;
std::string arch_str_;
};
} // namespace paddle2cinn
......
......@@ -47,17 +47,19 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) {
DDim ddim = paddle::framework::make_ddim({1, 2, 3});
std::map<std::string, DDim> feed_shapes = {{"X", ddim}};
CinnCacheKey cache_key1(empty_graph, feed_tensors);
CinnCacheKey cache_key2(empty_graph, feed_shapes);
EXPECT_EQ(cache_key1, cache_key2);
CinnCacheKey cache_key3(graph, feed_shapes);
CinnCacheKey cache_key4(graph, feed_tensors);
CinnCacheKey cache_key0(empty_graph, feed_tensors, "x86");
CinnCacheKey cache_key1(empty_graph, feed_shapes, "x86");
EXPECT_EQ(cache_key0, cache_key1);
CinnCacheKey cache_key2(graph, feed_shapes, "x86");
CinnCacheKey cache_key3(graph, feed_shapes, "nvgpu");
CinnCacheKey cache_key4(graph, feed_tensors, "nvgpu");
EXPECT_NE(cache_key2, cache_key3);
EXPECT_EQ(cache_key3, cache_key4);
CinnCacheKey cache_key5(empty_graph,
std::map<std::string, const LoDTensor *>());
CinnCacheKey cache_key6(empty_graph, std::map<std::string, DDim>());
std::map<std::string, const LoDTensor *>(), "unk");
CinnCacheKey cache_key6(empty_graph, std::map<std::string, DDim>(), "unk");
EXPECT_EQ(cache_key5, cache_key6);
EXPECT_NE(cache_key1, cache_key3);
......@@ -69,19 +71,19 @@ TEST(CinnCacheKeyTest, TestAsUnorderedKey) {
EXPECT_NE(cache_key5, cache_key1);
EXPECT_NE(cache_key2, cache_key6);
test_set.insert(cache_key0);
test_set.insert(cache_key1);
test_set.insert(cache_key2);
test_set.insert(cache_key3);
test_set.insert(cache_key4);
test_set.insert(cache_key5);
test_set.insert(cache_key6);
EXPECT_EQ(test_set.size(), 3U);
auto iter = test_set.find(cache_key1);
auto iter = test_set.find(cache_key0);
EXPECT_NE(iter, test_set.end());
test_set.erase(iter);
EXPECT_EQ(test_set.size(), 2U);
EXPECT_EQ(test_set.find(cache_key2), test_set.end());
EXPECT_EQ(test_set.find(cache_key1), test_set.end());
iter = test_set.find(cache_key3);
EXPECT_NE(iter, test_set.end());
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h"
#include <map>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
CinnCompiledObject::CinnCompiledObject() {
// TODO(zhhsplendid): complete this function after CINN interface is ready
}
CinnCompiledObject::~CinnCompiledObject() {
// TODO(zhhsplendid): complete this function after CINN interface is ready
}
void CinnCompiledObject::Compile(
const ir::Graph& graph,
std::map<std::string, const LoDTensor*>* feed_targets) {
// TODO(zhhsplendid): complete this function after CINN interface is ready
}
std::map<std::string, FetchType*> CinnCompiledObject::Run(
Scope* scope, std::map<std::string, const LoDTensor*>* feed_targets) {
// TODO(zhhsplendid): complete this function after CINN interface is ready
return std::map<std::string, FetchType*>();
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
// Class to store and call CINN complied object
class CinnCompiledObject {
public:
CinnCompiledObject();
~CinnCompiledObject();
// Compiles use CINN. CINN compilation needs model graph, input names, and
// input_shapes
void Compile(const ir::Graph& graph,
std::map<std::string, const LoDTensor*>* feed_targets);
// Feed LoDTensors to tun CINN compiled object and return fetched result
std::map<std::string, FetchType*> Run(
Scope* scope, std::map<std::string, const LoDTensor*>* feed_targets);
// Converts compiled object to Paddle Graph
// To be discussed
// ir::Graph ToGraph();
};
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <map>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h"
#include "paddle/fluid/framework/program_desc.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
TEST(CinnCompiledObjecctTest, TodoTest) {
ProgramDesc empty_program;
ir::Graph empty_graph(empty_program);
std::map<std::string, const LoDTensor*> empty_feed;
Scope empty_scope;
CinnCompiledObject compiled_obj;
compiled_obj.Compile(empty_graph, &empty_feed);
auto fetch = compiled_obj.Run(&empty_scope, &empty_feed);
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <map>
#include <memory>
#include <string>
#include "cinn/common/target.h"
#include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h"
#include "cinn/frontend/net_builder.h" // need to remove after
#include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h"
#include "cinn/hlir/framework/graph.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using ir::Graph;
using ::cinn::common::Target;
using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::BuildScope;
using ::cinn::frontend::ProgramPass;
using ::cinn::hlir::framework::ApplyPass;
CinnCompiler* CinnCompiler::GetInstance() {
static CinnCompiler instance;
return &instance;
}
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key;
ProgramDesc program;
GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key);
if (!graphs_.count(graph_key)) {
graphs_[graph_key] = std::move(graph);
} else {
LOG(WARNING)
<< "The graph being added is already in CinnCompiler. Its key is:\n"
<< graph_key;
}
return graph_key;
}
const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0,
platform::errors::InvalidArgument("Can not find the target graph: %s",
graph_key.c_str()));
return *graphs_.at(graph_key);
}
const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
if (!cache_.count(cur_key)) {
real_compiled_num_++;
cache_[cur_key] = CompileGraph(graph, input_tensors, target);
}
return *cache_[cur_key];
}
const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
}
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) const {
CinnGraphSymbolization symbol{real_compiled_num_, graph, target,
input_tensors};
auto frontend_program = symbol();
ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target);
VLOG(4) << "The " << real_compiled_num_ << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion");
auto scope = BuildScope(target, cinn_graph);
GraphCompiler graph_compiler(target, scope, cinn_graph);
GraphCompiler::CompileOptions options;
options.with_instantiate_variables = false;
auto compiled_res = graph_compiler.Build(options);
auto compiled_obj = std::make_unique<CinnCompiledObject>();
*compiled_obj = {std::move(compiled_res.runtime_program), scope,
symbol.var_model_to_program_map()};
return compiled_obj;
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
......@@ -14,50 +14,73 @@
#pragma once
#include <atomic>
#include <map>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include "cinn/common/target.h"
#include "cinn/hlir/framework/graph_compiler.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiled_object.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
// Entrance to run CINN.
struct CinnCompiledObject {
std::unique_ptr<::cinn::hlir::framework::Program> runtime_program;
std::shared_ptr<::cinn::hlir::framework::Scope> scope;
std::unordered_map<std::string, std::string> paddle2cinn_varmap;
};
// Entrance to use CINN.
//
// CINN cannot handle changable shape now, so CinnRunner keeps a cache mapping
// CINN cannot handle changable shape now, so CinnCompiler keeps a cache mapping
// from CinnCacheKey to CinnCompiledObject. If cache hits, we will re-use cache
// stored CinnCompiledObject, otherwise we will compile again and put into
// cache.
class CinnRunner {
class CinnCompiler {
public:
~CinnRunner() {}
// Singleton
static std::shared_ptr<CinnRunner> GetInstance();
static CinnCompiler* GetInstance();
const CinnCompiledObject& Compile(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target);
// Replace Paddle graph with some CINN subgraphs/ops
void ReplaceWithCinn(ir::Graph* graph);
const CinnCompiledObject& Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target);
// Feed LoDTensors to tun CINN compiled object and return fetched result
std::map<std::string, FetchType*> Run(
const ir::Graph& graph, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets);
std::string AddGraph(std::unique_ptr<ir::Graph> graph);
const ir::Graph& FindGraph(const std::string& key) const;
std::int64_t real_compiled_num() const { return real_compiled_num_; }
~CinnCompiler() = default;
private:
CinnRunner() {}
CinnCompiler() = default;
std::unique_ptr<CinnCompiledObject> CompileGraph(
const ir::Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const ::cinn::common::Target& target) const;
static std::once_flag get_instance_once_flag_;
static std::shared_ptr<CinnRunner> instance_;
std::unordered_map<CinnCacheKey, std::shared_ptr<CinnCompiledObject>,
std::unordered_map<std::string, std::unique_ptr<ir::Graph>> graphs_;
std::unordered_map<CinnCacheKey, std::unique_ptr<CinnCompiledObject>,
CinnCacheKey::Hash>
cache_;
std::atomic_int64_t real_compiled_num_{0};
DISABLE_COPY_AND_ASSIGN(CinnCompiler);
};
} // namespace paddle2cinn
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <map>
#include <memory>
#include <string>
#include "cinn/common/target.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using ir::Graph;
using ::cinn::common::Target;
// X -
// | -> mul -> MUL_OUT -
// Y - | -> elementwise_add -> ADD_OUT -> relu -> RELU_OUT
// Z -
std::unique_ptr<Graph> CreateGraph() {
ProgramDesc program;
auto* global_block = program.MutableBlock(0);
// mul
auto* x = global_block->Var("X");
x->SetType(proto::VarType::LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(proto::VarType::FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(proto::VarType::LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(proto::VarType::FP32);
y->SetShape({784, 100});
y->SetPersistable(true);
y->SetIsParameter(true);
auto* mul_op = global_block->AppendOp();
mul_op->SetType("mul");
mul_op->SetInput("X", {x->Name()});
mul_op->SetInput("Y", {y->Name()});
auto* mul_out = global_block->Var("MUL_OUT");
mul_out->SetType(proto::VarType::LOD_TENSOR);
mul_op->SetOutput("Out", {mul_out->Name()});
// add
auto* z = global_block->Var("Z");
z->SetType(proto::VarType::LOD_TENSOR);
z->SetLoDLevel(0);
z->SetDataType(proto::VarType::FP32);
z->SetShape({100});
z->SetPersistable(true);
z->SetIsParameter(true);
auto* add_op = global_block->AppendOp();
add_op->SetType("elementwise_add");
add_op->SetInput("X", {mul_out->Name()});
add_op->SetInput("Y", {z->Name()});
auto* add_out = global_block->Var("ADD_OUT");
add_out->SetType(proto::VarType::LOD_TENSOR);
add_op->SetOutput("Out", {add_out->Name()});
// relu
auto* relu_op = global_block->AppendOp();
relu_op->SetType("relu");
relu_op->SetInput("X", {add_out->Name()});
auto* relu_out = global_block->Var("RELU_OUT");
relu_out->SetType(proto::VarType::LOD_TENSOR);
relu_op->SetOutput("Out", {relu_out->Name()});
program.Flush();
return std::make_unique<Graph>(program);
}
TEST(CinnCompilerTest, Compile) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass");
auto viz_graph = [&viz_pass](const std::string& viz_path, Graph* graph) {
viz_pass->Erase("graph_viz_path");
viz_pass->Set("graph_viz_path", new std::string(viz_path));
viz_pass->Apply(graph);
};
// create a graph
auto graph = CreateGraph();
viz_graph("origin_graph.dot", graph.get());
// apply build_cinn_pass
cinn_pass->Apply(graph.get());
viz_graph("processed_graph.dot", graph.get());
// get the compilation_key
std::vector<std::string> compilation_keys;
for (auto& node : graph->Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
}
}
ASSERT_EQ(compilation_keys.size(), 1);
const auto& compilation_key = compilation_keys[0];
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key);
// viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
EXPECT_THROW(cinn_compiler->FindGraph("no_existed"),
paddle::platform::EnforceNotMet);
LoDTensor tensor1, tensor2, tensor3;
tensor1.Resize({1000, 784});
tensor2.Resize({784, 100});
tensor3.Resize({100});
tensor1.mutable_data<float>(platform::CPUPlace());
tensor2.mutable_data<float>(platform::CPUPlace());
tensor3.mutable_data<float>(platform::CPUPlace());
std::map<std::string, const LoDTensor*> input_tensors = {
{"X", &tensor1}, {"Y", &tensor2}, {"Z", &tensor3}};
auto compile_fn = [&](const Target& target) {
const auto& compiled_obj =
cinn_compiler->Compile(compiling_graph, input_tensors, target);
ASSERT_NE(compiled_obj.runtime_program, nullptr);
ASSERT_NE(compiled_obj.scope, nullptr);
ASSERT_FALSE(compiled_obj.paddle2cinn_varmap.empty());
const auto& cached_obj =
cinn_compiler->Compile(compilation_key, input_tensors, target);
ASSERT_EQ(reinterpret_cast<std::uint64_t>(&compiled_obj),
reinterpret_cast<std::uint64_t>(&cached_obj));
};
// GPU Compilation
compile_fn(::cinn::common::DefaultNVGPUTarget());
ASSERT_EQ(cinn_compiler->real_compiled_num(), 1);
// CPU Compilation
compile_fn(::cinn::common::DefaultHostTarget());
ASSERT_EQ(cinn_compiler->real_compiled_num(), 2);
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
USE_PASS(build_cinn_pass);
USE_PASS(graph_viz_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
#include <map>
#include <memory>
#include <mutex>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using ir::Graph;
std::once_flag CinnRunner::get_instance_once_flag_;
std::shared_ptr<CinnRunner> CinnRunner::instance_;
std::shared_ptr<CinnRunner> CinnRunner::GetInstance() {
std::call_once(get_instance_once_flag_,
[&]() { instance_.reset(new CinnRunner()); });
return instance_;
}
void CinnRunner::ReplaceWithCinn(Graph* graph) {
// TODO(zhhsplendid): call CINN Api when it is ready
}
std::map<std::string, FetchType*> CinnRunner::Run(
const Graph& graph, Scope* scope,
std::map<std::string, const LoDTensor*>* feed_targets) {
CinnCacheKey cur_key(graph, *feed_targets);
std::shared_ptr<CinnCompiledObject> obj_to_run;
if (cache_.find(cur_key) != cache_.end()) {
obj_to_run = cache_[cur_key];
} else {
obj_to_run = std::make_shared<CinnCompiledObject>();
obj_to_run->Compile(graph, feed_targets);
cache_[cur_key] = obj_to_run;
}
return obj_to_run->Run(scope, feed_targets);
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/paddle2cinn/cinn_runner.h"
#include <memory>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace paddle2cinn {
using ir::Graph;
TEST(CinnRunnerTest, TodoTest) {
ProgramDesc empty_program;
Graph empty_graph(empty_program);
Scope empty_scope;
std::map<std::string, const LoDTensor*> empty_feed;
std::shared_ptr<CinnRunner> cinn_runner = CinnRunner::GetInstance();
cinn_runner->ReplaceWithCinn(&empty_graph);
cinn_runner->Run(empty_graph, &empty_scope, &empty_feed);
}
} // namespace paddle2cinn
} // namespace framework
} // namespace paddle
......@@ -27,16 +27,18 @@ logger = logging.getLogger(__name__)
def set_cinn_flag(val):
cinn_compiled = False
try:
paddle.set_flags({'FLAGS_use_cinn': val})
cinn_compiled = True
except ValueError:
logger.warning("The used paddle is not compiled with CINN.")
return cinn_compiled
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase):
def test_run_from_cinn(self):
set_cinn_flag(False)
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册