From 599388e3cfeffca7b6a3f004e8237b213c956c3b Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Wed, 29 Mar 2023 13:15:12 +0800 Subject: [PATCH] [XPU] optimize pass (#52099) --- paddle/fluid/framework/ir/CMakeLists.txt | 5 + .../fluid/framework/ir/inplace_op_var_pass.cc | 125 +++++++++------ .../fluid/framework/ir/inplace_op_var_pass.h | 12 ++ paddle/fluid/framework/ir/pass.cc | 26 ++- .../fluid/framework/ir/pass_tester_helper.h | 14 ++ .../ir/xpu/delete_isolated_node_pass.cc | 9 +- .../fluid/framework/ir/xpu/stack_fuse_pass.cc | 148 ++++++++++++++++++ .../framework/ir/xpu/stack_fuse_pass_test.cc | 53 +++++++ .../inference/analysis/ir_pass_manager.cc | 4 +- .../fluid/inference/api/analysis_predictor.cc | 4 +- .../inference/api/paddle_pass_builder.cc | 5 +- paddle/phi/kernels/squeeze_kernel.cc | 3 + paddle/phi/kernels/unsqueeze_kernel.cc | 3 + 13 files changed, 351 insertions(+), 60 deletions(-) mode change 100755 => 100644 paddle/fluid/framework/ir/inplace_op_var_pass.cc mode change 100755 => 100644 paddle/fluid/framework/ir/inplace_op_var_pass.h create mode 100644 paddle/fluid/framework/ir/xpu/stack_fuse_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/stack_fuse_pass_test.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index be115c4d8e7..ebed80b6bc9 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -238,6 +238,7 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( @@ -512,4 +513,8 @@ if(WITH_XPU) test_one_beam_size_fuse_pass SRCS xpu/one_beam_size_fuse_pass_test.cc DEPS one_beam_size_fuse_pass) + cc_test( + test_stack_fuse_pass + SRCS xpu/stack_fuse_pass_test.cc + DEPS stack_fuse_pass) endif() diff --git a/paddle/fluid/framework/ir/inplace_op_var_pass.cc b/paddle/fluid/framework/ir/inplace_op_var_pass.cc old mode 100755 new mode 100644 index 0cd58963802..0ccac637be3 --- a/paddle/fluid/framework/ir/inplace_op_var_pass.cc +++ b/paddle/fluid/framework/ir/inplace_op_var_pass.cc @@ -25,62 +25,38 @@ namespace ir { class Graph; -void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { - FusePassBase::Init("inplace_op_var", graph); - int found_subgraph_count = 0; - - auto nodes = graph->Nodes(); - auto is_valid_reshape = [](Node* node) { - // Some cases need to consider, please refer to - // https://github.com/PaddlePaddle/Paddle/pull/49146 - if (node->IsOp() && node->Op()->Type() == "reshape2") { - auto x_name = node->Op()->Input("X").front(); - for (auto* var_node : node->inputs) { - if (var_node->Name() == x_name) { - if (!var_node->Var()->Persistable() && var_node->outputs.size() == 1) - return true; - } - } - } +bool InplaceOpVarPass::IsValidInplaceOp( + Node* node, const std::unordered_set& deny_var_names) const { + if (!node->IsOp() || inplace_ops_.count(node->Op()->Type()) == 0) return false; - }; - // Record all reshape2 op's input name and output name in block 0. - // If the name used in other block, we can not inplace reshape op. - std::unordered_set var_names, deny_var_names; - for (auto* node : nodes) { - if (is_valid_reshape(node)) { - for (auto n : node->inputs) var_names.insert(n->Name()); - for (auto n : node->outputs) var_names.insert(n->Name()); - } - } - for (size_t i = 1; i < graph->SubGraphsSize(); ++i) { - auto sub_graph = graph->GetSubGraph(i); - for (auto* node : sub_graph->Nodes()) { - if (node->IsOp()) { - for (auto var_node : node->inputs) { - if (var_names.count(var_node->Name())) - deny_var_names.insert(var_node->Name()); - } - for (auto var_node : node->outputs) { - if (var_names.count(var_node->Name())) - deny_var_names.insert(var_node->Name()); - } - } - } + // in_var_node should only has one out_op_node + auto x_name = node->Op()->Input("X").front(); + for (auto* var_node : node->inputs) { + if (var_node->Name() != x_name) continue; + if (var_node->Var()->Persistable() || var_node->outputs.size() != 1) + return false; } + // in/out_var_node should be not used in multi graphs. + auto out_name = node->Op()->Output("Out").front(); + if (deny_var_names.count(x_name) > 0 || deny_var_names.count(out_name) > 0) + return false; + + return true; +} + +int InplaceOpVarPass::ApplyImpl( + ir::Graph* graph, + const std::unordered_set& deny_var_names) const { + int found_subgraph_count = 0; // inplace all reshape op. auto topo_nodes = TopologySortOperations(*graph); for (auto* node : topo_nodes) { - if (!is_valid_reshape(node)) continue; + if (!IsValidInplaceOp(node, deny_var_names)) continue; auto* op_node = node->Op(); auto input_name = op_node->Input("X")[0]; auto output_name = op_node->Output("Out")[0]; - if (deny_var_names.count(input_name) || deny_var_names.count(output_name)) { - continue; - } - ++found_subgraph_count; for (auto* out_var : node->outputs) { if (out_var->Name() == output_name) { out_var->RenameVar(input_name); @@ -90,9 +66,50 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { } } } - op_node->RenameOutput(output_name, input_name); op_node->Flush(); + found_subgraph_count++; + } + return found_subgraph_count; +} + +std::vector InplaceOpVarPass::GetControlFlowVarNames( + ir::Graph* graph) const { + std::vector control_flow_var_names; + for (auto* node : graph->Nodes()) { + if (!node->IsOp() || control_flow_ops_.count(node->Op()->Type()) == 0) + continue; + for (auto in_names : node->Op()->Inputs()) { + auto var_names = in_names.second; + control_flow_var_names.insert( + control_flow_var_names.end(), var_names.begin(), var_names.end()); + } + for (auto out_names : node->Op()->Outputs()) { + auto var_names = out_names.second; + control_flow_var_names.insert( + control_flow_var_names.end(), var_names.begin(), var_names.end()); + } + } + return control_flow_var_names; +} + +void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { + FusePassBase::Init("inplace_op_var", graph); + if (!graph->IsMainGraph()) { + VLOG(3) << "Pass(apply in main graph) will work on all subgraphs."; + return; + } + + std::unordered_set deny_var_names; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + auto control_flow_var_names = GetControlFlowVarNames(graph->GetSubGraph(i)); + deny_var_names.insert(control_flow_var_names.begin(), + control_flow_var_names.end()); + } + + int found_subgraph_count = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += ApplyImpl(graph->GetSubGraph(i), deny_var_names); } AddStatis(found_subgraph_count); } @@ -105,4 +122,16 @@ REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass); REGISTER_PASS_CAPABILITY(inplace_op_var_pass) .AddCombination( paddle::framework::compatible::OpVersionComparatorCombination().EQ( - "reshape2", 0)); + "reshape2", 0)) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "unsqueeze2", 0)) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "unsqueeze", 0)) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "squeeze2", 0)) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "squeeze", 0)); diff --git a/paddle/fluid/framework/ir/inplace_op_var_pass.h b/paddle/fluid/framework/ir/inplace_op_var_pass.h old mode 100755 new mode 100644 index 0a579d1c2f6..50c9c502915 --- a/paddle/fluid/framework/ir/inplace_op_var_pass.h +++ b/paddle/fluid/framework/ir/inplace_op_var_pass.h @@ -28,6 +28,18 @@ class InplaceOpVarPass : public FusePassBase { private: virtual ~InplaceOpVarPass() = default; + + int ApplyImpl(ir::Graph* graph, + const std::unordered_set& deny_var_names) const; + + bool IsValidInplaceOp( + Node* node, const std::unordered_set& deny_var_names) const; + + std::vector GetControlFlowVarNames(ir::Graph* graph) const; + + std::set inplace_ops_{ + "reshape", "unsqueeze", "unsqueeze2", "squeeze", "squeeze2"}; + std::set control_flow_ops_{"while", "conditional_block"}; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index f51e028f965..b48b3606594 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -49,10 +49,22 @@ static const std::vector support_subgraph_passes = { "fuse_multi_transformer_layer_pass", "delete_quant_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass", +}; + +static const std::vector xpu_support_subgraph_passes = { + "delete_dropout_op_pass", + "identity_scale_op_clean_pass", + "delete_op_device_pass", + "constant_folding_pass", + "generate_sequence_xpu_fuse_pass", + "embedding_with_eltwise_add_xpu_fuse_pass", + "multi_encoder_xpu_fuse_pass", + "multi_encoder_xpu_slice_fuse_pass", "one_beam_size_fuse_pass", + "stack_fuse_pass", "fused_multi_transformer_xpu_quant_pass", "fc_xpu_fuse_pass", - "delete_op_device_pass", + "link_xpu_op_max_pass", }; Graph *Pass::Apply(Graph *graph) const { @@ -90,9 +102,15 @@ Graph *Pass::Apply(Graph *graph) const { } graph->Get(kPassRecorder).insert(Type()); - if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(), - support_subgraph_passes.end(), - Type())) { + std::vector subgraph_passes; + bool use_xpu = Has("use_xpu") && Get("use_xpu"); + if (use_xpu) { + subgraph_passes = xpu_support_subgraph_passes; + } else { + subgraph_passes = support_subgraph_passes; + } + if (graph->IsMainGraph() && + std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) { for (size_t i = 1; i < graph->SubGraphsSize(); i++) { auto *sub_graph = graph->GetSubGraph(i); if (!sub_graph->Has(framework::ir::kParamScopeAttr)) { diff --git a/paddle/fluid/framework/ir/pass_tester_helper.h b/paddle/fluid/framework/ir/pass_tester_helper.h index 10bb062e8ea..1547615be37 100644 --- a/paddle/fluid/framework/ir/pass_tester_helper.h +++ b/paddle/fluid/framework/ir/pass_tester_helper.h @@ -807,6 +807,20 @@ struct Layers { return unary_op("logical_not", input); } + VarDesc* stack(std::vector inputs, int axis = -1) { + VarDesc* out = lod_tensor(unique_name()); + OpDesc* op = program_.MutableBlock(0)->AppendOp(); + op->SetType("stack"); + std::vector input_names; + for (auto* input : inputs) { + input_names.push_back(input->Name()); + } + op->SetInput("X", input_names); + op->SetAttr("axis", axis); + op->SetOutput("Y", {out->Name()}); + return out; + } + private: VarDesc* lod_tensor(std::string name, std::vector shape = {}, diff --git a/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc b/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc index 41a822e3e2b..c1137c319d2 100644 --- a/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc +++ b/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc @@ -60,10 +60,11 @@ class DeleteIsolatedNodePass : public Pass { void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); - PADDLE_ENFORCE(graph->IsMainGraph(), - platform::errors::PreconditionNotMet( - "Pass(apply in main graph) will delete isolated nodes in " - "all subgraphs. Do not apply pass in subgraph.")); + if (!graph->IsMainGraph()) { + VLOG(3) << "Pass(apply in main graph) will delete isolated nodes in all " + "subgraphs."; + return; + } std::unordered_set reserved_persistable_node_names; for (size_t i = 0; i < graph->SubGraphsSize(); i++) { diff --git a/paddle/fluid/framework/ir/xpu/stack_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/stack_fuse_pass.cc new file mode 100644 index 00000000000..a128011159d --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/stack_fuse_pass.cc @@ -0,0 +1,148 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct StackPattern : public PatternBase { + StackPattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(stack); + // declare variable node's name + PATTERN_DECL_NODE(stack_out); +}; + +StackPattern::StackPattern(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* stack = pattern->NewNode(stack_repr()) + ->assert_is_op("stack") + ->assert_more([](Node* node) { + auto input_names = node->Op()->Input("X"); + auto first_name = input_names[0]; + for (auto name : input_names) { + if (name != first_name) return false; + } + return true; + }); + auto* stack_out = pattern->NewNode(stack_out_repr()) + ->assert_is_op_output("stack", "Y") + ->assert_more([](Node* node) { + std::map support_out_ops{ + {"elementwise_add", "Y"}, + {"fused_multi_transformer", "SrcMask"}}; + auto var_name = node->Name(); + for (auto* out_node : node->outputs) { + auto op_type = out_node->Op()->Type(); + if (support_out_ops.count(op_type) == 0) + return false; + auto out_op_in_names = + out_node->Op()->Input(support_out_ops[op_type]); + if (std::find(out_op_in_names.begin(), + out_op_in_names.end(), + var_name) == out_op_in_names.end()) + return false; + } + return true; + }); + stack->LinksTo({stack_out}); +} + +} // namespace patterns + +/* +"stack" can be replaced by "unsqueeze" if: +1. "stack inputs" are the same。 +1. "stack output" is "elementwise_add input" or "fused_multi_transformer +src_mask input". +*/ +class StackFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"stack_fuse_pass"}; +}; + +void StackFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + GraphPatternDetector gpd; + patterns::StackPattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle StackFusePass fuse"; + GET_IR_NODE(stack); + GET_IR_NODE(stack_out); + + stack->RenameOp("unsqueeze2"); + auto* op_desc = stack->Op(); + int axis = op_desc->GetAttrIfExists("axis"); + op_desc->SetAttr("axes", std::vector{axis}); + op_desc->RemoveAttr("axis"); + + op_desc->MutableInputs()->at("X").resize(1); + auto* stack_in = stack->inputs[0]; + IR_NODE_UNLINK(stack_in, stack); + IR_NODE_LINK_TO(stack_in, stack); + + auto* outputs = op_desc->MutableOutputs(); + (*outputs)["Out"] = outputs->at("Y"); + outputs->erase("Y"); + + auto stack_out_shape = stack_out->Var()->GetShape(); + if (axis < 0) axis += stack_out_shape.size(); + stack_out_shape[axis] = 1; + stack_out->Var()->SetShape(stack_out_shape); + + found_subgraph_count++; + }; + + gpd(graph, handler); + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(stack_fuse_pass, paddle::framework::ir::StackFusePass); + +REGISTER_PASS_CAPABILITY(stack_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "stack", 0)); diff --git a/paddle/fluid/framework/ir/xpu/stack_fuse_pass_test.cc b/paddle/fluid/framework/ir/xpu/stack_fuse_pass_test.cc new file mode 100644 index 00000000000..6b4026d0023 --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/stack_fuse_pass_test.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(StackFusePass, basic) { + Layers layers; + auto* block = layers.Block(); + + auto* stack_x = layers.data("stack_x", {-1, 64, 64}); + auto* stack_out = layers.stack({stack_x, stack_x, stack_x}, 1); + stack_out->SetShape({-1, 3, 64, 64}); + auto* add_x = layers.data("add_x", {-1, 24, 64, 64}); + layers.elementwise_add(add_x, stack_out); + + OpDesc* fused_multi_transformer_op = block->AppendOp(); + fused_multi_transformer_op->SetType("fused_multi_transformer"); + fused_multi_transformer_op->SetInput("SrcMask", {stack_out->Name()}); + + std::unique_ptr graph(new ir::Graph(layers.main_program())); + auto pass = PassRegistry::Instance().Get("stack_fuse_pass"); + pass->Apply(graph.get()); + auto stack_num = GetNumOpNodes(graph, "stack"); + PADDLE_ENFORCE_EQ(stack_num, + 0, + platform::errors::PreconditionNotMet( + "stack op should be removed from graph, but graph " + "still has %d stack op.", + stack_num)); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(stack_fuse_pass); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index 33562d339a5..9dc5f0f961b 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -105,6 +105,9 @@ void IRPassManager::CreatePasses(Argument *argument, new int(argument->mixed_precision_mode())); pass->Set("model_precision", new int(argument->model_precision())); + // "use_xpu" is used for passes in subgraphs. + pass->Set("use_xpu", new bool(argument->use_xpu())); + if (pass_name == "graph_viz_pass") { std::string optim_cache_dir = argument->optim_cache_dir(); std::string dot_file_path; @@ -260,7 +263,6 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("enable_int8", new bool(lite_enable_int8)); pass->Set("use_gpu", new bool(argument->use_gpu())); pass->Set("zero_copy", new bool(argument->lite_zero_copy())); - pass->Set("use_xpu", new bool(argument->use_xpu())); pass->Set("xpu_l3_workspace_size", new int(argument->xpu_l3_workspace_size())); pass->Set("use_opencl", new bool(argument->use_opencl())); diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index c3fc6667581..91dee8a9ae4 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1201,6 +1201,7 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_); } + argument_->SetUseXpu(config_.use_xpu_); if (config_.lite_engine_enabled()) { argument_->SetCpuMathLibraryNumThreads( config_.cpu_math_library_num_threads()); @@ -1208,7 +1209,6 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetLitePassesFilter(config_.lite_passes_filter_); argument_->SetLiteOpsFilter(config_.lite_ops_filter_); argument_->SetLiteZeroCopy(config_.lite_zero_copy_); - argument_->SetUseXpu(config_.use_xpu_); argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_); argument_->SetXpuLocked(config_.xpu_locked_); argument_->SetXpuAutotune(config_.xpu_autotune_); @@ -1316,7 +1316,7 @@ void AnalysisPredictor::PrepareArgument() { // processed in a single if (model_precision_ != phi::DataType::FLOAT32) { LOG(INFO) << "Model is mixed precision type with " << model_precision_ - << ", we will use a new PassStrategy. Note that only the GPU " + << ", we will use a new PassStrategy. Note that only GPU/XPU " "backend is supported for now."; if (!config_.use_cinn_compiler_) { const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 43fa40c4fa7..2e2a896f754 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -520,15 +520,18 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { passes_.assign({ "delete_dropout_op_pass", "identity_scale_op_clean_pass", + "delete_op_device_pass", + "constant_folding_pass", "generate_sequence_xpu_fuse_pass", "embedding_with_eltwise_add_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_slice_fuse_pass", "one_beam_size_fuse_pass", + "stack_fuse_pass", "fused_multi_transformer_xpu_quant_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", - "delete_op_device_pass", + "inplace_op_var_pass", "delete_isolated_node_pass", }); use_xpu_ = true; diff --git a/paddle/phi/kernels/squeeze_kernel.cc b/paddle/phi/kernels/squeeze_kernel.cc index d36e42c8126..a0b72381601 100644 --- a/paddle/phi/kernels/squeeze_kernel.cc +++ b/paddle/phi/kernels/squeeze_kernel.cc @@ -27,6 +27,9 @@ void SqueezeInferKernel(const Context& dev_ctx, DenseTensor* out) { auto out_dims = out->dims(); dev_ctx.template Alloc(out); + if (x.Holder() == out->Holder()) { + return; + } phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); // copy will reset the dims. } diff --git a/paddle/phi/kernels/unsqueeze_kernel.cc b/paddle/phi/kernels/unsqueeze_kernel.cc index 4008d7883d4..4354b09c753 100644 --- a/paddle/phi/kernels/unsqueeze_kernel.cc +++ b/paddle/phi/kernels/unsqueeze_kernel.cc @@ -32,6 +32,9 @@ void UnsqueezeInferKernel(const Context& dev_ctx, } out->Resize(out_dims); dev_ctx.template Alloc(out); + if (x.Holder() == out->Holder()) { + return; + } phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); out->Resize(out_dims); // copy will reset the dims. } -- GitLab