diff --git a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt index 8d2ee2f01008bf20d43ddf9242f2bf56effb0db7..a1d4eb20ffa6ab3a6d9597d9444b01eed24d063d 100644 --- a/paddle/fluid/framework/paddle2cinn/CMakeLists.txt +++ b/paddle/fluid/framework/paddle2cinn/CMakeLists.txt @@ -11,7 +11,7 @@ if (WITH_TESTING) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) set_tests_properties(cinn_cache_key_test PROPERTIES LABELS "RUN_TYPE=CINN") - cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler) + cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler op_registry mul_op activation_op elementwise_add_op) set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN") cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc) @@ -20,6 +20,6 @@ if (WITH_TESTING) cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization) set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS "RUN_TYPE=CINN") - cc_test(cinn_compiler_test SRCS cinn_compiler_test.cc DEPS cinn_compiler place proto_desc graph_viz_pass build_cinn_pass cinn) + cc_test(cinn_compiler_test SRCS cinn_compiler_test.cc DEPS cinn_compiler place proto_desc graph_viz_pass build_cinn_pass cinn mul_op activation_op elementwise_add_op) set_tests_properties(cinn_compiler_test PROPERTIES LABELS "RUN_TYPE=CINN") endif() diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc index c15744fc1650db2048688bb183c9cff30d779a0c..4abe3a55b298f1f006a129873b544cf55d252daa 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.cc @@ -32,6 +32,7 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" +#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/operators/cinn/cinn_launch_op.h" @@ -214,6 +215,73 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster, } } +std::unordered_set ExtractNoNeedBufferFeeds( + const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs) { + // 1. Find op with NoNeedBufferVarsInferer defined and collect its input nodes + std::unordered_map op_node2no_need_buffer_nodes; + for (auto* op_node : cluster) { + auto& inferer = + OpInfoMap::Instance().Get(op_node->Name()).NoNeedBufferVarsInferer(); + if (!inferer) { + continue; + } + auto* op_desc = op_node->Op(); + PADDLE_ENFORCE_NOT_NULL( + op_desc, platform::errors::PreconditionNotMet( + "The op desc of node in cluster shouldn't be null.")); + auto inferred_params = + inferer(op_desc->Inputs(), op_desc->Inputs(), op_desc->GetAttrMap()); + std::unordered_set inferred_args; + std::for_each(inferred_params.begin(), inferred_params.end(), + [&op_desc, &inferred_args](const std::string& param) { + const auto& args = op_desc->Input(param); + inferred_args.insert(args.begin(), args.end()); + }); + auto& no_need_buffer_nodes = op_node2no_need_buffer_nodes[op_node]; + for (auto* input_node : op_node->inputs) { + if (input_node->Var() && inferred_args.count(input_node->Name())) { + VLOG(4) << "Input node(" << input_node->Name() << ") of op(" + << op_node->Name() << ") is no_need_buffer"; + no_need_buffer_nodes.insert(input_node); + } + } + } + + // 2. Extract no_need_buffer nodes from cluster_inputs by checking + // all of their outputs are op nodes with NoNeedBufferVarsInferer + // and they used as no_need_buffer inputs. + auto check_all_used_as_no_need_buffer_fn = + [&op_node2no_need_buffer_nodes](Node* var_node) -> bool { + for (auto* output_node : var_node->outputs) { + auto it = op_node2no_need_buffer_nodes.find(output_node); + if (it == op_node2no_need_buffer_nodes.end()) { + VLOG(4) << "Var node(" << var_node->Name() << ")'s output node(" + << output_node->Name() + << ") doesn't have NoNeedBufferVarsInferer"; + return false; + } + if (it->second.count(var_node) == 0) { + VLOG(4) << "Var node(" + << ") is not used as no_need_buffer inputs"; + return false; + } + } + return true; + }; + std::unordered_set result; + for (const auto& op2inputs_pair : op_node2no_need_buffer_nodes) { + for (auto* input_node : op2inputs_pair.second) { + if (cluster_inputs.count(input_node) && + check_all_used_as_no_need_buffer_fn(input_node)) { + VLOG(4) << "Input node(" << input_node->Name() + << ") is declared as no_need_buffer cluster_inputs"; + result.insert(input_node->Name()); + } + } + } + return result; +} + // Create new subgraph with and op nodes are cluster nodes, and all // var node are from internal nodes std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, @@ -295,7 +363,12 @@ std::unique_ptr CreateNewSubGraph(const GraphNodeSet& cluster, subgraph.get()); AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var, subgraph.get()); - + // Store the input variables whose buffer are not needed as + // attribute of the graph. + auto no_need_buffer_feeds = std::make_unique>( + ExtractNoNeedBufferFeeds(cluster, cluster_inputs)); + subgraph->Set>( + kNoNeedBufferFeeds, no_need_buffer_feeds.release()); return subgraph; } @@ -374,15 +447,26 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, // Add the cinn launch op framework::OpDesc cinn_op_desc; cinn_op_desc.SetType(kCinnLaunchOp); - std::vector input_names; - std::for_each(cluster_inputs.begin(), cluster_inputs.end(), - [&input_names, &deny_var_set](Node* n) { - if (n->Var() != nullptr && !deny_var_set.count(n->Name())) { - input_names.emplace_back(n->Name()); - } - }); - cinn_op_desc.SetInput(operators::kX, input_names); + // Divide input variables as two parts: + // the ones that data buffer are not needed and remain ones + std::vector op_kx_inputs, no_need_buffer_inputs; + const auto& subgraph = + CinnCompiler::GetInstance()->FindGraph(compilation_key); + auto& no_need_buffer_feeds = + subgraph.Get>(kNoNeedBufferFeeds); + for (const auto* n : cluster_inputs) { + const auto& var_name = n->Name(); + if (no_need_buffer_feeds.count(var_name)) { + no_need_buffer_inputs.emplace_back(var_name); + } else { + op_kx_inputs.emplace_back(var_name); + } + } + + cinn_op_desc.SetInput(operators::kX, op_kx_inputs); + cinn_op_desc.SetInput(operators::kNoNeedBufferX, no_need_buffer_inputs); + std::vector output_names; std::for_each(cluster_outputs.begin(), cluster_outputs.end(), [&output_names, &deny_var_set](Node* n) { diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h index 8f7d5eb266ece8b3a0650392c7c16f906397d969..10d12f93f8bd83c3768f6951396959d0d9db5634 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass.h @@ -21,6 +21,7 @@ namespace framework { namespace paddle2cinn { constexpr char kCinnLaunchOp[] = "cinn_launch"; +constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds"; // A pass named BuildCinnPass, the function of this pass is: // diff --git a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc index 586b59a05ecef82b30f5df3c3f2122c683dd5412..bca6a0a4cb8e0d61574f2b7be00e1f67b70ec035 100644 --- a/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc +++ b/paddle/fluid/framework/paddle2cinn/build_cinn_pass_test.cc @@ -24,6 +24,8 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/var_desc.h" @@ -169,7 +171,7 @@ std::unique_ptr BuildAllOpSupportCinnGraph() { // v4 -- OpDesc add_op; - add_op.SetType("add"); + add_op.SetType("elementwise_add"); OpDesc mul_op; mul_op.SetType("mul"); OpDesc relu_op; @@ -259,7 +261,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { // previous op (mul, add, relu) should all removed ASSERT_FALSE(CheckNodeExisted(nodes, "mul")); - ASSERT_FALSE(CheckNodeExisted(nodes, "add")); + ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add")); ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); // After search, there should has just one cinn subgraph @@ -277,7 +279,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); - ASSERT_TRUE(CheckNodeExisted(subnodes, "add")); + ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_EQ(CountNode(subnodes, "feed"), 2); ASSERT_EQ(CountNode(subnodes, "fetch"), 1); @@ -529,8 +531,136 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { } } +std::unique_ptr BuildGraphWithNoNeedBufferInput() { + ProgramDesc prog; + auto g = std::make_unique(prog); + + // fake1 --> v1 -- --> v4 --> relu_grad --> v6 + // v2 -- | --> add_grad | + // v3 -- --> v5 --> fake2 + + OpDesc fake1_op; + fake1_op.SetType("fake1"); + OpDesc add_grad_op; + add_grad_op.SetType("elementwise_add_grad"); + add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"}); + add_grad_op.SetInput("X", {"var2"}); + add_grad_op.SetInput("Y", {"var3"}); + OpDesc relu_grad_op; + relu_grad_op.SetType("relu_grad"); + OpDesc fake2_op; + fake2_op.SetType("fake2"); + + VarDesc var1("var1"); + VarDesc var2("var2"); + VarDesc var3("var3"); + VarDesc var4("var4"); + VarDesc var5("var5"); + VarDesc var6("var6"); + + ir::Node* fake1 = g->CreateOpNode(&fake1_op); + ir::Node* add_grad = g->CreateOpNode(&add_grad_op); + ir::Node* relu_grad = g->CreateOpNode(&relu_grad_op); + ir::Node* fake2 = g->CreateOpNode(&fake2_op); + + ir::Node* v1 = g->CreateVarNode(&var1); + ir::Node* v2 = g->CreateVarNode(&var2); + ir::Node* v3 = g->CreateVarNode(&var3); + ir::Node* v4 = g->CreateVarNode(&var4); + ir::Node* v5 = g->CreateVarNode(&var5); + ir::Node* v6 = g->CreateVarNode(&var6); + + // fill op node + fake1->outputs = {v1}; + add_grad->inputs = {v1, v2, v3}; + add_grad->outputs = {v4, v5}; + relu_grad->inputs = {v4}; + relu_grad->outputs = {v6}; + fake2->inputs = {v5}; + + // fill variable node + v1->inputs = {fake1}; + v1->outputs = {add_grad}; + + v2->outputs = {add_grad}; + v3->outputs = {add_grad}; + + v4->inputs = {add_grad}; + v4->outputs = {relu_grad}; + v5->inputs = {add_grad}; + v5->outputs = {fake2}; + + v6->inputs = {relu_grad}; + + return g; +} + +TEST(BuildCinnPassTest, NoNeedBufferInput) { + auto g = BuildGraphWithNoNeedBufferInput(); + + auto pass = + paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass"); + pass->Apply(g.get()); + + // After search, the graph should as following + // fake1 --> v1 -- --> v6 + // v2 -- | -->kCinnLaunchOp | + // v3 -- --> v5 --> fake2 + const auto& nodes = g->Nodes(); + ASSERT_EQ(nodes.size(), static_cast(8)); + ASSERT_TRUE(CheckGraphIndependence(nodes)); + + // A new op named kCinnLaunchOp should be added and + // its input arguments are set correctly + ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp)); + ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1); + auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp); + ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX), + std::vector({"var1"})); + auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX); + ASSERT_EQ(std::unordered_set(no_need_buffer_x.begin(), + no_need_buffer_x.end()), + std::unordered_set({"var2", "var3"})); + + // previous op (add_grad, relu_grad) should be removed + ASSERT_FALSE(CheckNodeExisted(nodes, "add_grad")); + ASSERT_FALSE(CheckNodeExisted(nodes, "relu_grad")); + + // previous op (fake1, fake2) should be preserved + ASSERT_TRUE(CheckNodeExisted(nodes, "fake1")); + ASSERT_TRUE(CheckNodeExisted(nodes, "fake2")); + + // After search, there should has just one cinn subgraph + // feed --> v1 -- --> v6 --> fetch + // feed --> v2 -- | -->add_grad --> v4 --> relu_grad | + // feed --> v3 -- --> v5 --> fetch + auto compilation_keys = GetCompilationKeys(*g); + ASSERT_EQ(compilation_keys.size(), static_cast(1)); + auto* cinn_compiler = CinnCompiler::GetInstance(); + const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]); + + const auto& subnodes = subgraph.Nodes(); + ASSERT_EQ(subnodes.size(), static_cast(13)); + ASSERT_TRUE(CheckGraphIndependence(subnodes)); + + ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad")); + ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad")); + ASSERT_EQ(CountNode(subnodes, "feed"), 3); + ASSERT_EQ(CountNode(subnodes, "fetch"), 2); + const auto& no_need_buffer_feeds = + subgraph.Get>(kNoNeedBufferFeeds); + ASSERT_EQ(no_need_buffer_feeds.size(), 2); + ASSERT_EQ(no_need_buffer_feeds, + std::unordered_set({"var2", "var3"})); +} + } // namespace paddle2cinn } // namespace framework } // namespace paddle USE_PASS(build_cinn_pass); +USE_OP(mul); +USE_OP(relu); +USE_OP(elementwise_add); +USE_OP(relu_grad); +USE_OP(elementwise_add_grad); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc index db20e423c4a40f51184f921ee3e8b9be0ad276ac..6769413d99bafd7a26a3486da6928d06ad920ace 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_compiler_test.cc @@ -293,3 +293,6 @@ TEST(CinnCompilerTest, Compile) { USE_PASS(build_cinn_pass); USE_PASS(graph_viz_pass); +USE_OP(mul); +USE_OP(relu); +USE_OP(elementwise_add); diff --git a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc index 9bdaf61858f45d9670716a99f84c5c63c6244657..43d62e8d8df3bb71e471598ff61a444b338f762e 100644 --- a/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc +++ b/paddle/fluid/framework/paddle2cinn/cinn_graph_symbolization.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h" #include "paddle/fluid/framework/paddle2cinn/transform_desc.h" #include "paddle/fluid/framework/variable.h" @@ -42,20 +43,35 @@ using FeedInfoMap = CinnGraphSymbolization::FeedInfoMap; namespace utils { -OpMapperContext::FeedInfo GetCinnFeedInfoFromTensor(const Tensor& tensor) { +OpMapperContext::FeedInfo GetCinnFeedInfoFromTensor( + const Tensor& tensor, bool skip_trans_type = false) { OpMapperContext::FeedInfo info; const auto& dim = tensor.dims(); for (int i = 0; i < dim.size(); i++) { info.shape.emplace_back(static_cast(dim[i])); } - auto cinn_var_type = TransformVarDataTypeToCinn(tensor.type()); + // use FP32 as default type if skip_trans_type=true to pass CINN + // enforce check that is shape and type of each input should be filled, + // and we will ensure these feeds doesn't be used in execution on cinn_launch + // op + auto tensor_type = ::paddle::framework::proto::VarType::FP32; + if (!skip_trans_type) { + tensor_type = tensor.type(); + } + auto cinn_var_type = TransformVarDataTypeToCinn(tensor_type); info.type = ::cinn::frontend::utils::CppVarType2CommonType(cinn_var_type); return info; } } // namespace utils FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const { + const std::unordered_set* no_need_buffer_feeds = nullptr; + if (graph_.Has(kNoNeedBufferFeeds)) { + no_need_buffer_feeds = + &graph_.Get>(kNoNeedBufferFeeds); + } + FeedInfoMap feed_map; for (auto& feed_pair : input_tensors_) { const auto& feed_name = feed_pair.first; @@ -67,7 +83,14 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const { feed_name.c_str())); VLOG(4) << "Get feed info from input: " << feed_name; - feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor); + // if this feed declared as no need buffer then we can not access + // its type so passing skip_trans_type=true + if (no_need_buffer_feeds) { + feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor( + *tensor, no_need_buffer_feeds->count(feed_name) > 0); + } else { + feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor); + } PADDLE_ENFORCE_NE( feed_map[feed_name].shape.size(), 0UL, diff --git a/paddle/fluid/operators/cinn/CMakeLists.txt b/paddle/fluid/operators/cinn/CMakeLists.txt index 1df49a957438d02eb1cd78afe0560c6ea9536118..ed3a7598bdab6b288d280c13af79f16ff0a84e46 100644 --- a/paddle/fluid/operators/cinn/CMakeLists.txt +++ b/paddle/fluid/operators/cinn/CMakeLists.txt @@ -2,7 +2,7 @@ include(operators) register_operators(EXCLUDES cinn_launch_op) cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn) -op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS cinn cinn_compiler cinn_launch_context) +op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_launch_context) if (WITH_TESTING) cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context) diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.cc b/paddle/fluid/operators/cinn/cinn_launch_op.cc index 4b86d9b2b3d41540d70122362215bd6a77ef0184..cd17c947228d6201b551410172246498f75f3b12 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.cc +++ b/paddle/fluid/operators/cinn/cinn_launch_op.cc @@ -86,7 +86,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnLaunchOp"); + OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX), + "Input", string::format_string("%s|%s", kX, kNoNeedBufferX), + "CinnLaunchOp"); OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, "CinnLaunchOp"); } @@ -117,8 +119,15 @@ class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput(kX, "(vector)" - "which are the input of graph inside the CinnLaunchOp.") + "which are the input of graph inside the CinnLaunchOp" + "excluding kNoNeedBufferX.") .AsDuplicable(); + AddInput(kNoNeedBufferX, + "(vector)" + "which are the input of graph inside the CinnLaunchOp but" + "their buffer are not needed.") + .AsDuplicable() + .AsDispensable(); AddOutput(kOutputs, "(vector)" "which are the output of graph inside the CinnLaunchOp.") @@ -155,12 +164,16 @@ It accomplishes the computation of graph following several steps: } }; +DECLARE_NO_NEED_BUFFER_VARS_INFERER(CinnLaunchOpNoBufVarsInferer, + kNoNeedBufferX); + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR( cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker, + ops::CinnLaunchOpNoBufVarsInferer, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); /* see [Why use single type kernel] */ diff --git a/paddle/fluid/operators/cinn/cinn_launch_op.h b/paddle/fluid/operators/cinn/cinn_launch_op.h index 170546ed23041237095474830b4d6fd2d11d8783..23dfa9d84c01203f3edbef6216cccbc340ffda52 100644 --- a/paddle/fluid/operators/cinn/cinn_launch_op.h +++ b/paddle/fluid/operators/cinn/cinn_launch_op.h @@ -32,6 +32,7 @@ namespace paddle { namespace operators { constexpr char kX[] = "X"; +constexpr char kNoNeedBufferX[] = "NoNeedBufferX"; constexpr char kOutputs[] = "Out"; constexpr char kCompilationKey[] = "compilation_key"; @@ -87,15 +88,33 @@ class CinnLaunchOpKernel : public framework::OpKernel { << "value:\n" << CinnCompiler::GetInstance()->ReadableKey(compilation_key); - auto input_variable_names = ctx.InputNames(kX); - const auto& input_tensors = ctx.MultiInput(kX); std::map inputs_name2tensor; - std::transform(input_variable_names.begin(), input_variable_names.end(), - input_tensors.begin(), - std::inserter(inputs_name2tensor, inputs_name2tensor.end()), - [](const std::string& name, const LoDTensor* tensor) { - return std::make_pair(name, tensor); - }); + std::vector input_x_variable_names; + std::vector input_no_need_buffer_variable_names; + auto add_name2tensor_fn = [&inputs_name2tensor]( + const std::vector& variable_names, + const std::vector& tensors) { + std::transform( + variable_names.begin(), variable_names.end(), tensors.begin(), + std::inserter(inputs_name2tensor, inputs_name2tensor.end()), + [](const std::string& name, const LoDTensor* tensor) { + return std::make_pair(name, tensor); + }); + }; + + auto input_x_tensors = ctx.MultiInput(kX); + if (!input_x_tensors.empty()) { + input_x_variable_names = std::move(ctx.InputNames(kX)); + add_name2tensor_fn(input_x_variable_names, input_x_tensors); + } + auto input_no_need_buffer_tensors = + ctx.MultiInput(kNoNeedBufferX); + if (!input_no_need_buffer_tensors.empty()) { + input_no_need_buffer_variable_names = + std::move(ctx.InputNames(kNoNeedBufferX)); + add_name2tensor_fn(input_no_need_buffer_variable_names, + input_no_need_buffer_tensors); + } // Step 2. Get compilation result of the graph auto target = details::PlaceToCinnTarget(place); @@ -112,12 +131,21 @@ class CinnLaunchOpKernel : public framework::OpKernel { // 3.1 Prepare input variables: tensors of input variables have // been initialized before graph compiled, just check the // equiality between tensors of paddle and cinn. - for (const auto& var_name : input_variable_names) { + for (const auto& var_name : input_no_need_buffer_variable_names) { + // the input variable declared as 'no need buffer' can not be used + PADDLE_ENFORCE_EQ( + launch_context->IsVariableUsed(var_name), false, + platform::errors::InvalidArgument( + "Input variable(%s) should not be used by cinn in execution", + var_name)); + } + + for (const auto& var_name : input_x_variable_names) { + // some input variables don't need for cinn because they are + // eliminated by optimized passes or some cinn operators use + // less variables if (!launch_context->IsVariableUsed(var_name)) { - // some input variables don't need for cinn because they are - // eliminated by optimized passes or some cinn operators use - // less variables - VLOG(4) << "Input variable(" << var_name << ") not used by cinn"; + VLOG(4) << "Input variable" << var_name << " not used by cinn"; continue; }