diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 1547615be377f449bc371db6928589e807d1198d..0557104773a95fc5eb8fdf54f05b87cc3e11e4c1 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -787,20 +787,37 @@ struct Layers { return out; } - VarDesc* write_to_array(std::vector x, VarDesc* i) { + VarDesc* write_to_array(VarDesc* x, VarDesc* i) { VarDesc* out = lod_tensor(unique_name()); OpDesc* op = program_.MutableBlock(0)->AppendOp(); op->SetType("write_to_array"); - std::vector x_names; - for (auto k : x) { - x_names.push_back(k->Name()); - } - op->SetInput("X", x_names); + op->SetInput("X", {x->Name()}); + op->SetInput("I", {i->Name()}); + op->SetOutput("Out", {out->Name()}); + return out; + } + + VarDesc* read_from_array(VarDesc* x, VarDesc* i) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("read_from_array"); + op->SetInput("X", {x->Name()}); op->SetInput("I", {i->Name()}); op->SetOutput("Out", {out->Name()}); return out; } + VarDesc* gather(VarDesc* x, VarDesc* index, int axis) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("gather"); + op->SetInput("X", {x->Name()}); + op->SetInput("Index", {index->Name()}); + op->SetAttr("axis", axis); + op->SetOutput("Out", {out->Name()}); + return out; + } + VarDesc* is_empty(VarDesc* input) { return unary_op("is_empty", input); } VarDesc* logical_not(VarDesc* input) { diff --git a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc index 2944f9c7a517494097bac877280530a8c0bf46e2..9ab6cd031b5c72b40b4af989eb736e0407e08314 100644 --- a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.cc @@ -259,20 +259,21 @@ bool OnlyOneBeamSearchAndOneBeamSize(ir::Graph* graph) { beam_search_nodes[0]->Op()->GetAttrIfExists("beam_size") == 1; } -Node* FindOpNodeByInputName(Graph* graph, - const std::string& op_type, - const std::string& arg_name, - const std::string& var_name) { +std::vector FindOpNodeByInputName(Graph* graph, + const std::string& var_name) { + std::vector ret; for (auto* node : graph->Nodes()) { - if (!node->IsOp() || node->Op()->Type() != op_type) continue; + if (!node->IsOp()) continue; auto inputs = node->Op()->Inputs(); - if (inputs.count(arg_name) == 0) continue; - auto in_names = inputs.at(arg_name); - if (std::find(in_names.begin(), in_names.end(), var_name) == in_names.end()) - continue; - return node; + for (auto input : inputs) { + auto in_names = input.second; + if (std::count(in_names.begin(), in_names.end(), var_name) > 0) { + ret.push_back(node); + break; + } + } } - return nullptr; + return ret; } void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const { @@ -287,9 +288,9 @@ void OneBeamSizeFusePass::RemoveAssignGather(ir::Graph* graph) const { GET_IR_NODE(assign); GET_IR_NODE(assign_out); // Assign_out may not link to gather, so we find gather by input name. - auto* gather = - FindOpNodeByInputName(graph, "gather", "X", assign_out->Name()); - if (gather == nullptr) return; + auto next_ops = FindOpNodeByInputName(graph, assign_out->Name()); + if (next_ops.size() != 1 || next_ops[0]->Name() != "gather") return; + auto* gather = next_ops[0]; // "assign_out" is used in multi blocks. "assign_out" should be reserved. auto* assign_in = assign->inputs[0]; @@ -449,6 +450,143 @@ void OneBeamSizeFusePass::RemoveBeamSearchAssociatedOps( AddStatis(found_subgraph_count); } +namespace patterns { +struct WriteToArrayPattern : public PatternBase { + WriteToArrayPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(write_to_array); + // declare variable node's name + PATTERN_DECL_NODE(write_x); + PATTERN_DECL_NODE(write_out); +}; + +WriteToArrayPattern::WriteToArrayPattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* write_x = pattern->NewNode(write_x_repr()) + ->assert_is_op_input("write_to_array", "X") + ->assert_is_persistable_var(); + auto* write_to_array = + pattern->NewNode(write_to_array_repr())->assert_is_op("write_to_array"); + auto* write_out = pattern->NewNode(write_out_repr()) + ->assert_is_op_output("write_to_array", "Out"); + + write_to_array->LinksFrom({write_x}).LinksTo({write_out}); +} +} // namespace patterns + +void OneBeamSizeFusePass::RemoveWriteReadArrayOps(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::WriteToArrayPattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle RemoveWriteReadArrayOps"; + GET_IR_NODE(write_to_array); + GET_IR_NODE(write_x); + GET_IR_NODE(write_out); + auto* scope = param_scope(); + + // write_out is from graph0 and do not link to any op, so we find + // "read_from_array" by write_out name. + auto next_ops = FindOpNodeByInputName(graph, write_out->Name()); + if (next_ops.size() != 1 || next_ops[0]->Name() != "read_from_array") + return; + auto* read_from_array = next_ops[0]; + auto* read_out = read_from_array->outputs[0]; + read_out->Var()->SetPersistable(true); + auto* write_x_tensor = + scope->Var(write_x->Name())->GetMutable(); + auto* read_out_tensor = + scope->Var(read_out->Name())->GetMutable(); + Assign(*write_x_tensor, read_out_tensor); + + std::unordered_set delete_nodes{ + write_to_array, write_out, read_from_array}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +namespace patterns { +struct GatherPattern : public PatternBase { + GatherPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(gather); + // declare variable node's name + PATTERN_DECL_NODE(gather_x); + PATTERN_DECL_NODE(gather_i); + PATTERN_DECL_NODE(gather_out); +}; + +GatherPattern::GatherPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* gather_x = + pattern->NewNode(gather_x_repr())->assert_is_op_input("gather", "X"); + auto* gather_i = pattern->NewNode(gather_i_repr()) + ->assert_is_op_input("gather", "Index") + ->assert_is_persistable_var(); + auto* gather = pattern->NewNode(gather_repr()) + ->assert_is_op("gather") + ->assert_more([&](Node* node) { + return node->Op()->GetAttrIfExists("axis") == 0; + }); + auto* gather_out = + pattern->NewNode(gather_out_repr())->assert_is_op_output("gather", "Out"); + + gather->LinksFrom({gather_x, gather_i}).LinksTo({gather_out}); +} +} // namespace patterns + +void OneBeamSizeFusePass::RemoveGatherOps(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::GatherPattern pattern(gpd.mutable_pattern(), name_scope_); + int found_subgraph_count = 0; + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle RemoveGatherOps"; + GET_IR_NODE(gather); + GET_IR_NODE(gather_x); + GET_IR_NODE(gather_i); + GET_IR_NODE(gather_out); + auto* scope = param_scope(); + + // gather_i should be 0 + auto* gather_i_tensor = + scope->Var(gather_i->Name())->GetMutable(); + auto gather_i_dims = gather_i_tensor->dims(); + if (gather_i_dims.size() != 1 || gather_i_dims[0] != 1) return; + if (gather_i_tensor->dtype() == phi::DataType::INT32) { + auto* i_data = gather_i_tensor->data(); + if (i_data[0] != 0) return; + } else { + auto* i_data = gather_i_tensor->data(); + if (i_data[0] != 0) return; + } + + auto gather_x_name = gather_x->Name(); + auto gather_out_name = gather_out->Name(); + for (auto* next_op : gather_out->outputs) { + next_op->Op()->RenameInput(gather_out_name, gather_x_name); + IR_NODE_LINK_TO(gather_x, next_op); + } + + std::unordered_set delete_nodes{gather, gather_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); @@ -458,6 +596,8 @@ void OneBeamSizeFusePass::ApplyImpl(ir::Graph* graph) const { RemoveAssignGather(graph); FoldShapeAssociatedOps(graph); RemoveBeamSearchAssociatedOps(graph); + RemoveWriteReadArrayOps(graph); + RemoveGatherOps(graph); } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.h b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.h index 79e0504bff0d3e38da0f56890d0d0b0b8ee7bcee..5ea5abc3fcae8571b54ea47dec957edd48a8523a 100644 --- a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass.h @@ -103,6 +103,38 @@ class OneBeamSizeFusePass : public FusePassBase { */ void RemoveBeamSearchAssociatedOps(ir::Graph* graph) const; + /* + Origin subgraph: + (x: persistable) (index) + \ / + write_to_array + | + read_from_array + | + any_op + + Fused subgraph: + (x: persistable) + | + any_op + */ + void RemoveWriteReadArrayOps(ir::Graph* graph) const; + + /* + Origin subgraph: + (x: dims0=1) (index=[0]) + \ / + gather(axis=0) + | + any_op + + Fused subgraph: + (x) + | + any_op + */ + void RemoveGatherOps(ir::Graph* graph) const; + const std::string name_scope_{"one_beam_size_fuse_pass"}; }; diff --git a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass_test.cc index c98c9891bb04b30bb0b3bf3440cf1843c46df5db..5eccd822e8259c9ba7fa157043b2100a206330dc 100644 --- a/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass_test.cc +++ b/paddle/fluid/framework/ir/xpu/one_beam_size_fuse_pass_test.cc @@ -20,6 +20,21 @@ namespace paddle { namespace framework { namespace ir { +template +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims, + T value = 0) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + auto* data = cpu_ctx->Alloc(tensor); + for (int64_t i = 0; i < tensor->numel(); i++) { + data[i] = value; + } +} + VarDesc* Data(paddle::framework::BlockDesc* block, std::string name, std::vector shape = {}, @@ -128,9 +143,9 @@ TEST(RemoveBeamSearchAssociatedOps, basic) { auto* selected_scores = beam_search_outs[2]; auto* write_to_array_0_i = layers.data("write_to_array_0_i"); - layers.write_to_array({selected_ids}, write_to_array_0_i); + layers.write_to_array(selected_ids, write_to_array_0_i); auto* write_to_array_1_i = layers.data("write_to_array_1_i"); - layers.write_to_array({selected_scores}, write_to_array_1_i); + layers.write_to_array(selected_scores, write_to_array_1_i); auto* is_empty_out = layers.is_empty(selected_ids); layers.logical_not(is_empty_out); layers.cast(parent_idx); @@ -147,6 +162,58 @@ TEST(RemoveBeamSearchAssociatedOps, basic) { "beam_search op should be removed from the graph.")); } +TEST(RemoveWriteReadArrayOps, basic) { + Layers layers; + auto* block = layers.Block(); + OpDesc* beam_search_op = block->AppendOp(); + beam_search_op->SetType("beam_search"); + beam_search_op->SetAttr("beam_size", 1); + + auto* write_x = layers.data("write_x", {1}, true); + auto* write_i = layers.data("write_i"); + auto* write_out = layers.write_to_array(write_x, write_i); + auto* read_i = layers.data("read_i"); + layers.read_from_array(write_out, read_i); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto* param_scope = new Scope(); + graph->Set("__param_scope__", param_scope); + AddVarToScope(param_scope, write_x->Name(), {1}); + auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass"); + pass->Apply(graph.get()); + auto write_read_num = GetNumOpNodes(graph, "write_to_array") + + GetNumOpNodes(graph, "read_from_array"); + PADDLE_ENFORCE_EQ(write_read_num, + 0, + platform::errors::PreconditionNotMet( + "write_to_array and read_from_array ops should be " + "removed from the graph.")); +} + +TEST(RemoveGatherOps, basic) { + Layers layers; + auto* block = layers.Block(); + OpDesc* beam_search_op = block->AppendOp(); + beam_search_op->SetType("beam_search"); + beam_search_op->SetAttr("beam_size", 1); + + auto* gather_x = layers.data("gather_x"); + auto* gather_i = layers.data("gather_i", {1}, true); + layers.gather(gather_x, gather_i, 0); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto* param_scope = new Scope(); + graph->Set("__param_scope__", param_scope); + AddVarToScope(param_scope, gather_i->Name(), {1}, 0); + auto pass = PassRegistry::Instance().Get("one_beam_size_fuse_pass"); + pass->Apply(graph.get()); + auto gather_num = GetNumOpNodes(graph, "gather"); + PADDLE_ENFORCE_EQ(gather_num, + 0, + platform::errors::PreconditionNotMet( + "gather op should be removed from the graph.")); +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/platform/device/xpu/xpu_info.cc b/paddle/fluid/platform/device/xpu/xpu_info.cc index 548fe89dc5ceab37461fda3ae115c9a4401d73f6..624a19dd7de91acc4ba0295ea012a54f30afac9e 100644 --- a/paddle/fluid/platform/device/xpu/xpu_info.cc +++ b/paddle/fluid/platform/device/xpu/xpu_info.cc @@ -57,7 +57,6 @@ void MemcpySyncH2D(void* dst, const platform::XPUPlace& dst_place) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.GetByPlace(dst_place); - dev_ctx->Wait(); phi::backends::xpu::MemcpySyncH2D(dst, src, count, dst_place, *dev_ctx); } @@ -67,7 +66,6 @@ void MemcpySyncD2H(void* dst, const platform::XPUPlace& src_place) { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.GetByPlace(src_place); - dev_ctx->Wait(); phi::backends::xpu::MemcpySyncD2H(dst, src, count, src_place, *dev_ctx); } diff --git a/paddle/phi/core/tensor_utils.cc b/paddle/phi/core/tensor_utils.cc index 79f8138222a505939997523ee6eb885f746eb5e4..c73d6dc369bb83d13dbfd23427697d7338e7777d 100644 --- a/paddle/phi/core/tensor_utils.cc +++ b/paddle/phi/core/tensor_utils.cc @@ -64,7 +64,6 @@ void Copy(const Context& dev_ctx, dst_ptr = dev_ctx.Alloc( dst, src.dtype(), 0, dst_place.GetType() == AllocationType::GPUPINNED); #endif - #ifdef PADDLE_WITH_XPU } else if (dst_place.GetType() == AllocationType::XPU) { dst_ptr = dev_ctx.Alloc(dst, src.dtype());