未验证 提交 e8e9d6c5 编写于 作者: Z zhupengyang 提交者: GitHub

optimize write_read_array, gather if beam_size=1 (#53130)

上级 fc6d4399
......@@ -787,20 +787,37 @@ struct Layers {
return out;
}
VarDesc* write_to_array(std::vector<VarDesc*> x, VarDesc* i) {
VarDesc* write_to_array(VarDesc* x, VarDesc* i) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("write_to_array");
std::vector<std::string> x_names;
for (auto k : x) {
x_names.push_back(k->Name());
}
op->SetInput("X", x_names);
op->SetInput("X", {x->Name()});
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* read_from_array(VarDesc* x, VarDesc* i) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("read_from_array");
op->SetInput("X", {x->Name()});
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* gather(VarDesc* x, VarDesc* index, int axis) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("gather");
op->SetInput("X", {x->Name()});
op->SetInput("Index", {index->Name()});
op->SetAttr("axis", axis);
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* is_empty(VarDesc* input) { return unary_op("is_empty", input); }
VarDesc* logical_not(VarDesc* input) {
......
......@@ -259,20 +259,21 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) {
beam_search_nodes[0]->Op()->GetAttrIfExists<int>("beam_size") == 1;
}
Node* FindOpNodeByInputName(Graph* graph,
const std::string& op_type,
const std::string& arg_name,
const std::string& var_name) {
std::vector<Node*> FindOpNodeByInputName(Graph* graph,
const std::string& var_name) {
std::vector<Node*> ret;
for (auto* node : graph->Nodes()) {
if (!node->IsOp() || node->Op()->Type() != op_type) continue;
if (!node->IsOp()) continue;
auto inputs = node->Op()->Inputs();
if (inputs.count(arg_name) == 0) continue;
auto in_names = inputs.at(arg_name);
if (std::find(in_names.begin(), in_names.end(), var_name) == in_names.end())
continue;
return node;
for (auto input : inputs) {
auto in_names = input.second;
if (std::count(in_names.begin(), in_names.end(), var_name) > 0) {
ret.push_back(node);
break;
}
}
}
return nullptr;
return ret;
}
void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const {
......@@ -287,9 +288,9 @@ void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const {
GET_IR_NODE(assign);
GET_IR_NODE(assign_out);
// Assign_out may not link to gather, so we find gather by input name.
auto* gather =
FindOpNodeByInputName(graph, "gather", "X", assign_out->Name());
if (gather == nullptr) return;
auto next_ops = FindOpNodeByInputName(graph, assign_out->Name());
if (next_ops.size() != 1 || next_ops[0]->Name() != "gather") return;
auto* gather = next_ops[0];
// "assign_out" is used in multi blocks. "assign_out" should be reserved.
auto* assign_in = assign->inputs[0];
......@@ -449,6 +450,143 @@ void OneBeamSizeFusePass::RemoveBeamSearchAssociatedOps(
AddStatis(found_subgraph_count);
}
namespace patterns {
struct WriteToArrayPattern : public PatternBase {
WriteToArrayPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(write_to_array);
// declare variable node's name
PATTERN_DECL_NODE(write_x);
PATTERN_DECL_NODE(write_out);
};
WriteToArrayPattern::WriteToArrayPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* write_x = pattern->NewNode(write_x_repr())
->assert_is_op_input("write_to_array", "X")
->assert_is_persistable_var();
auto* write_to_array =
pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array");
auto* write_out = pattern->NewNode(write_out_repr())
->assert_is_op_output("write_to_array", "Out");
write_to_array->LinksFrom({write_x}).LinksTo({write_out});
}
} // namespace patterns
void OneBeamSizeFusePass::RemoveWriteReadArrayOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::WriteToArrayPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle RemoveWriteReadArrayOps";
GET_IR_NODE(write_to_array);
GET_IR_NODE(write_x);
GET_IR_NODE(write_out);
auto* scope = param_scope();
// write_out is from graph0 and do not link to any op, so we find
// "read_from_array" by write_out name.
auto next_ops = FindOpNodeByInputName(graph, write_out->Name());
if (next_ops.size() != 1 || next_ops[0]->Name() != "read_from_array")
return;
auto* read_from_array = next_ops[0];
auto* read_out = read_from_array->outputs[0];
read_out->Var()->SetPersistable(true);
auto* write_x_tensor =
scope->Var(write_x->Name())->GetMutable<phi::DenseTensor>();
auto* read_out_tensor =
scope->Var(read_out->Name())->GetMutable<phi::DenseTensor>();
Assign(*write_x_tensor, read_out_tensor);
std::unordered_set<const Node*> delete_nodes{
write_to_array, write_out, read_from_array};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
namespace patterns {
struct GatherPattern : public PatternBase {
GatherPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(gather);
// declare variable node's name
PATTERN_DECL_NODE(gather_x);
PATTERN_DECL_NODE(gather_i);
PATTERN_DECL_NODE(gather_out);
};
GatherPattern::GatherPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* gather_x =
pattern->NewNode(gather_x_repr())->assert_is_op_input("gather", "X");
auto* gather_i = pattern->NewNode(gather_i_repr())
->assert_is_op_input("gather", "Index")
->assert_is_persistable_var();
auto* gather = pattern->NewNode(gather_repr())
->assert_is_op("gather")
->assert_more([&](Node* node) {
return node->Op()->GetAttrIfExists<int>("axis") == 0;
});
auto* gather_out =
pattern->NewNode(gather_out_repr())->assert_is_op_output("gather", "Out");
gather->LinksFrom({gather_x, gather_i}).LinksTo({gather_out});
}
} // namespace patterns
void OneBeamSizeFusePass::RemoveGatherOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::GatherPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle RemoveGatherOps";
GET_IR_NODE(gather);
GET_IR_NODE(gather_x);
GET_IR_NODE(gather_i);
GET_IR_NODE(gather_out);
auto* scope = param_scope();
// gather_i should be 0
auto* gather_i_tensor =
scope->Var(gather_i->Name())->GetMutable<phi::DenseTensor>();
auto gather_i_dims = gather_i_tensor->dims();
if (gather_i_dims.size() != 1 || gather_i_dims[0] != 1) return;
if (gather_i_tensor->dtype() == phi::DataType::INT32) {
auto* i_data = gather_i_tensor->data<int>();
if (i_data[0] != 0) return;
} else {
auto* i_data = gather_i_tensor->data<int64_t>();
if (i_data[0] != 0) return;
}
auto gather_x_name = gather_x->Name();
auto gather_out_name = gather_out->Name();
for (auto* next_op : gather_out->outputs) {
next_op->Op()->RenameInput(gather_out_name, gather_x_name);
IR_NODE_LINK_TO(gather_x, next_op);
}
std::unordered_set<const Node*> delete_nodes{gather, gather_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
......@@ -458,6 +596,8 @@ void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const {
RemoveAssignGather(graph);
FoldShapeAssociatedOps(graph);
RemoveBeamSearchAssociatedOps(graph);
RemoveWriteReadArrayOps(graph);
RemoveGatherOps(graph);
}
} // namespace ir
......
......@@ -103,6 +103,38 @@ class OneBeamSizeFusePass : public FusePassBase {
*/
void RemoveBeamSearchAssociatedOps(ir::Graph* graph) const;
/*
Origin subgraph:
(x: persistable) (index)
\ /
write_to_array
|
read_from_array
|
any_op
Fused subgraph:
(x: persistable)
|
any_op
*/
void RemoveWriteReadArrayOps(ir::Graph* graph) const;
/*
Origin subgraph:
(x: dims0=1) (index=[0])
\ /
gather(axis=0)
|
any_op
Fused subgraph:
(x)
|
any_op
*/
void RemoveGatherOps(ir::Graph* graph) const;
const std::string name_scope_{"one_beam_size_fuse_pass"};
};
......
......@@ -20,6 +20,21 @@ namespace paddle {
namespace framework {
namespace ir {
template <typename T = float>
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims,
T value = 0) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* data = cpu_ctx->Alloc<T>(tensor);
for (int64_t i = 0; i < tensor->numel(); i++) {
data[i] = value;
}
}
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
......@@ -128,9 +143,9 @@ TEST(RemoveBeamSearchAssociatedOps, basic) {
auto* selected_scores = beam_search_outs[2];
auto* write_to_array_0_i = layers.data("write_to_array_0_i");
layers.write_to_array({selected_ids}, write_to_array_0_i);
layers.write_to_array(selected_ids, write_to_array_0_i);
auto* write_to_array_1_i = layers.data("write_to_array_1_i");
layers.write_to_array({selected_scores}, write_to_array_1_i);
layers.write_to_array(selected_scores, write_to_array_1_i);
auto* is_empty_out = layers.is_empty(selected_ids);
layers.logical_not(is_empty_out);
layers.cast(parent_idx);
......@@ -147,6 +162,58 @@ TEST(RemoveBeamSearchAssociatedOps, basic) {
"beam_search op should be removed from the graph."));
}
TEST(RemoveWriteReadArrayOps, basic) {
Layers layers;
auto* block = layers.Block();
OpDesc* beam_search_op = block->AppendOp();
beam_search_op->SetType("beam_search");
beam_search_op->SetAttr("beam_size", 1);
auto* write_x = layers.data("write_x", {1}, true);
auto* write_i = layers.data("write_i");
auto* write_out = layers.write_to_array(write_x, write_i);
auto* read_i = layers.data("read_i");
layers.read_from_array(write_out, read_i);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto* param_scope = new Scope();
graph->Set("__param_scope__", param_scope);
AddVarToScope(param_scope, write_x->Name(), {1});
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto write_read_num = GetNumOpNodes(graph, "write_to_array") +
GetNumOpNodes(graph, "read_from_array");
PADDLE_ENFORCE_EQ(write_read_num,
0,
platform::errors::PreconditionNotMet(
"write_to_array and read_from_array ops should be "
"removed from the graph."));
}
TEST(RemoveGatherOps, basic) {
Layers layers;
auto* block = layers.Block();
OpDesc* beam_search_op = block->AppendOp();
beam_search_op->SetType("beam_search");
beam_search_op->SetAttr("beam_size", 1);
auto* gather_x = layers.data("gather_x");
auto* gather_i = layers.data("gather_i", {1}, true);
layers.gather(gather_x, gather_i, 0);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto* param_scope = new Scope();
graph->Set("__param_scope__", param_scope);
AddVarToScope<int>(param_scope, gather_i->Name(), {1}, 0);
auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass");
pass->Apply(graph.get());
auto gather_num = GetNumOpNodes(graph, "gather");
PADDLE_ENFORCE_EQ(gather_num,
0,
platform::errors::PreconditionNotMet(
"gather op should be removed from the graph."));
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -57,7 +57,6 @@ void MemcpySyncH2D(void* dst,
const platform::XPUPlace& dst_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(dst_place);
dev_ctx->Wait();
phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx);
}
......@@ -67,7 +66,6 @@ void MemcpySyncD2H(void* dst,
const platform::XPUPlace& src_place) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.GetByPlace(src_place);
dev_ctx->Wait();
phi::backends::xpu::MemcpySyncD2H(dst, src, count, src_place, *dev_ctx);
}
......
......@@ -64,7 +64,6 @@ void Copy(const Context& dev_ctx,
dst_ptr = dev_ctx.Alloc(
dst, src.dtype(), 0, dst_place.GetType() == AllocationType::GPUPINNED);
#endif
#ifdef PADDLE_WITH_XPU
} else if (dst_place.GetType() == AllocationType::XPU) {
dst_ptr = dev_ctx.Alloc(dst, src.dtype());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册