未验证 提交 b4cb3589 编写于 作者: C CtfGo 提交者: GitHub

expose input variables that only shape needed in each subgraph that compiled by CINN (#38367)

collecting input variables that only shape needed of each subgraph that compiled by CINN in build_cinn_pass, and expose them to memory optimization of framework passes by declaringDECLARE_INPLACE_OP_INFERER in cinn_launch op.
上级 04f042a5
...@@ -11,7 +11,7 @@ if (WITH_TESTING) ...@@ -11,7 +11,7 @@ if (WITH_TESTING)
cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key) cc_test(cinn_cache_key_test SRCS cinn_cache_key_test.cc DEPS cinn_cache_key)
set_tests_properties(cinn_cache_key_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_cache_key_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler) cc_test(build_cinn_pass_test SRCS build_cinn_pass_test.cc DEPS build_cinn_pass cinn_compiler op_registry mul_op activation_op elementwise_add_op)
set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(build_cinn_pass_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc) cc_test(transform_desc_test SRCS transform_desc_test.cc DEPS transform_desc)
...@@ -20,6 +20,6 @@ if (WITH_TESTING) ...@@ -20,6 +20,6 @@ if (WITH_TESTING)
cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization) cc_test(cinn_graph_symbolization_test SRCS cinn_graph_symbolization_test.cc DEPS cinn_graph_symbolization)
set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_graph_symbolization_test PROPERTIES LABELS "RUN_TYPE=CINN")
cc_test(cinn_compiler_test SRCS cinn_compiler_test.cc DEPS cinn_compiler place proto_desc graph_viz_pass build_cinn_pass cinn) cc_test(cinn_compiler_test SRCS cinn_compiler_test.cc DEPS cinn_compiler place proto_desc graph_viz_pass build_cinn_pass cinn mul_op activation_op elementwise_add_op)
set_tests_properties(cinn_compiler_test PROPERTIES LABELS "RUN_TYPE=CINN") set_tests_properties(cinn_compiler_test PROPERTIES LABELS "RUN_TYPE=CINN")
endif() endif()
...@@ -32,6 +32,7 @@ limitations under the License. */ ...@@ -32,6 +32,7 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/operators/cinn/cinn_launch_op.h" #include "paddle/fluid/operators/cinn/cinn_launch_op.h"
...@@ -214,6 +215,73 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster, ...@@ -214,6 +215,73 @@ void AddOutputVar(const GraphNodeSet& output_vars, const GraphNodeSet& cluster,
} }
} }
std::unordered_set<std::string> ExtractNoNeedBufferFeeds(
const GraphNodeSet& cluster, const GraphNodeSet& cluster_inputs) {
// 1. Find op with NoNeedBufferVarsInferer defined and collect its input nodes
std::unordered_map<Node*, GraphNodeSet> op_node2no_need_buffer_nodes;
for (auto* op_node : cluster) {
auto& inferer =
OpInfoMap::Instance().Get(op_node->Name()).NoNeedBufferVarsInferer();
if (!inferer) {
continue;
}
auto* op_desc = op_node->Op();
PADDLE_ENFORCE_NOT_NULL(
op_desc, platform::errors::PreconditionNotMet(
"The op desc of node in cluster shouldn't be null."));
auto inferred_params =
inferer(op_desc->Inputs(), op_desc->Inputs(), op_desc->GetAttrMap());
std::unordered_set<std::string> inferred_args;
std::for_each(inferred_params.begin(), inferred_params.end(),
[&op_desc, &inferred_args](const std::string& param) {
const auto& args = op_desc->Input(param);
inferred_args.insert(args.begin(), args.end());
});
auto& no_need_buffer_nodes = op_node2no_need_buffer_nodes[op_node];
for (auto* input_node : op_node->inputs) {
if (input_node->Var() && inferred_args.count(input_node->Name())) {
VLOG(4) << "Input node(" << input_node->Name() << ") of op("
<< op_node->Name() << ") is no_need_buffer";
no_need_buffer_nodes.insert(input_node);
}
}
}
// 2. Extract no_need_buffer nodes from cluster_inputs by checking
// all of their outputs are op nodes with NoNeedBufferVarsInferer
// and they used as no_need_buffer inputs.
auto check_all_used_as_no_need_buffer_fn =
[&op_node2no_need_buffer_nodes](Node* var_node) -> bool {
for (auto* output_node : var_node->outputs) {
auto it = op_node2no_need_buffer_nodes.find(output_node);
if (it == op_node2no_need_buffer_nodes.end()) {
VLOG(4) << "Var node(" << var_node->Name() << ")'s output node("
<< output_node->Name()
<< ") doesn't have NoNeedBufferVarsInferer";
return false;
}
if (it->second.count(var_node) == 0) {
VLOG(4) << "Var node("
<< ") is not used as no_need_buffer inputs";
return false;
}
}
return true;
};
std::unordered_set<std::string> result;
for (const auto& op2inputs_pair : op_node2no_need_buffer_nodes) {
for (auto* input_node : op2inputs_pair.second) {
if (cluster_inputs.count(input_node) &&
check_all_used_as_no_need_buffer_fn(input_node)) {
VLOG(4) << "Input node(" << input_node->Name()
<< ") is declared as no_need_buffer cluster_inputs";
result.insert(input_node->Name());
}
}
}
return result;
}
// Create new subgraph with and op nodes are cluster nodes, and all // Create new subgraph with and op nodes are cluster nodes, and all
// var node are from internal nodes // var node are from internal nodes
std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
...@@ -295,7 +363,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster, ...@@ -295,7 +363,12 @@ std::unique_ptr<Graph> CreateNewSubGraph(const GraphNodeSet& cluster,
subgraph.get()); subgraph.get());
AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var, AddOutputVar(output_vars, cluster, old_op2new_op, old_var2new_var,
subgraph.get()); subgraph.get());
// Store the input variables whose buffer are not needed as
// attribute of the graph.
auto no_need_buffer_feeds = std::make_unique<std::unordered_set<std::string>>(
ExtractNoNeedBufferFeeds(cluster, cluster_inputs));
subgraph->Set<std::unordered_set<std::string>>(
kNoNeedBufferFeeds, no_need_buffer_feeds.release());
return subgraph; return subgraph;
} }
...@@ -374,15 +447,26 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster, ...@@ -374,15 +447,26 @@ void AddCinnOpToGraph(const GraphNodeSet& cluster,
// Add the cinn launch op // Add the cinn launch op
framework::OpDesc cinn_op_desc; framework::OpDesc cinn_op_desc;
cinn_op_desc.SetType(kCinnLaunchOp); cinn_op_desc.SetType(kCinnLaunchOp);
std::vector<std::string> input_names;
std::for_each(cluster_inputs.begin(), cluster_inputs.end(), // Divide input variables as two parts:
[&input_names, &deny_var_set](Node* n) { // the ones that data buffer are not needed and remain ones
if (n->Var() != nullptr && !deny_var_set.count(n->Name())) { std::vector<std::string> op_kx_inputs, no_need_buffer_inputs;
input_names.emplace_back(n->Name()); const auto& subgraph =
CinnCompiler::GetInstance()->FindGraph(compilation_key);
auto& no_need_buffer_feeds =
subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);
for (const auto* n : cluster_inputs) {
const auto& var_name = n->Name();
if (no_need_buffer_feeds.count(var_name)) {
no_need_buffer_inputs.emplace_back(var_name);
} else {
op_kx_inputs.emplace_back(var_name);
} }
}); }
cinn_op_desc.SetInput(operators::kX, input_names);
cinn_op_desc.SetInput(operators::kX, op_kx_inputs);
cinn_op_desc.SetInput(operators::kNoNeedBufferX, no_need_buffer_inputs);
std::vector<std::string> output_names; std::vector<std::string> output_names;
std::for_each(cluster_outputs.begin(), cluster_outputs.end(), std::for_each(cluster_outputs.begin(), cluster_outputs.end(),
[&output_names, &deny_var_set](Node* n) { [&output_names, &deny_var_set](Node* n) {
......
...@@ -21,6 +21,7 @@ namespace framework { ...@@ -21,6 +21,7 @@ namespace framework {
namespace paddle2cinn { namespace paddle2cinn {
constexpr char kCinnLaunchOp[] = "cinn_launch"; constexpr char kCinnLaunchOp[] = "cinn_launch";
constexpr char kNoNeedBufferFeeds[] = "no_need_buffer_feeds";
// A pass named BuildCinnPass, the function of this pass is: // A pass named BuildCinnPass, the function of this pass is:
// //
......
...@@ -24,6 +24,8 @@ limitations under the License. */ ...@@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h" #include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
...@@ -169,7 +171,7 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() { ...@@ -169,7 +171,7 @@ std::unique_ptr<Graph> BuildAllOpSupportCinnGraph() {
// v4 -- // v4 --
OpDesc add_op; OpDesc add_op;
add_op.SetType("add"); add_op.SetType("elementwise_add");
OpDesc mul_op; OpDesc mul_op;
mul_op.SetType("mul"); mul_op.SetType("mul");
OpDesc relu_op; OpDesc relu_op;
...@@ -259,7 +261,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -259,7 +261,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
// previous op (mul, add, relu) should all removed // previous op (mul, add, relu) should all removed
ASSERT_FALSE(CheckNodeExisted(nodes, "mul")); ASSERT_FALSE(CheckNodeExisted(nodes, "mul"));
ASSERT_FALSE(CheckNodeExisted(nodes, "add")); ASSERT_FALSE(CheckNodeExisted(nodes, "elementwise_add"));
ASSERT_FALSE(CheckNodeExisted(nodes, "relu")); ASSERT_FALSE(CheckNodeExisted(nodes, "relu"));
// After search, there should has just one cinn subgraph // After search, there should has just one cinn subgraph
...@@ -277,7 +279,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) { ...@@ -277,7 +279,7 @@ TEST(BuildCinnPassTest, AllOpSupportCinn) {
ASSERT_TRUE(CheckGraphIndependence(subnodes)); ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "mul")); ASSERT_TRUE(CheckNodeExisted(subnodes, "mul"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "add")); ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu")); ASSERT_TRUE(CheckNodeExisted(subnodes, "relu"));
ASSERT_EQ(CountNode(subnodes, "feed"), 2); ASSERT_EQ(CountNode(subnodes, "feed"), 2);
ASSERT_EQ(CountNode(subnodes, "fetch"), 1); ASSERT_EQ(CountNode(subnodes, "fetch"), 1);
...@@ -529,8 +531,136 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) { ...@@ -529,8 +531,136 @@ TEST(BuildCinnPassTest, MultiCinnSubgraph) {
} }
} }
std::unique_ptr<Graph> BuildGraphWithNoNeedBufferInput() {
ProgramDesc prog;
auto g = std::make_unique<Graph>(prog);
// fake1 --> v1 -- --> v4 --> relu_grad --> v6
// v2 -- | --> add_grad |
// v3 -- --> v5 --> fake2
OpDesc fake1_op;
fake1_op.SetType("fake1");
OpDesc add_grad_op;
add_grad_op.SetType("elementwise_add_grad");
add_grad_op.SetInput(::paddle::framework::GradVarName("Out"), {"var1"});
add_grad_op.SetInput("X", {"var2"});
add_grad_op.SetInput("Y", {"var3"});
OpDesc relu_grad_op;
relu_grad_op.SetType("relu_grad");
OpDesc fake2_op;
fake2_op.SetType("fake2");
VarDesc var1("var1");
VarDesc var2("var2");
VarDesc var3("var3");
VarDesc var4("var4");
VarDesc var5("var5");
VarDesc var6("var6");
ir::Node* fake1 = g->CreateOpNode(&fake1_op);
ir::Node* add_grad = g->CreateOpNode(&add_grad_op);
ir::Node* relu_grad = g->CreateOpNode(&relu_grad_op);
ir::Node* fake2 = g->CreateOpNode(&fake2_op);
ir::Node* v1 = g->CreateVarNode(&var1);
ir::Node* v2 = g->CreateVarNode(&var2);
ir::Node* v3 = g->CreateVarNode(&var3);
ir::Node* v4 = g->CreateVarNode(&var4);
ir::Node* v5 = g->CreateVarNode(&var5);
ir::Node* v6 = g->CreateVarNode(&var6);
// fill op node
fake1->outputs = {v1};
add_grad->inputs = {v1, v2, v3};
add_grad->outputs = {v4, v5};
relu_grad->inputs = {v4};
relu_grad->outputs = {v6};
fake2->inputs = {v5};
// fill variable node
v1->inputs = {fake1};
v1->outputs = {add_grad};
v2->outputs = {add_grad};
v3->outputs = {add_grad};
v4->inputs = {add_grad};
v4->outputs = {relu_grad};
v5->inputs = {add_grad};
v5->outputs = {fake2};
v6->inputs = {relu_grad};
return g;
}
TEST(BuildCinnPassTest, NoNeedBufferInput) {
auto g = BuildGraphWithNoNeedBufferInput();
auto pass =
paddle::framework::ir::PassRegistry::Instance().Get("build_cinn_pass");
pass->Apply(g.get());
// After search, the graph should as following
// fake1 --> v1 -- --> v6
// v2 -- | -->kCinnLaunchOp |
// v3 -- --> v5 --> fake2
const auto& nodes = g->Nodes();
ASSERT_EQ(nodes.size(), static_cast<size_t>(8));
ASSERT_TRUE(CheckGraphIndependence(nodes));
// A new op named kCinnLaunchOp should be added and
// its input arguments are set correctly
ASSERT_TRUE(CheckNodeExisted(nodes, kCinnLaunchOp));
ASSERT_EQ(CountNode(nodes, kCinnLaunchOp), 1);
auto* cinn_op_node = GetNode(nodes, kCinnLaunchOp);
ASSERT_EQ(cinn_op_node->Op()->Input(operators::kX),
std::vector<std::string>({"var1"}));
auto& no_need_buffer_x = cinn_op_node->Op()->Input(operators::kNoNeedBufferX);
ASSERT_EQ(std::unordered_set<std::string>(no_need_buffer_x.begin(),
no_need_buffer_x.end()),
std::unordered_set<std::string>({"var2", "var3"}));
// previous op (add_grad, relu_grad) should be removed
ASSERT_FALSE(CheckNodeExisted(nodes, "add_grad"));
ASSERT_FALSE(CheckNodeExisted(nodes, "relu_grad"));
// previous op (fake1, fake2) should be preserved
ASSERT_TRUE(CheckNodeExisted(nodes, "fake1"));
ASSERT_TRUE(CheckNodeExisted(nodes, "fake2"));
// After search, there should has just one cinn subgraph
// feed --> v1 -- --> v6 --> fetch
// feed --> v2 -- | -->add_grad --> v4 --> relu_grad |
// feed --> v3 -- --> v5 --> fetch
auto compilation_keys = GetCompilationKeys(*g);
ASSERT_EQ(compilation_keys.size(), static_cast<size_t>(1));
auto* cinn_compiler = CinnCompiler::GetInstance();
const auto& subgraph = cinn_compiler->FindGraph(compilation_keys[0]);
const auto& subnodes = subgraph.Nodes();
ASSERT_EQ(subnodes.size(), static_cast<size_t>(13));
ASSERT_TRUE(CheckGraphIndependence(subnodes));
ASSERT_TRUE(CheckNodeExisted(subnodes, "elementwise_add_grad"));
ASSERT_TRUE(CheckNodeExisted(subnodes, "relu_grad"));
ASSERT_EQ(CountNode(subnodes, "feed"), 3);
ASSERT_EQ(CountNode(subnodes, "fetch"), 2);
const auto& no_need_buffer_feeds =
subgraph.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);
ASSERT_EQ(no_need_buffer_feeds.size(), 2);
ASSERT_EQ(no_need_buffer_feeds,
std::unordered_set<std::string>({"var2", "var3"}));
}
} // namespace paddle2cinn } // namespace paddle2cinn
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_OP(mul);
USE_OP(relu);
USE_OP(elementwise_add);
USE_OP(relu_grad);
USE_OP(elementwise_add_grad);
...@@ -293,3 +293,6 @@ TEST(CinnCompilerTest, Compile) { ...@@ -293,3 +293,6 @@ TEST(CinnCompilerTest, Compile) {
USE_PASS(build_cinn_pass); USE_PASS(build_cinn_pass);
USE_PASS(graph_viz_pass); USE_PASS(graph_viz_pass);
USE_OP(mul);
USE_OP(relu);
USE_OP(elementwise_add);
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/transform_desc.h" #include "paddle/fluid/framework/paddle2cinn/transform_desc.h"
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
...@@ -42,20 +43,35 @@ using FeedInfoMap = CinnGraphSymbolization::FeedInfoMap; ...@@ -42,20 +43,35 @@ using FeedInfoMap = CinnGraphSymbolization::FeedInfoMap;
namespace utils { namespace utils {
OpMapperContext::FeedInfo GetCinnFeedInfoFromTensor(const Tensor& tensor) { OpMapperContext::FeedInfo GetCinnFeedInfoFromTensor(
const Tensor& tensor, bool skip_trans_type = false) {
OpMapperContext::FeedInfo info; OpMapperContext::FeedInfo info;
const auto& dim = tensor.dims(); const auto& dim = tensor.dims();
for (int i = 0; i < dim.size(); i++) { for (int i = 0; i < dim.size(); i++) {
info.shape.emplace_back(static_cast<int>(dim[i])); info.shape.emplace_back(static_cast<int>(dim[i]));
} }
auto cinn_var_type = TransformVarDataTypeToCinn(tensor.type()); // use FP32 as default type if skip_trans_type=true to pass CINN
// enforce check that is shape and type of each input should be filled,
// and we will ensure these feeds doesn't be used in execution on cinn_launch
// op
auto tensor_type = ::paddle::framework::proto::VarType::FP32;
if (!skip_trans_type) {
tensor_type = tensor.type();
}
auto cinn_var_type = TransformVarDataTypeToCinn(tensor_type);
info.type = ::cinn::frontend::utils::CppVarType2CommonType(cinn_var_type); info.type = ::cinn::frontend::utils::CppVarType2CommonType(cinn_var_type);
return info; return info;
} }
} // namespace utils } // namespace utils
FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const { FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
const std::unordered_set<std::string>* no_need_buffer_feeds = nullptr;
if (graph_.Has(kNoNeedBufferFeeds)) {
no_need_buffer_feeds =
&graph_.Get<std::unordered_set<std::string>>(kNoNeedBufferFeeds);
}
FeedInfoMap feed_map; FeedInfoMap feed_map;
for (auto& feed_pair : input_tensors_) { for (auto& feed_pair : input_tensors_) {
const auto& feed_name = feed_pair.first; const auto& feed_name = feed_pair.first;
...@@ -67,7 +83,14 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const { ...@@ -67,7 +83,14 @@ FeedInfoMap CinnGraphSymbolization::GetFeedInfoMapFromInput() const {
feed_name.c_str())); feed_name.c_str()));
VLOG(4) << "Get feed info from input: " << feed_name; VLOG(4) << "Get feed info from input: " << feed_name;
// if this feed declared as no need buffer then we can not access
// its type so passing skip_trans_type=true
if (no_need_buffer_feeds) {
feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(
*tensor, no_need_buffer_feeds->count(feed_name) > 0);
} else {
feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor); feed_map[feed_name] = utils::GetCinnFeedInfoFromTensor(*tensor);
}
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
feed_map[feed_name].shape.size(), 0UL, feed_map[feed_name].shape.size(), 0UL,
......
...@@ -2,7 +2,7 @@ include(operators) ...@@ -2,7 +2,7 @@ include(operators)
register_operators(EXCLUDES cinn_launch_op) register_operators(EXCLUDES cinn_launch_op)
cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn) cc_library(cinn_launch_context SRCS cinn_launch_context.cc DEPS ddim lod_tensor scope cinn)
op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS cinn cinn_compiler cinn_launch_context) op_library(cinn_launch_op SRCS cinn_launch_op.cc cinn_launch_op.cu.cc DEPS string_helper cinn cinn_compiler cinn_launch_context)
if (WITH_TESTING) if (WITH_TESTING)
cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context) cc_test(cinn_launch_context_test SRCS cinn_launch_context_test.cc DEPS ddim lod_tensor scope cinn_launch_context)
......
...@@ -86,7 +86,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel { ...@@ -86,7 +86,9 @@ class CinnLaunchOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs(kX), "Input", kX, "CinnLaunchOp"); OP_INOUT_CHECK(ctx->HasInputs(kX) || ctx->HasInputs(kNoNeedBufferX),
"Input", string::format_string("%s|%s", kX, kNoNeedBufferX),
"CinnLaunchOp");
OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs, OP_INOUT_CHECK(ctx->HasOutputs(kOutputs), "Output", kOutputs,
"CinnLaunchOp"); "CinnLaunchOp");
} }
...@@ -117,8 +119,15 @@ class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -117,8 +119,15 @@ class CinnLaunchOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput(kX, AddInput(kX,
"(vector<LoDTensor>)" "(vector<LoDTensor>)"
"which are the input of graph inside the CinnLaunchOp.") "which are the input of graph inside the CinnLaunchOp"
"excluding kNoNeedBufferX.")
.AsDuplicable(); .AsDuplicable();
AddInput(kNoNeedBufferX,
"(vector<LoDTensor>)"
"which are the input of graph inside the CinnLaunchOp but"
"their buffer are not needed.")
.AsDuplicable()
.AsDispensable();
AddOutput(kOutputs, AddOutput(kOutputs,
"(vector<LoDTensor>)" "(vector<LoDTensor>)"
"which are the output of graph inside the CinnLaunchOp.") "which are the output of graph inside the CinnLaunchOp.")
...@@ -155,12 +164,16 @@ It accomplishes the computation of graph following several steps: ...@@ -155,12 +164,16 @@ It accomplishes the computation of graph following several steps:
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(CinnLaunchOpNoBufVarsInferer,
kNoNeedBufferX);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker, cinn_launch, ops::CinnLaunchOp, ops::CinnLaunchOpMaker,
ops::CinnLaunchOpNoBufVarsInferer,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* see [Why use single type kernel] */ /* see [Why use single type kernel] */
......
...@@ -32,6 +32,7 @@ namespace paddle { ...@@ -32,6 +32,7 @@ namespace paddle {
namespace operators { namespace operators {
constexpr char kX[] = "X"; constexpr char kX[] = "X";
constexpr char kNoNeedBufferX[] = "NoNeedBufferX";
constexpr char kOutputs[] = "Out"; constexpr char kOutputs[] = "Out";
constexpr char kCompilationKey[] = "compilation_key"; constexpr char kCompilationKey[] = "compilation_key";
...@@ -87,15 +88,33 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -87,15 +88,33 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
<< "value:\n" << "value:\n"
<< CinnCompiler::GetInstance()->ReadableKey(compilation_key); << CinnCompiler::GetInstance()->ReadableKey(compilation_key);
auto input_variable_names = ctx.InputNames(kX);
const auto& input_tensors = ctx.MultiInput<LoDTensor>(kX);
std::map<std::string, const LoDTensor*> inputs_name2tensor; std::map<std::string, const LoDTensor*> inputs_name2tensor;
std::transform(input_variable_names.begin(), input_variable_names.end(), std::vector<std::string> input_x_variable_names;
input_tensors.begin(), std::vector<std::string> input_no_need_buffer_variable_names;
auto add_name2tensor_fn = [&inputs_name2tensor](
const std::vector<std::string>& variable_names,
const std::vector<const LoDTensor*>& tensors) {
std::transform(
variable_names.begin(), variable_names.end(), tensors.begin(),
std::inserter(inputs_name2tensor, inputs_name2tensor.end()), std::inserter(inputs_name2tensor, inputs_name2tensor.end()),
[](const std::string& name, const LoDTensor* tensor) { [](const std::string& name, const LoDTensor* tensor) {
return std::make_pair(name, tensor); return std::make_pair(name, tensor);
}); });
};
auto input_x_tensors = ctx.MultiInput<LoDTensor>(kX);
if (!input_x_tensors.empty()) {
input_x_variable_names = std::move(ctx.InputNames(kX));
add_name2tensor_fn(input_x_variable_names, input_x_tensors);
}
auto input_no_need_buffer_tensors =
ctx.MultiInput<LoDTensor>(kNoNeedBufferX);
if (!input_no_need_buffer_tensors.empty()) {
input_no_need_buffer_variable_names =
std::move(ctx.InputNames(kNoNeedBufferX));
add_name2tensor_fn(input_no_need_buffer_variable_names,
input_no_need_buffer_tensors);
}
// Step 2. Get compilation result of the graph // Step 2. Get compilation result of the graph
auto target = details::PlaceToCinnTarget(place); auto target = details::PlaceToCinnTarget(place);
...@@ -112,12 +131,21 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> { ...@@ -112,12 +131,21 @@ class CinnLaunchOpKernel : public framework::OpKernel<T> {
// 3.1 Prepare input variables: tensors of input variables have // 3.1 Prepare input variables: tensors of input variables have
// been initialized before graph compiled, just check the // been initialized before graph compiled, just check the
// equiality between tensors of paddle and cinn. // equiality between tensors of paddle and cinn.
for (const auto& var_name : input_variable_names) { for (const auto& var_name : input_no_need_buffer_variable_names) {
if (!launch_context->IsVariableUsed(var_name)) { // the input variable declared as 'no need buffer' can not be used
PADDLE_ENFORCE_EQ(
launch_context->IsVariableUsed(var_name), false,
platform::errors::InvalidArgument(
"Input variable(%s) should not be used by cinn in execution",
var_name));
}
for (const auto& var_name : input_x_variable_names) {
// some input variables don't need for cinn because they are // some input variables don't need for cinn because they are
// eliminated by optimized passes or some cinn operators use // eliminated by optimized passes or some cinn operators use
// less variables // less variables
VLOG(4) << "Input variable(" << var_name << ") not used by cinn"; if (!launch_context->IsVariableUsed(var_name)) {
VLOG(4) << "Input variable" << var_name << " not used by cinn";
continue; continue;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册