未验证 提交 e8e9d6c5 编写于 作者: Z zhupengyang 提交者: GitHub

optimize write_read_array, gather if beam_size=1 (#53130)

上级 fc6d4399
...@@ -787,20 +787,37 @@ struct Layers { ...@@ -787,20 +787,37 @@ struct Layers {
return out; return out;
} }
VarDesc* write_to_array(std::vector<VarDesc*> x, VarDesc* i) { VarDesc* write_to_array(VarDesc* x, VarDesc* i) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("write_to_array"); op->SetType("write_to_array");
std::vector<std::string> x_names; op->SetInput("X", {x->Name()});
for (auto k : x) { op->SetInput("I", {i->Name()});
x_names.push_back(k->Name()); op->SetOutput("Out", {out->Name()});
} return out;
op->SetInput("X", x_names); }
VarDesc* read_from_array(VarDesc* x, VarDesc* i) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("read_from_array");
op->SetInput("X", {x->Name()});
op->SetInput("I", {i->Name()}); op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
return out; return out;
} }
VarDesc* gather(VarDesc* x, VarDesc* index, int axis) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("gather");
op->SetInput("X", {x->Name()});
op->SetInput("Index", {index->Name()});
op->SetAttr("axis", axis);
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* is_empty(VarDesc* input) { return unary_op("is_empty", input); } VarDesc* is_empty(VarDesc* input) { return unary_op("is_empty", input); }
VarDesc* logical_not(VarDesc* input) { VarDesc* logical_not(VarDesc* input) {
......
...@@ -259,20 +259,21 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) { ...@@ -259,20 +259,21 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) {
beam_search_nodes[0]->Op()->GetAttrIfExists<int>("beam_size") == 1; beam_search_nodes[0]->Op()->GetAttrIfExists<int>("beam_size") == 1;
} }
Node* FindOpNodeByInputName(Graph* graph, std::vector<Node*> FindOpNodeByInputName(Graph* graph,
const std::string& op_type, const std::string& var_name) {
const std::string& arg_name, std::vector<Node*> ret;
const std::string& var_name) {
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (!node->IsOp() || node->Op()->Type() != op_type) continue; if (!node->IsOp()) continue;
auto inputs = node->Op()->Inputs(); auto inputs = node->Op()->Inputs();
if (inputs.count(arg_name) == 0) continue; for (auto input : inputs) {
auto in_names = inputs.at(arg_name); auto in_names = input.second;
if (std::find(in_names.begin(), in_names.end(), var_name) == in_names.end()) if (std::count(in_names.begin(), in_names.end(), var_name) > 0) {
continue; ret.push_back(node);
return node; break;
}
}
} }
return nullptr; return ret;
} }
void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const { void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const {
...@@ -287,9 +288,9 @@ void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const { ...@@ -287,9 +288,9 @@ void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const {
GET_IR_NODE(assign); GET_IR_NODE(assign);
GET_IR_NODE(assign_out); GET_IR_NODE(assign_out);
// Assign_out may not link to gather, so we find gather by input name. // Assign_out may not link to gather, so we find gather by input name.
auto* gather = auto next_ops = FindOpNodeByInputName(graph, assign_out->Name());
FindOpNodeByInputName(graph, "gather", "X", assign_out->Name()); if (next_ops.size() != 1 || next_ops[0]->Name() != "gather") return;
if (gather == nullptr) return; auto* gather = next_ops[0];
// "assign_out" is used in multi blocks. "assign_out" should be reserved. // "assign_out" is used in multi blocks. "assign_out" should be reserved.
auto* assign_in = assign->inputs[0]; auto* assign_in = assign->inputs[0];
...@@ -449,6 +450,143 @@ void OneBeamSizeFusePass::RemoveBeamSearchAssociatedOps( ...@@ -449,6 +450,143 @@ void OneBeamSizeFusePass::RemoveBeamSearchAssociatedOps(
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
namespace patterns {
struct WriteToArrayPattern : public PatternBase {
WriteToArrayPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(write_to_array);
// declare variable node's name
PATTERN_DECL_NODE(write_x);
PATTERN_DECL_NODE(write_out);
};
WriteToArrayPattern::WriteToArrayPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* write_x = pattern->NewNode(write_x_repr())
->assert_is_op_input("write_to_array", "X")
->assert_is_persistable_var();
auto* write_to_array =
pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array");
auto* write_out = pattern->NewNode(write_out_repr())
->assert_is_op_output("write_to_array", "Out");
write_to_array->LinksFrom({write_x}).LinksTo({write_out});
}
} // namespace patterns
void OneBeamSizeFusePass::RemoveWriteReadArrayOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::WriteToArrayPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle RemoveWriteReadArrayOps";
GET_IR_NODE(write_to_array);
GET_IR_NODE(write_x);
GET_IR_NODE(write_out);
auto* scope = param_scope();
// write_out is from graph0 and do not link to any op, so we find
// "read_from_array" by write_out name.
auto next_ops = FindOpNodeByInputName(graph, write_out->Name());
if (next_ops.size() != 1 || next_ops[0]->Name() != "read_from_array")
return;
auto* read_from_array = next_ops[0];
auto* read_out = read_from_array->outputs[0];
read_out->Var()->SetPersistable(true);
auto* write_x_tensor =
scope->Var(write_x->Name())->GetMutable<phi::DenseTensor>();
auto* read_out_tensor =
scope->Var(read_out->Name())->GetMutable<phi::DenseTensor>();
Assign(*write_x_tensor, read_out_tensor);
std::unordered_set<const Node*> delete_nodes{
write_to_array, write_out, read_from_array};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
namespace patterns {
struct GatherPattern : public PatternBase {
GatherPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(gather);
// declare variable node's name
PATTERN_DECL_NODE(gather_x);
PATTERN_DECL_NODE(gather_i);
PATTERN_DECL_NODE(gather_out);
};
GatherPattern::GatherPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* gather_x =
pattern->NewNode(gather_x_repr())->assert_is_op_input("gather", "X");
auto* gather_i = pattern->NewNode(gather_i_repr())
->assert_is_op_input("gather", "Index")
->assert_is_persistable_var();
auto* gather = pattern->NewNode(gather_repr())
->assert_is_op("gather")
->assert_more([&](Node* node) {
return node->Op()->GetAttrIfExists<int>("axis") == 0;
});
auto* gather_out =
pattern->NewNode(gather_out_repr())->assert_is_op_output("gather", "Out");
gather->LinksFrom({gather_x, gather_i}).LinksTo({gather_out});
}
} // namespace patterns
void OneBeamSizeFusePass::RemoveGatherOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::GatherPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle RemoveGatherOps";
GET_IR_NODE(gather);
GET_IR_NODE(gather_x);
GET_IR_NODE(gather_i);
GET_IR_NODE(gather_out);
auto* scope = param_scope();
// gather_i should be 0
auto* gather_i_tensor =
scope->Var(gather_i->Name())->GetMutable<phi::DenseTensor>();
auto gather_i_dims = gather_i_tensor->dims();
if (gather_i_dims.size() != 1 || gather_i_dims[0] != 1) return;
if (gather_i_tensor->dtype() == phi::DataType::INT32) {
auto* i_data = gather_i_tensor->data<int>();
if (i_data[0] != 0) return;
} else {
auto* i_data = gather_i_tensor->data<int64_t>();
if (i_data[0] != 0) return;
}
auto gather_x_name = gather_x->Name();
auto gather_out_name = gather_out->Name();
for (auto* next_op : gather_out->outputs) {
next_op->Op()->RenameInput(gather_out_name, gather_x_name);
IR_NODE_LINK_TO(gather_x, next_op);
}
std::unordered_set<const Node*> delete_nodes{gather, gather_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const { void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
...@@ -458,6 +596,8 @@ void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -458,6 +596,8 @@ void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const {
RemoveAssignGather(graph); RemoveAssignGather(graph);
FoldShapeAssociatedOps(graph); FoldShapeAssociatedOps(graph);
RemoveBeamSearchAssociatedOps(graph); RemoveBeamSearchAssociatedOps(graph);
RemoveWriteReadArrayOps(graph);
RemoveGatherOps(graph);
} }
} // namespace ir } // namespace ir
......
...@@ -103,6 +103,38 @@ class OneBeamSizeFusePass : public FusePassBase { ...@@ -103,6 +103,38 @@ class OneBeamSizeFusePass : public FusePassBase {
*/ */
void RemoveBeamSearchAssociatedOps(ir::Graph* graph) const; void RemoveBeamSearchAssociatedOps(ir::Graph* graph) const;
/*
Origin subgraph:
(x: persistable) (index)
\ /
write_to_array
|
read_from_array
|
any_op
Fused subgraph:
(x: persistable)
|
any_op
*/
void RemoveWriteReadArrayOps(ir::Graph* graph) const;
/*
Origin subgraph:
(x: dims0=1) (index=[0])
\ /
gather(axis=0)
|
any_op
Fused subgraph:
(x)
|
any_op
*/
void RemoveGatherOps(ir::Graph* graph) const;
const std::string name_scope_{"one_beam_size_fuse_pass"}; const std::string name_scope_{"one_beam_size_fuse_pass"};
}; };
......
...@@ -20,6 +20,21 @@ namespace paddle { ...@@ -20,6 +20,21 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
template <typename T = float>
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims,
T value = 0) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* data = cpu_ctx->Alloc<T>(tensor);
for (int64_t i = 0; i < tensor->numel(); i++) {
data[i] = value;
}
}
VarDesc* Data(paddle::framework::BlockDesc* block, VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name, std::string name,
std::vector<int64_t> shape = {}, std::vector<int64_t> shape = {},
...@@ -128,9 +143,9 @@ TEST(RemoveBeamSearchAssociatedOps, basic) { ...@@ -128,9 +143,9 @@ TEST(RemoveBeamSearchAssociatedOps, basic) {
auto* selected_scores = beam_search_outs[2]; auto* selected_scores = beam_search_outs[2];
auto* write_to_array_0_i = layers.data("write_to_array_0_i"); auto* write_to_array_0_i = layers.data("write_to_array_0_i");
layers.write_to_array({selected_ids}, write_to_array_0_i); layers.write_to_array(selected_ids, write_to_array_0_i);
auto* write_to_array_1_i = layers.data("write_to_array_1_i"); auto* write_to_array_1_i = layers.data("write_to_array_1_i");
layers.write_to_array({selected_scores}, write_to_array_1_i); layers.write_to_array(selected_scores, write_to_array_1_i);
auto* is_empty_out = layers.is_empty(selected_ids); auto* is_empty_out = layers.is_empty(selected_ids);
layers.logical_not(is_empty_out); layers.logical_not(is_empty_out);
layers.cast(parent_idx); layers.cast(parent_idx);
...@@ -147,6 +162,58 @@ TEST(RemoveBeamSearchAssociatedOps, basic) { ...@@ -147,6 +162,58 @@ TEST(RemoveBeamSearchAssociatedOps, basic) {
"beam_search op should be removed from the graph.")); "beam_search op should be removed from the graph."));
} }
TEST(RemoveWriteReadArrayOps, basic) {
Layers layers;
auto* block = layers.Block();
OpDesc* beam_search_op = block->AppendOp();
beam_search_op->SetType("beam_search");
beam_search_op->SetAttr("beam_size", 1);
auto* write_x = layers.data("write_x", {1}, true);
auto* write_i = layers.data("write_i");
auto* write_out = layers.write_to_array(write_x, write_i);
auto* read_i = layers.data("read_i");
layers.read_from_array(write_out, read_i);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto* param_scope = new Scope();
graph->Set("__param_scope__", param_scope);
AddVarToScope(param_scope, write_x->Name(), {1});
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto write_read_num = GetNumOpNodes(graph, "write_to_array") +
GetNumOpNodes(graph, "read_from_array");
PADDLE_ENFORCE_EQ(write_read_num,
0,
platform::errors::PreconditionNotMet(
"write_to_array and read_from_array ops should be "
"removed from the graph."));
}
TEST(RemoveGatherOps, basic) {
Layers layers;
auto* block = layers.Block();
OpDesc* beam_search_op = block->AppendOp();
beam_search_op->SetType("beam_search");
beam_search_op->SetAttr("beam_size", 1);
auto* gather_x = layers.data("gather_x");
auto* gather_i = layers.data("gather_i", {1}, true);
layers.gather(gather_x, gather_i, 0);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto* param_scope = new Scope();
graph->Set("__param_scope__", param_scope);
AddVarToScope<int>(param_scope, gather_i->Name(), {1}, 0);
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto gather_num = GetNumOpNodes(graph, "gather");
PADDLE_ENFORCE_EQ(gather_num,
0,
platform::errors::PreconditionNotMet(
"gather op should be removed from the graph."));
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -57,7 +57,6 @@ void MemcpySyncH2D(void* dst, ...@@ -57,7 +57,6 @@ void MemcpySyncH2D(void* dst,
const platform::XPUPlace& dst_place) { const platform::XPUPlace& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(dst_place); auto* dev_ctx = pool.GetByPlace(dst_place);
dev_ctx->Wait();
phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx); phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx);
} }
...@@ -67,7 +66,6 @@ void MemcpySyncD2H(void* dst, ...@@ -67,7 +66,6 @@ void MemcpySyncD2H(void* dst,
const platform::XPUPlace& src_place) { const platform::XPUPlace& src_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(src_place); auto* dev_ctx = pool.GetByPlace(src_place);
dev_ctx->Wait();
phi::backends::xpu::MemcpySyncD2H(dst, src, count, src_place, *dev_ctx); phi::backends::xpu::MemcpySyncD2H(dst, src, count, src_place, *dev_ctx);
} }
......
...@@ -64,7 +64,6 @@ void Copy(const Context& dev_ctx, ...@@ -64,7 +64,6 @@ void Copy(const Context& dev_ctx,
dst_ptr = dev_ctx.Alloc( dst_ptr = dev_ctx.Alloc(
dst, src.dtype(), 0, dst_place.GetType() == AllocationType::GPUPINNED); dst, src.dtype(), 0, dst_place.GetType() == AllocationType::GPUPINNED);
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
} else if (dst_place.GetType() == AllocationType::XPU) { } else if (dst_place.GetType() == AllocationType::XPU) {
dst_ptr = dev_ctx.Alloc(dst, src.dtype()); dst_ptr = dev_ctx.Alloc(dst, src.dtype());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册