diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index e5dac1aa6292d47ad1a2e286d3123a135fc8fc1b..6eef1a00e1e730fcf1552352dffbb8f75feb3ab1 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -2,7 +2,7 @@ cc_library(cinn_cache_key SRCS cinn_cache_key.cc DEPS boost graph graph_helper l cc_library(build_cinn_pass SRCS build_cinn_pass.cc DEPS pass subgraph_detector graph_pattern_detector cinn_compiler errors enforce) cc_library(transform_desc SRCS transform_desc.cc DEPS proto_desc cinn) cc_library(cinn_graph_symbolization SRCS cinn_graph_symbolization.cc DEPS lod_tensor graph transform_desc cinn) -cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn) +cc_library(cinn_compiler SRCS cinn_compiler.cc DEPS framework_proto graph lod_tensor cinn_cache_key cinn_graph_symbolization cinn) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler) diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index 173ba55fd9d1aebed472d294c58c73041d935847..b90dbd7dcd845e8785e7f504f41ce5a81c97046d 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include @@ -25,6 +26,8 @@ limitations under the License. */ #include "cinn/frontend/op_mapper_registry.h" #include "cinn/frontend/op_mappers/use_op_mappers.h" +#include "gflags/gflags.h" +#include "glog/logging.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" @@ -34,6 +37,9 @@ limitations under the License. */ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" +DECLARE_string(allow_cinn_ops); +DECLARE_string(deny_cinn_ops); + namespace paddle { namespace framework { namespace paddle2cinn { @@ -46,6 +52,20 @@ using GraphNodeSet = std::unordered_set; using GraphNodeMap = std::unordered_map; namespace { +// The delim(`;`) that is used to split the FLAGS_allow_cinn_ops +// & FLAGS_deny_cinn_ops. +constexpr char kDelim[] = ";"; + +std::unordered_set StringSplit(const std::string& str, + const std::string& delim) { + std::regex reg(delim); + std::unordered_set elems{ + std::sregex_token_iterator(str.begin(), str.end(), reg, -1), + std::sregex_token_iterator()}; + elems.erase(""); + return elems; +} + int ExtractOpRole(const GraphNodeSet& cluster) { std::unordered_set op_roles; std::string attr_name = OpProtoAndCheckerMaker::OpRoleAttrName(); @@ -339,10 +359,27 @@ void ReplaceSubGraphWithCinnOpNode(const GraphNodeSet& cluster, // all of op node supported by CINN. We using OpMapperRegistry // to check whether the op node supported by CINN. void SearchAllSubgraphs(Graph* graph) { - auto teller = [](const Node* node) { - return ::cinn::frontend::OpMapperRegistry::Global()->Find(node->Name()) != - nullptr; + auto allow_ops = StringSplit(FLAGS_allow_cinn_ops, kDelim); + auto deny_ops = StringSplit(FLAGS_deny_cinn_ops, kDelim); + auto teller = [&allow_ops, &deny_ops](const Node* node) { + bool registered = ::cinn::frontend::OpMapperRegistry::Global()->Find( + node->Name()) != nullptr; + // if the op type is registered in CINN and allow_ops is not empty, return + // true only when it is in allow_ops + if (allow_ops.size()) { + return registered && allow_ops.count(node->Name()); + } + // if the op type is registered in CINN and deny_ops is not empty, return + // true only when it is not in deny_ops + if (deny_ops.size()) { + return registered && !deny_ops.count(node->Name()); + } + // if the user doesn't set FLAGS_allow_cinn_ops and FLAGS_deny_cinn_ops, + // return true only when it is registered in CINN + return registered; }; + VLOG(4) << "The allowed Cinn Ops: " << FLAGS_allow_cinn_ops; + VLOG(4) << "The denied Cinn Ops: " << FLAGS_deny_cinn_ops; std::vector clusters = framework::ir::SubgraphDetector(graph, teller)(); @@ -375,7 +412,8 @@ void SearchAllSubgraphs(Graph* graph) { // save it in CinnCompiler std::string compilation_key = cinn_compiler->AddGraph(CreateNewSubGraph( cluster_set, cluster_internals, cluster_inputs, cluster_outputs)); - VLOG(4) << "Compilation Key: " << compilation_key; + VLOG(4) << "Compilation Key:\n" + << cinn_compiler->ReadableKey(compilation_key); // Replace the found cluster to a new cinn op node ReplaceSubGraphWithCinnOpNode(cluster_set, cluster_inputs, cluster_outputs, diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc index bcff92ec18eda7045b018cfe6bc4df550eafd2f9..f9c28f42776906de2aec4313a7024396dbb24578 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.cc @@ -14,14 +14,15 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include #include #include #include +#include #include "cinn/common/target.h" #include "cinn/common/type.h" #include "cinn/frontend/decomposer/use_decomposer.h" -#include "cinn/frontend/net_builder.h" // need to remove after #include "cinn/frontend/pass/use_program_pass.h" #include "cinn/frontend/program_pass.h" #include "cinn/frontend/syntax.h" @@ -29,19 +30,26 @@ #include "cinn/hlir/framework/graph_compiler.h" #include "cinn/hlir/framework/pass.h" #include "cinn/hlir/pass/use_pass.h" +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/inference/analysis/dot.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/string/string_helper.h" namespace paddle { namespace framework { namespace paddle2cinn { using ir::Graph; +using ir::Node; +using inference::analysis::Dot; using ::cinn::common::Target; using ::cinn::common::Float; using ::cinn::hlir::framework::GraphCompiler; @@ -54,47 +62,121 @@ CinnCompiler* CinnCompiler::GetInstance() { return &instance; } +const CinnCompiledObject& CinnCompiler::Compile( + const Graph& graph, + const std::map& input_tensors, + const Target& target) { + CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); + bool exist = false; + { + AutoRDLock r_guard{&rwlock_}; + exist = cache_.count(cur_key) != 0; + } + if (!exist) { + real_compiled_num_++; + auto compiled_res = CompileGraph(graph, input_tensors, target); + AutoWRLock w_guard{&rwlock_}; + if (!cache_.count(cur_key)) { + cache_[cur_key] = std::move(compiled_res); + } + } + AutoRDLock guard{&rwlock_}; + const auto& cached_boj = *cache_[cur_key]; + return cached_boj; +} + +const CinnCompiledObject& CinnCompiler::Compile( + const std::string& compilation_key, + const std::map& input_tensors, + const Target& target) { + VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(compilation_key); + const auto& graph = FindGraph(compilation_key); + return Compile(graph, input_tensors, target); +} + std::string CinnCompiler::AddGraph(std::unique_ptr graph) { std::string graph_key; ProgramDesc program; GraphToProgram(*graph, &program); program.Proto()->SerializeToString(&graph_key); - if (!graphs_.count(graph_key)) { - graphs_[graph_key] = std::move(graph); - } else { - LOG(WARNING) - << "The graph being added is already in CinnCompiler. Its key is:\n" - << graph_key; - } + + PADDLE_ENFORCE_EQ( + graphs_.count(graph_key), 0, + platform::errors::PreconditionNotMet( + "The graph to be added is already in CinnCompiler, which is:\n", + VizGraph(graph_key).c_str())); + graphs_[graph_key] = std::move(graph); + VLOG(4) << "-- Add a graph into CinnCompiler, which is:\n" + << VizGraph(graph_key); return graph_key; } const Graph& CinnCompiler::FindGraph(const std::string& graph_key) const { PADDLE_ENFORCE_NE( graphs_.count(graph_key), 0, - platform::errors::InvalidArgument("Can not find the target graph: %s", - graph_key.c_str())); + platform::errors::PreconditionNotMet( + "Can not find the target graph, of which the key is:\n%s", + ReadableKey(graph_key).c_str())); return *graphs_.at(graph_key); } -const CinnCompiledObject& CinnCompiler::Compile( - const Graph& graph, - const std::map& input_tensors, - const Target& target) { - CinnCacheKey cur_key(graph, input_tensors, target.arch_str()); - if (!cache_.count(cur_key)) { - real_compiled_num_++; - cache_[cur_key] = CompileGraph(graph, input_tensors, target); +std::string CinnCompiler::VizGraph(const std::string& key) const { + Dot dot; + std::unordered_map node2dot; + const Graph& graph = FindGraph(key); + int id = 0; + // Create nodes + for (const Node* n : graph.Nodes()) { + std::string node_id = "Node" + std::to_string(id++); + if (n->IsOp()) { + dot.AddNode( + node_id, + {Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"), + Dot::Attr("color", "#303A3A"), Dot::Attr("fontcolor", "#ffffff")}, + n->Name()); + } else if (n->IsVar()) { + auto label = n->Name(); + if (n->Var() && n->Var()->GetType() == proto::VarType::LOD_TENSOR) { + auto shape = n->Var()->GetShape(); + std::vector shape_str(shape.size()); + std::transform(shape.begin(), shape.end(), shape_str.begin(), + [](const auto& val) { return std::to_string(val); }); + label += "\n" + string::join_strings(shape_str, ','); + } + dot.AddNode( + node_id, + {Dot::Attr("shape", "box"), Dot::Attr("style", "rounded,filled,bold"), + Dot::Attr("color", n->Var()->IsParameter() ? "#148b97" : "#dddddd"), + Dot::Attr("fontcolor", + n->Var()->IsParameter() ? "#ffffff" : "#000000")}, + label); + } + node2dot[n] = node_id; + } + // Create edges + for (const Node* n : graph.Nodes()) { + const auto& src_id = node2dot.at(n); + for (auto* out : n->outputs) { + const auto& dest_id = node2dot.at(out); + dot.AddEdge(src_id, dest_id, {}); + } } - return *cache_[cur_key]; + return dot.Build(); } -const CinnCompiledObject& CinnCompiler::Compile( - const std::string& compilation_key, - const std::map& input_tensors, - const Target& target) { - const auto& graph = FindGraph(compilation_key); - return Compile(graph, input_tensors, target); +std::string CinnCompiler::ReadableKey(const std::string& key) const { + proto::ProgramDesc desc; + desc.ParseFromString(key); + return desc.DebugString(); +} + +void CinnCompiler::Clear() { + { + AutoWRLock guard{&rwlock_}; + graphs_.clear(); + cache_.clear(); + } + real_compiled_num_ = 0; } std::unique_ptr CinnCompiler::CompileGraph( @@ -107,7 +189,7 @@ std::unique_ptr CinnCompiler::CompileGraph( ProgramPass::Apply(&frontend_program, target, {"Decomposer"}); auto cinn_graph = std::make_shared<::cinn::hlir::framework::Graph>( frontend_program, target); - VLOG(4) << "The " << real_compiled_num_ << "-th compilation (" + VLOG(4) << "-- The " << real_compiled_num_ << "-th compilation (" << target.arch_str() << "), and its related graph:\n" << cinn_graph->Visualize(); ApplyPass(cinn_graph.get(), "OpFusion"); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h index 0d6935849696b6ebbe5ee655054e7456b22c5469..3996c62cb943ec32a07db4ba6502f79689527ca3 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler.h +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler.h @@ -25,6 +25,7 @@ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/paddle2cinn/cinn_cache_key.h" +#include "paddle/fluid/framework/rw_lock.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/platform/macros.h" @@ -64,6 +65,12 @@ class CinnCompiler { const ir::Graph& FindGraph(const std::string& key) const; + std::string VizGraph(const std::string& key) const; + + std::string ReadableKey(const std::string& key) const; + + void Clear(); + std::int64_t real_compiled_num() const { return real_compiled_num_; } ~CinnCompiler() = default; @@ -80,6 +87,7 @@ class CinnCompiler { CinnCacheKey::Hash> cache_; std::atomic_int64_t real_compiled_num_{0}; + mutable RWLock rwlock_; DISABLE_COPY_AND_ASSIGN(CinnCompiler); }; diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index 22792e0f8c359aa5245f04435fe2b3ca428a48d9..145d3d83d45099c7ef8331efd919d67a110a1c75 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -14,12 +14,20 @@ #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" +#include #include #include +#include #include +#include +#include +#include #include "cinn/common/target.h" +#include "gflags/gflags.h" +#include "glog/logging.h" #include "gtest/gtest.h" +#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -29,13 +37,76 @@ #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" +DECLARE_string(allow_cinn_ops); +DECLARE_string(deny_cinn_ops); + namespace paddle { namespace framework { namespace paddle2cinn { - using ir::Graph; using ::cinn::common::Target; +namespace { +template > +std::ostream& operator<<(std::ostream& os, const std::vector& vec) { + os << "{ "; + for (auto e : vec) { + os << e << " "; + } + os << "}\n"; + return os; +} + +// Get compilation_key values +std::vector GetCompilationKeys(const Graph& graph) { + std::vector compilation_keys; + for (auto& node : graph.Nodes()) { + if (node->IsOp() && node->Name() == kCinnLaunchOp) { + compilation_keys.emplace_back( + BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); + } + } + return compilation_keys; +} + +// Extract op types from a graph +std::unordered_set ExtractOpTypes(const Graph& graph) { + std::unordered_set op_types; + for (auto& node : graph.Nodes()) { + if (node->IsOp()) { + op_types.emplace(node->Name()); + } + } + return op_types; +} + +// Get inputs info +std::unordered_map> GetInputsInfo( + const std::string& key, const Graph& graph) { + std::unordered_set inputs; + for (auto& node : graph.Nodes()) { + if (node->IsOp() && node->Name() == kCinnLaunchOp) { + if (BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey)) != + key) { + continue; + } + for (auto in_var_name : node->Op()->InputArgumentNames()) { + VLOG(4) << "get an input name: " << in_var_name; + inputs.emplace(in_var_name); + } + } + } + + std::unordered_map> inputs_info; + for (auto& node : graph.Nodes()) { + if (node->IsVar() && inputs.count(node->Name())) { + VLOG(4) << node->Name() << " : " << node->Var()->GetShape(); + inputs_info.emplace(node->Name(), node->Var()->GetShape()); + } + } + return inputs_info; +} + // X - // | -> mul -> MUL_OUT - // Y - | -> elementwise_add -> ADD_OUT -> relu -> RELU_OUT @@ -65,6 +136,9 @@ std::unique_ptr CreateGraph() { auto* mul_out = global_block->Var("MUL_OUT"); mul_out->SetType(proto::VarType::LOD_TENSOR); + mul_out->SetLoDLevel(0); + mul_out->SetDataType(proto::VarType::FP32); + mul_out->SetShape({1000, 100}); mul_op->SetOutput("Out", {mul_out->Name()}); // add @@ -83,6 +157,9 @@ std::unique_ptr CreateGraph() { auto* add_out = global_block->Var("ADD_OUT"); add_out->SetType(proto::VarType::LOD_TENSOR); + add_out->SetLoDLevel(0); + add_out->SetDataType(proto::VarType::FP32); + add_out->SetShape({1000, 100}); add_op->SetOutput("Out", {add_out->Name()}); // relu @@ -92,11 +169,59 @@ std::unique_ptr CreateGraph() { auto* relu_out = global_block->Var("RELU_OUT"); relu_out->SetType(proto::VarType::LOD_TENSOR); + relu_out->SetLoDLevel(0); + relu_out->SetDataType(proto::VarType::FP32); + relu_out->SetShape({1000, 100}); relu_op->SetOutput("Out", {relu_out->Name()}); program.Flush(); return std::make_unique(program); } +} // namespace + +TEST(CinnCompilerTest, FlagController) { + // init + auto* cinn_compiler = CinnCompiler::GetInstance(); + auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass"); + // apply build_cinn_pass & FLAGS_allow_cinn_ops="add" + { + FLAGS_allow_cinn_ops = "add"; + auto graph = CreateGraph(); + cinn_compiler->Clear(); + cinn_pass->Apply(graph.get()); + auto compilation_keys = GetCompilationKeys(*graph); + ASSERT_EQ(compilation_keys.size(), 0); + } + // apply build_cinn_pass & FLAGS_allow_cinn_ops="mul;relu" + { + FLAGS_allow_cinn_ops = "mul;relu"; + auto graph = CreateGraph(); + cinn_compiler->Clear(); + cinn_pass->Apply(graph.get()); + auto compilation_keys = GetCompilationKeys(*graph); + ASSERT_EQ(compilation_keys.size(), 2); + } + // apply build_cinn_pass & FLAGS_allow_cinn_ops="" & + // FLAGS_deny_cinn_ops="relu" + { + FLAGS_allow_cinn_ops = ""; + FLAGS_deny_cinn_ops = "elementwise_add;relu"; + auto graph = CreateGraph(); + cinn_compiler->Clear(); + cinn_pass->Apply(graph.get()); + auto compilation_keys = GetCompilationKeys(*graph); + ASSERT_EQ(compilation_keys.size(), 1); + const auto& compiling_graph = cinn_compiler->FindGraph(compilation_keys[0]); + auto op_types = ExtractOpTypes(compiling_graph); + ASSERT_EQ(op_types.size(), 2); + ASSERT_EQ(op_types.count("feed"), 1); + ASSERT_EQ(op_types.count("mul"), 1); + } + // recover flags + FLAGS_allow_cinn_ops = ""; + FLAGS_deny_cinn_ops = ""; +} + TEST(CinnCompilerTest, Compile) { auto viz_pass = ir::PassRegistry::Instance().Get("graph_viz_pass"); auto cinn_pass = ir::PassRegistry::Instance().Get("build_cinn_pass"); @@ -113,32 +238,31 @@ TEST(CinnCompilerTest, Compile) { cinn_pass->Apply(graph.get()); viz_graph("processed_graph.dot", graph.get()); // get the compilation_key - std::vector compilation_keys; - for (auto& node : graph->Nodes()) { - if (node->IsOp() && node->Name() == kCinnLaunchOp) { - compilation_keys.emplace_back( - BOOST_GET_CONST(std::string, node->Op()->GetAttr(kCompilationKey))); - } - } + auto compilation_keys = GetCompilationKeys(*graph); ASSERT_EQ(compilation_keys.size(), 1); const auto& compilation_key = compilation_keys[0]; auto* cinn_compiler = CinnCompiler::GetInstance(); + VLOG(4) << "The graph to be compiled:\n" + << cinn_compiler->VizGraph(compilation_key); const auto& compiling_graph = cinn_compiler->FindGraph(compilation_key); - // viz_graph("compiling_graph.dot", const_cast(&compiling_graph)); + viz_graph("compiling_graph.dot", const_cast(&compiling_graph)); EXPECT_THROW(cinn_compiler->FindGraph("no_existed"), paddle::platform::EnforceNotMet); - LoDTensor tensor1, tensor2, tensor3; - tensor1.Resize({1000, 784}); - tensor2.Resize({784, 100}); - tensor3.Resize({100}); - tensor1.mutable_data(platform::CPUPlace()); - tensor2.mutable_data(platform::CPUPlace()); - tensor3.mutable_data(platform::CPUPlace()); - std::map input_tensors = { - {"X", &tensor1}, {"Y", &tensor2}, {"Z", &tensor3}}; + auto inputs_info = GetInputsInfo(compilation_key, *graph); + std::unordered_map create_inputs; + for (const auto& pair : inputs_info) { + auto& tensor = create_inputs[pair.first]; + tensor.Resize(make_ddim(pair.second)); + tensor.mutable_data(platform::CPUPlace()); + } + std::map input_tensors; + std::for_each(create_inputs.begin(), create_inputs.end(), + [&input_tensors](const auto& val) { + input_tensors.emplace(val.first, &val.second); + }); auto compile_fn = [&](const Target& target) { const auto& compiled_obj = diff --git a/paddle/fluid/operators/cinn_launch_op.cc b/paddle/fluid/operators/cinn_launch_op.cc index b81ad11b06c1a30015823eb123d28aec344a7754..a17f1037318cb4c7824f0a35f9537f358cbad599 100644 --- a/paddle/fluid/operators/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn_launch_op.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/fluid/operators/cinn_launch_op.h" -#include "cinn/frontend/var_type_utils.h" #include "paddle/fluid/string/string_helper.h" namespace paddle { diff --git a/paddle/fluid/operators/cinn_launch_op.h b/paddle/fluid/operators/cinn_launch_op.h index b3bb050acfe07bb4da7e01b8ebf7c50af95459c5..858baffcac3584744532a29966877976a34fa790 100644 --- a/paddle/fluid/operators/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn_launch_op.h @@ -98,7 +98,8 @@ class CinnLaunchOpKernel : public framework::OpKernel { const auto& compilation_key = ctx.template Attr(kCompilationKey); VLOG(4) << "CinnLaunchOp attribute(" << kCompilationKey << ") " - << "value:" << compilation_key; + << "value:\n" + << CinnCompiler::GetInstance()->ReadableKey(compilation_key); const auto& graph = CinnCompiler::GetInstance()->FindGraph(compilation_key); auto input_variable_names = ctx.InputNames(kX); diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index f6c8ac2dc420f57c6f35e2a15354edb93e6067d9..a674a6a8acdf205f23bfcfcf53f6dbcb054723d9 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -710,6 +710,7 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, "events. Currently, only fuse allreduce supports " "this. Otherwise, the precision may be wrong."); +#ifdef PADDLE_WITH_CINN /* * CINN related FLAG * Name: FLAGS_use_cinn @@ -717,9 +718,31 @@ PADDLE_DEFINE_EXPORTED_bool(allreduce_record_one_event, false, * Value Range: bool, default=false * Example: FLAGS_use_cinn=true would run PaddlePaddle using CINN */ -#ifdef PADDLE_WITH_CINN PADDLE_DEFINE_EXPORTED_bool( use_cinn, false, "It controls whether to run PaddlePaddle using CINN"); + +/* + * CINN related FLAG + * Name: FLAGS_allow_cinn_ops + * Since Version: 2.3 + * Value Range: string, default="" + * Example: FLAGS_allow_cinn_ops="mul;relu" would only cover `mul` and `relu` + * when using CINN + */ +PADDLE_DEFINE_EXPORTED_string(allow_cinn_ops, "", + "It controls the cinn op subset to be used, " + "which has the highest priority."); + +/* + * CINN related FLAG + * Name: FLAGS_deny_cinn_ops + * Since Version: 2.3 + * Value Range: string, default="" + * Example: FLAGS_deny_cinn_ops="mul;relu" would block `mul` and `relu` two ops + * when using CINN + */ +PADDLE_DEFINE_EXPORTED_string(deny_cinn_ops, "", + "It controls the cinn op subset to be not used."); #endif DEFINE_int32(record_pool_max_size, 2000000, diff --git a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py index 601da32cfb12926660957b5d52c9aa1d8deb2792..d9ae3cf5e757d9b3d78c8f2437083bb326547152 100644 --- a/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py +++ b/python/paddle/fluid/tests/unittests/test_parallel_executor_run_cinn.py @@ -40,9 +40,9 @@ def set_cinn_flag(val): def reader(limit): - for i in range(limit): - yield np.ones([1, 28]).astype('float32') * (i * 3.14 / (i + 1)), \ - np.array([i + 1]).astype('int64') + for _ in range(limit): + yield np.random.random([1, 28]).astype('float32'), \ + np.random.randint(0, 2, size=[1]).astype('int64') def rand_data(img, label, loop_num=10): @@ -62,7 +62,7 @@ def build_program(main_program, startup_program): shape=[1, 28], dtype="float32", attr=paddle.ParamAttr(initializer=paddle.nn.initializer.Assign( - np.ones([1, 28]).astype(np.float32)))) + np.random.rand(1, 28).astype(np.float32)))) label = paddle.static.data(name="label", shape=[1], dtype='int64') hidden = paddle.add(img, param) @@ -75,7 +75,12 @@ def build_program(main_program, startup_program): return img, label, avg_loss -def do_test(dot_save_dir): +def train(dot_save_dir, prefix, seed=1234): + np.random.seed(seed) + paddle.seed(seed) + if paddle.is_compiled_with_cuda(): + paddle.set_flags({'FLAGS_cudnn_deterministic': 1}) + startup_program = paddle.static.Program() main_program = paddle.static.Program() img, label, loss = build_program(main_program, startup_program) @@ -86,32 +91,35 @@ def do_test(dot_save_dir): exe.run(startup_program) build_strategy = paddle.static.BuildStrategy() - build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, "viz") + build_strategy.debug_graphviz_path = os.path.join(dot_save_dir, prefix) compiled_program = paddle.static.CompiledProgram( main_program, build_strategy).with_data_parallel(loss_name=loss.name) - iters = 1 + iters = 100 feed = rand_data(img.name, label.name, iters) + loss_values = [] for step in range(iters): loss_v = exe.run(compiled_program, feed=feed[step], fetch_list=[loss], return_merged=False) - logger.info("loss value = {}".format(loss_v)) + loss_values.append(loss_v[0][0][0]) + return loss_values @unittest.skipIf(not set_cinn_flag(True), "Paddle is not compiled with CINN.") class TestParallelExecutorRunCinn(unittest.TestCase): def setUp(self): - set_cinn_flag(True) self.tmpdir = tempfile.mkdtemp(prefix="dots_") def tearDown(self): - set_cinn_flag(False) shutil.rmtree(self.tmpdir) def test_run_with_cinn(self): - do_test(self.tmpdir) + cinn_losses = train(self.tmpdir, "paddle") + set_cinn_flag(False) + pd_losses = train(self.tmpdir, "cinn") + self.assertTrue(np.allclose(cinn_losses, pd_losses, atol=1e-5)) if __name__ == '__main__':