未验证 提交 2479664a 编写于 作者: Z Zhen Wang 提交者: GitHub

Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used...

Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops for controlling op types used in training with CINN. (#36842)

* Update UT test_parallel_executor_run_cinn.py.

* Add FLAGS_allow_cinn_ops & FLAGS_deny_cinn_ops & FLAGS_cinn_ops_delim.

* Use the custom StringSplit function and remove the FLAGS_cinn_ops_delim flag.

* Add FlagController test.

* Apply lock to the cache_ only in CinnCompiler.

* Add VizGraph & ReadableKey method for CinnCompiler.

* Update the dot style of VizGraph in CinnCompiler.
上级 fb394695
...@@ -2,7 +2,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l ...@@ -2,7 +2,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l
cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce) cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce)
cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn)
cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn) cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn)
cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn) cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn)
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler) cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler)
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <regex>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
...@@ -25,6 +26,8 @@ limitations under the License. */ ...@@ -25,6 +26,8 @@ limitations under the License. */
#include "cinn/frontend/op_mapper_registry.h" #include "cinn/frontend/op_mapper_registry.h"
#include "cinn/frontend/op_mappers/use_op_mappers.h" #include "cinn/frontend/op_mappers/use_op_mappers.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
...@@ -34,6 +37,9 @@ limitations under the License. */ ...@@ -34,6 +37,9 @@ limitations under the License. */
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h" #include "paddle/fluid/platform/errors.h"
DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
...@@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set<Node*>; ...@@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set<Node*>;
using GraphNodeMap = std::unordered_map<Node*, Node*>; using GraphNodeMap = std::unordered_map<Node*, Node*>;
namespace { namespace {
// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops
// & FLAGS_deny_cinn_ops.
constexpr char kDelim[] = ";";
std::unordered_set<std::string> StringSplit(const std::string& str,
const std::string& delim) {
std::regex reg(delim);
std::unordered_set<std::string> elems{
std::sregex_token_iterator(str.begin(), str.end(), reg, -1),
std::sregex_token_iterator()};
elems.erase("");
return elems;
}
int ExtractOpRole(const GraphNodeSet& cluster) { int ExtractOpRole(const GraphNodeSet& cluster) {
std::unordered_set<int> op_roles; std::unordered_set<int> op_roles;
std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName(); std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName();
...@@ -339,10 +359,27 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, ...@@ -339,10 +359,27 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster,
// all of op node supported by CINN. We using OpMapperRegistry // all of op node supported by CINN. We using OpMapperRegistry
// to check whether the op node supported by CINN. // to check whether the op node supported by CINN.
void SearchAllSubgraphs(Graph* graph) { void SearchAllSubgraphs(Graph* graph) {
auto teller = [](const Node* node) { auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim);
return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) != auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim);
nullptr; auto teller = [&allow_ops, &deny_ops](const Node* node) {
bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find(
node->Name()) != nullptr;
// if the op type is registered in CINN and allow_ops is not empty, return
// true only when it is in allow_ops
if (allow_ops.size()) {
return registered && allow_ops.count(node->Name());
}
// if the op type is registered in CINN and deny_ops is not empty, return
// true only when it is not in deny_ops
if (deny_ops.size()) {
return registered && !deny_ops.count(node->Name());
}
// if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops,
// return true only when it is registered in CINN
return registered;
}; };
VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops;
VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops;
std::vector<GraphNodeVec> clusters = std::vector<GraphNodeVec> clusters =
framework::ir::SubgraphDetector(graph, teller)(); framework::ir::SubgraphDetector(graph, teller)();
...@@ -375,7 +412,8 @@ void SearchAllSubgraphs(Graph* graph) { ...@@ -375,7 +412,8 @@ void SearchAllSubgraphs(Graph* graph) {
// save it in CinnCompiler // save it in CinnCompiler
std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph(
cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); cluster_set, cluster_internals, cluster_inputs, cluster_outputs));
VLOG(4) << "Compilation Key: " << compilation_key; VLOG(4) << "Compilation Key:\n"
<< cinn_compiler->ReadableKey(compilation_key);
// Replace the found cluster to a new cinn op node // Replace the found cluster to a new cinn op node
ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs, ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs,
......
...@@ -14,14 +14,15 @@ ...@@ -14,14 +14,15 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <iterator>
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include "cinn/common/target.h" #include "cinn/common/target.h"
#include "cinn/common/type.h" #include "cinn/common/type.h"
#include "cinn/frontend/decomposer/use_decomposer.h" #include "cinn/frontend/decomposer/use_decomposer.h"
#include "cinn/frontend/net_builder.h" // need to remove after
#include "cinn/frontend/pass/use_program_pass.h" #include "cinn/frontend/pass/use_program_pass.h"
#include "cinn/frontend/program_pass.h" #include "cinn/frontend/program_pass.h"
#include "cinn/frontend/syntax.h" #include "cinn/frontend/syntax.h"
...@@ -29,19 +30,26 @@ ...@@ -29,19 +30,26 @@
#include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/pass.h" #include "cinn/hlir/framework/pass.h"
#include "cinn/hlir/pass/use_pass.h" #include "cinn/hlir/pass/use_pass.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" #include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/inference/analysis/dot.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
using ir::Graph; using ir::Graph;
using ir::Node;
using inference::analysis::Dot;
using ::cinn::common::Target; using ::cinn::common::Target;
using ::cinn::common::Float; using ::cinn::common::Float;
using ::cinn::hlir::framework::GraphCompiler; using ::cinn::hlir::framework::GraphCompiler;
...@@ -54,47 +62,121 @@ CinnCompiler* CinnCompiler::GetInstance() { ...@@ -54,47 +62,121 @@ CinnCompiler* CinnCompiler::GetInstance() {
return &instance; return &instance;
} }
const CinnCompiledObject& CinnCompiler::Compile(
const Graph& graph,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
CinnCacheKey cur_key(graph, input_tensors, target.arch_str());
bool exist = false;
{
AutoRDLock r_guard{&rwlock_};
exist = cache_.count(cur_key) != 0;
}
if (!exist) {
real_compiled_num_++;
auto compiled_res = CompileGraph(graph, input_tensors, target);
AutoWRLock w_guard{&rwlock_};
if (!cache_.count(cur_key)) {
cache_[cur_key] = std::move(compiled_res);
}
}
AutoRDLock guard{&rwlock_};
const auto& cached_boj = *cache_[cur_key];
return cached_boj;
}
const CinnCompiledObject& CinnCompiler::Compile(
const std::string& compilation_key,
const std::map<std::string, const LoDTensor*>& input_tensors,
const Target& target) {
VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(compilation_key);
const auto& graph = FindGraph(compilation_key);
return Compile(graph, input_tensors, target);
}
std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) { std::string CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
std::string graph_key; std::string graph_key;
ProgramDesc program; ProgramDesc program;
GraphToProgram(*graph, &program); GraphToProgram(*graph, &program);
program.Proto()->SerializeToString(&graph_key); program.Proto()->SerializeToString(&graph_key);
if (!graphs_.count(graph_key)) {
PADDLE_ENFORCE_EQ(
graphs_.count(graph_key), 0,
platform::errors::PreconditionNotMet(
"The graph to be added is already in CinnCompiler, which is:\n",
VizGraph(graph_key).c_str()));
graphs_[graph_key] = std::move(graph); graphs_[graph_key] = std::move(graph);
} else { VLOG(4) << "-- Add a graph into CinnCompiler, which is:\n"
LOG(WARNING) << VizGraph(graph_key);
<< "The graph being added is already in CinnCompiler. Its key is:\n"
<< graph_key;
}
return graph_key; return graph_key;
} }
const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
graphs_.count(graph_key), 0, graphs_.count(graph_key), 0,
platform::errors::InvalidArgument("Can not find the target graph: %s", platform::errors::PreconditionNotMet(
graph_key.c_str())); "Can not find the target graph, of which the key is:\n%s",
ReadableKey(graph_key).c_str()));
return *graphs_.at(graph_key); return *graphs_.at(graph_key);
} }
const CinnCompiledObject& CinnCompiler::Compile( std::string CinnCompiler::VizGraph(const std::string& key) const {
const Graph& graph, Dot dot;
const std::map<std::string, const LoDTensor*>& input_tensors, std::unordered_map<const Node*, std::string> node2dot;
const Target& target) { const Graph& graph = FindGraph(key);
CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); int id = 0;
if (!cache_.count(cur_key)) { // Create nodes
real_compiled_num_++; for (const Node* n : graph.Nodes()) {
cache_[cur_key] = CompileGraph(graph, input_tensors, target); std::string node_id = "Node" + std::to_string(id++);
if (n->IsOp()) {
dot.AddNode(
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")},
n->Name());
} else if (n->IsVar()) {
auto label = n->Name();
if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) {
auto shape = n->Var()->GetShape();
std::vector<std::string> shape_str(shape.size());
std::transform(shape.begin(), shape.end(), shape_str.begin(),
[](const auto& val) { return std::to_string(val); });
label += "\n" + string::join_strings(shape_str, ',');
}
dot.AddNode(
node_id,
{Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"),
Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"),
Dot::Attr("fontcolor",
n->Var()->IsParameter() ? "#ffffff" : "#000000")},
label);
}
node2dot[n] = node_id;
}
// Create edges
for (const Node* n : graph.Nodes()) {
const auto& src_id = node2dot.at(n);
for (auto* out : n->outputs) {
const auto& dest_id = node2dot.at(out);
dot.AddEdge(src_id, dest_id, {});
} }
return *cache_[cur_key]; }
return dot.Build();
} }
const CinnCompiledObject& CinnCompiler::Compile( std::string CinnCompiler::ReadableKey(const std::string& key) const {
const std::string& compilation_key, proto::ProgramDesc desc;
const std::map<std::string, const LoDTensor*>& input_tensors, desc.ParseFromString(key);
const Target& target) { return desc.DebugString();
const auto& graph = FindGraph(compilation_key); }
return Compile(graph, input_tensors, target);
void CinnCompiler::Clear() {
{
AutoWRLock guard{&rwlock_};
graphs_.clear();
cache_.clear();
}
real_compiled_num_ = 0;
} }
std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
...@@ -107,7 +189,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph( ...@@ -107,7 +189,7 @@ std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); ProgramPass::Apply(&frontend_program, target, {"Decomposer"});
auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>(
frontend_program, target); frontend_program, target);
VLOG(4) << "The " << real_compiled_num_ << "-th compilation (" VLOG(4) << "-- The " << real_compiled_num_ << "-th compilation ("
<< target.arch_str() << "), and its related graph:\n" << target.arch_str() << "), and its related graph:\n"
<< cinn_graph->Visualize(); << cinn_graph->Visualize();
ApplyPass(cinn_graph.get(), "OpFusion"); ApplyPass(cinn_graph.get(), "OpFusion");
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h"
#include "paddle/fluid/framework/rw_lock.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
...@@ -64,6 +65,12 @@ class CinnCompiler { ...@@ -64,6 +65,12 @@ class CinnCompiler {
const ir::Graph& FindGraph(const std::string& key) const; const ir::Graph& FindGraph(const std::string& key) const;
std::string VizGraph(const std::string& key) const;
std::string ReadableKey(const std::string& key) const;
void Clear();
std::int64_t real_compiled_num() const { return real_compiled_num_; } std::int64_t real_compiled_num() const { return real_compiled_num_; }
~CinnCompiler() = default; ~CinnCompiler() = default;
...@@ -80,6 +87,7 @@ class CinnCompiler { ...@@ -80,6 +87,7 @@ class CinnCompiler {
CinnCacheKey::Hash> CinnCacheKey::Hash>
cache_; cache_;
std::atomic_int64_t real_compiled_num_{0}; std::atomic_int64_t real_compiled_num_{0};
mutable RWLock rwlock_;
DISABLE_COPY_AND_ASSIGN(CinnCompiler); DISABLE_COPY_AND_ASSIGN(CinnCompiler);
}; };
......
...@@ -14,12 +14,20 @@ ...@@ -14,12 +14,20 @@
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include <algorithm>
#include <map> #include <map>
#include <memory> #include <memory>
#include <ostream>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "cinn/common/target.h" #include "cinn/common/target.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -29,13 +37,76 @@ ...@@ -29,13 +37,76 @@
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
DECLARE_string(allow_cinn_ops);
DECLARE_string(deny_cinn_ops);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
using ir::Graph; using ir::Graph;
using ::cinn::common::Target; using ::cinn::common::Target;
namespace {
template <typename T, typename Alloc = std::allocator<T>>
std::ostream& operator<<(std::ostream& os, const std::vector<T, Alloc>& vec) {
os << "{ ";
for (auto e : vec) {
os << e << " ";
}
os << "}\n";
return os;
}
// Get compilation_key values
std::vector<std::string> GetCompilationKeys(const Graph& graph) {
std::vector<std::string> compilation_keys;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
}
}
return compilation_keys;
}
// Extract op types from a graph
std::unordered_set<std::string> ExtractOpTypes(const Graph& graph) {
std::unordered_set<std::string> op_types;
for (auto& node : graph.Nodes()) {
if (node->IsOp()) {
op_types.emplace(node->Name());
}
}
return op_types;
}
// Get inputs info
std::unordered_map<std::string, std::vector<int64_t>> GetInputsInfo(
const std::string& key, const Graph& graph) {
std::unordered_set<std::string> inputs;
for (auto& node : graph.Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) !=
key) {
continue;
}
for (auto in_var_name : node->Op()->InputArgumentNames()) {
VLOG(4) << "get an input name: " << in_var_name;
inputs.emplace(in_var_name);
}
}
}
std::unordered_map<std::string, std::vector<int64_t>> inputs_info;
for (auto& node : graph.Nodes()) {
if (node->IsVar() && inputs.count(node->Name())) {
VLOG(4) << node->Name() << " : " << node->Var()->GetShape();
inputs_info.emplace(node->Name(), node->Var()->GetShape());
}
}
return inputs_info;
}
// X - // X -
// | -> mul -> MUL_OUT - // | -> mul -> MUL_OUT -
// Y - | -> elementwise_add -> ADD_OUT -> relu -> RELU_OUT // Y - | -> elementwise_add -> ADD_OUT -> relu -> RELU_OUT
...@@ -65,6 +136,9 @@ std::unique_ptr<Graph> CreateGraph() { ...@@ -65,6 +136,9 @@ std::unique_ptr<Graph> CreateGraph() {
auto* mul_out = global_block->Var("MUL_OUT"); auto* mul_out = global_block->Var("MUL_OUT");
mul_out->SetType(proto::VarType::LOD_TENSOR); mul_out->SetType(proto::VarType::LOD_TENSOR);
mul_out->SetLoDLevel(0);
mul_out->SetDataType(proto::VarType::FP32);
mul_out->SetShape({1000, 100});
mul_op->SetOutput("Out", {mul_out->Name()}); mul_op->SetOutput("Out", {mul_out->Name()});
// add // add
...@@ -83,6 +157,9 @@ std::unique_ptr<Graph> CreateGraph() { ...@@ -83,6 +157,9 @@ std::unique_ptr<Graph> CreateGraph() {
auto* add_out = global_block->Var("ADD_OUT"); auto* add_out = global_block->Var("ADD_OUT");
add_out->SetType(proto::VarType::LOD_TENSOR); add_out->SetType(proto::VarType::LOD_TENSOR);
add_out->SetLoDLevel(0);
add_out->SetDataType(proto::VarType::FP32);
add_out->SetShape({1000, 100});
add_op->SetOutput("Out", {add_out->Name()}); add_op->SetOutput("Out", {add_out->Name()});
// relu // relu
...@@ -92,11 +169,59 @@ std::unique_ptr<Graph> CreateGraph() { ...@@ -92,11 +169,59 @@ std::unique_ptr<Graph> CreateGraph() {
auto* relu_out = global_block->Var("RELU_OUT"); auto* relu_out = global_block->Var("RELU_OUT");
relu_out->SetType(proto::VarType::LOD_TENSOR); relu_out->SetType(proto::VarType::LOD_TENSOR);
relu_out->SetLoDLevel(0);
relu_out->SetDataType(proto::VarType::FP32);
relu_out->SetShape({1000, 100});
relu_op->SetOutput("Out", {relu_out->Name()}); relu_op->SetOutput("Out", {relu_out->Name()});
program.Flush(); program.Flush();
return std::make_unique<Graph>(program); return std::make_unique<Graph>(program);
} }
} // namespace
TEST(CinnCompilerTest, FlagController) {
// init
auto* cinn_compiler = CinnCompiler::GetInstance();
auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass");
// apply build_cinn_pass & FLAGS_allow_cinn_ops="add"
{
FLAGS_allow_cinn_ops = "add";
auto graph = CreateGraph();
cinn_compiler->Clear();
cinn_pass->Apply(graph.get());
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 0);
}
// apply build_cinn_pass & FLAGS_allow_cinn_ops="mul;relu"
{
FLAGS_allow_cinn_ops = "mul;relu";
auto graph = CreateGraph();
cinn_compiler->Clear();
cinn_pass->Apply(graph.get());
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 2);
}
// apply build_cinn_pass & FLAGS_allow_cinn_ops="" &
// FLAGS_deny_cinn_ops="relu"
{
FLAGS_allow_cinn_ops = "";
FLAGS_deny_cinn_ops = "elementwise_add;relu";
auto graph = CreateGraph();
cinn_compiler->Clear();
cinn_pass->Apply(graph.get());
auto compilation_keys = GetCompilationKeys(*graph);
ASSERT_EQ(compilation_keys.size(), 1);
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_keys[0]);
auto op_types = ExtractOpTypes(compiling_graph);
ASSERT_EQ(op_types.size(), 2);
ASSERT_EQ(op_types.count("feed"), 1);
ASSERT_EQ(op_types.count("mul"), 1);
}
// recover flags
FLAGS_allow_cinn_ops = "";
FLAGS_deny_cinn_ops = "";
}
TEST(CinnCompilerTest, Compile) { TEST(CinnCompilerTest, Compile) {
auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass");
auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass"); auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass");
...@@ -113,32 +238,31 @@ TEST(CinnCompilerTest, Compile) { ...@@ -113,32 +238,31 @@ TEST(CinnCompilerTest, Compile) {
cinn_pass->Apply(graph.get()); cinn_pass->Apply(graph.get());
viz_graph("processed_graph.dot", graph.get()); viz_graph("processed_graph.dot", graph.get());
// get the compilation_key // get the compilation_key
std::vector<std::string> compilation_keys; auto compilation_keys = GetCompilationKeys(*graph);
for (auto& node : graph->Nodes()) {
if (node->IsOp() && node->Name() == kCinnLaunchOp) {
compilation_keys.emplace_back(
BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)));
}
}
ASSERT_EQ(compilation_keys.size(), 1); ASSERT_EQ(compilation_keys.size(), 1);
const auto& compilation_key = compilation_keys[0]; const auto& compilation_key = compilation_keys[0];
auto* cinn_compiler = CinnCompiler::GetInstance(); auto* cinn_compiler = CinnCompiler::GetInstance();
VLOG(4) << "The graph to be compiled:\n"
<< cinn_compiler->VizGraph(compilation_key);
const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key); const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key);
// viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph)); viz_graph("compiling_graph.dot", const_cast<Graph*>(&compiling_graph));
EXPECT_THROW(cinn_compiler->FindGraph("no_existed"), EXPECT_THROW(cinn_compiler->FindGraph("no_existed"),
paddle::platform::EnforceNotMet); paddle::platform::EnforceNotMet);
LoDTensor tensor1, tensor2, tensor3; auto inputs_info = GetInputsInfo(compilation_key, *graph);
tensor1.Resize({1000, 784}); std::unordered_map<std::string, LoDTensor> create_inputs;
tensor2.Resize({784, 100}); for (const auto& pair : inputs_info) {
tensor3.Resize({100}); auto& tensor = create_inputs[pair.first];
tensor1.mutable_data<float>(platform::CPUPlace()); tensor.Resize(make_ddim(pair.second));
tensor2.mutable_data<float>(platform::CPUPlace()); tensor.mutable_data<float>(platform::CPUPlace());
tensor3.mutable_data<float>(platform::CPUPlace()); }
std::map<std::string, const LoDTensor*> input_tensors = { std::map<std::string, const LoDTensor*> input_tensors;
{"X", &tensor1}, {"Y", &tensor2}, {"Z", &tensor3}}; std::for_each(create_inputs.begin(), create_inputs.end(),
[&input_tensors](const auto& val) {
input_tensors.emplace(val.first, &val.second);
});
auto compile_fn = [&](const Target& target) { auto compile_fn = [&](const Target& target) {
const auto& compiled_obj = const auto& compiled_obj =
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/operators/cinn_launch_op.h" #include "paddle/fluid/operators/cinn_launch_op.h"
#include "cinn/frontend/var_type_utils.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
namespace paddle { namespace paddle {
......
...@@ -98,7 +98,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -98,7 +98,8 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
const auto& compilation_key = const auto& compilation_key =
ctx.template Attr<std::string>(kCompilationKey); ctx.template Attr<std::string>(kCompilationKey);
VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") " VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") "
<< "value:" << compilation_key; << "value:\n"
<< CinnCompiler::GetInstance()->ReadableKey(compilation_key);
const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key); const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto input_variable_names = ctx.InputNames(kX); auto input_variable_names = ctx.InputNames(kX);
......
...@@ -710,6 +710,7 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, ...@@ -710,6 +710,7 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
"events. Currently, only fuse allreduce supports " "events. Currently, only fuse allreduce supports "
"this. Otherwise, the precision may be wrong."); "this. Otherwise, the precision may be wrong.");
#ifdef PADDLE_WITH_CINN
/* /*
* CINN related FLAG * CINN related FLAG
* Name: FLAGS_use_cinn * Name: FLAGS_use_cinn
...@@ -717,9 +718,31 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, ...@@ -717,9 +718,31 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false,
* Value Range: bool, default=false * Value Range: bool, default=false
* Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN * Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN
*/ */
#ifdef PADDLE_WITH_CINN
PADDLE_DEFINE_EXPORTED_bool( PADDLE_DEFINE_EXPORTED_bool(
use_cinn, false, "It controls whether to run PaddlePaddle using CINN"); use_cinn, false, "It controls whether to run PaddlePaddle using CINN");
/*
* CINN related FLAG
* Name: FLAGS_allow_cinn_ops
* Since Version: 2.3
* Value Range: string, default=""
* Example: FLAGS_allow_cinn_ops="mul;relu" would only cover `mul` and `relu`
* when using CINN
*/
PADDLE_DEFINE_EXPORTED_string(allow_cinn_ops, "",
"It controls the cinn op subset to be used, "
"which has the highest priority.");
/*
* CINN related FLAG
* Name: FLAGS_deny_cinn_ops
* Since Version: 2.3
* Value Range: string, default=""
* Example: FLAGS_deny_cinn_ops="mul;relu" would block `mul` and `relu` two ops
* when using CINN
*/
PADDLE_DEFINE_EXPORTED_string(deny_cinn_ops, "",
"It controls the cinn op subset to be not used.");
#endif #endif
DEFINE_int32(record_pool_max_size, 2000000, DEFINE_int32(record_pool_max_size, 2000000,
......
...@@ -40,9 +40,9 @@ def set_cinn_flag(val): ...@@ -40,9 +40,9 @@ def set_cinn_flag(val):
def reader(limit): def reader(limit):
for i in range(limit): for _ in range(limit):
yield np.ones([1, 28]).astype('float32') * (i * 3.14 / (i + 1)), \ yield np.random.random([1, 28]).astype('float32'), \
np.array([i + 1]).astype('int64') np.random.randint(0, 2, size=[1]).astype('int64')
def rand_data(img, label, loop_num=10): def rand_data(img, label, loop_num=10):
...@@ -62,7 +62,7 @@ def build_program(main_program, startup_program): ...@@ -62,7 +62,7 @@ def build_program(main_program, startup_program):
shape=[1, 28], shape=[1, 28],
dtype="float32", dtype="float32",
attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign( attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign(
np.ones([1, 28]).astype(np.float32)))) np.random.rand(1, 28).astype(np.float32))))
label = paddle.static.data(name="label", shape=[1], dtype='int64') label = paddle.static.data(name="label", shape=[1], dtype='int64')
hidden = paddle.add(img, param) hidden = paddle.add(img, param)
...@@ -75,7 +75,12 @@ def build_program(main_program, startup_program): ...@@ -75,7 +75,12 @@ def build_program(main_program, startup_program):
return img, label, avg_loss return img, label, avg_loss
def do_test(dot_save_dir): def train(dot_save_dir, prefix, seed=1234):
np.random.seed(seed)
paddle.seed(seed)
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
main_program = paddle.static.Program() main_program = paddle.static.Program()
img, label, loss = build_program(main_program, startup_program) img, label, loss = build_program(main_program, startup_program)
...@@ -86,32 +91,35 @@ def do_test(dot_save_dir): ...@@ -86,32 +91,35 @@ def do_test(dot_save_dir):
exe.run(startup_program) exe.run(startup_program)
build_strategy = paddle.static.BuildStrategy() build_strategy = paddle.static.BuildStrategy()
build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, "viz") build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, prefix)
compiled_program = paddle.static.CompiledProgram( compiled_program = paddle.static.CompiledProgram(
main_program, build_strategy).with_data_parallel(loss_name=loss.name) main_program, build_strategy).with_data_parallel(loss_name=loss.name)
iters = 1 iters = 100
feed = rand_data(img.name, label.name, iters) feed = rand_data(img.name, label.name, iters)
loss_values = []
for step in range(iters): for step in range(iters):
loss_v = exe.run(compiled_program, loss_v = exe.run(compiled_program,
feed=feed[step], feed=feed[step],
fetch_list=[loss], fetch_list=[loss],
return_merged=False) return_merged=False)
logger.info("loss value = {}".format(loss_v)) loss_values.append(loss_v[0][0][0])
return loss_values
@unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.") @unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.")
class TestParallelExecutorRunCinn(unittest.TestCase): class TestParallelExecutorRunCinn(unittest.TestCase):
def setUp(self): def setUp(self):
set_cinn_flag(True)
self.tmpdir = tempfile.mkdtemp(prefix="dots_") self.tmpdir = tempfile.mkdtemp(prefix="dots_")
def tearDown(self): def tearDown(self):
set_cinn_flag(False)
shutil.rmtree(self.tmpdir) shutil.rmtree(self.tmpdir)
def test_run_with_cinn(self): def test_run_with_cinn(self):
do_test(self.tmpdir) cinn_losses = train(self.tmpdir, "paddle")
set_cinn_flag(False)
pd_losses = train(self.tmpdir, "cinn")
self.assertTrue(np.allclose(cinn_losses, pd_losses, atol=1e-5))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册