未验证 提交 599388e3 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] optimize pass (#52099)

上级 ad76d37e
...@@ -238,6 +238,7 @@ if(WITH_XPU) ...@@ -238,6 +238,7 @@ if(WITH_XPU)
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif() endif()
cc_library( cc_library(
...@@ -512,4 +513,8 @@ if(WITH_XPU) ...@@ -512,4 +513,8 @@ if(WITH_XPU)
test_one_beam_size_fuse_pass test_one_beam_size_fuse_pass
SRCS xpu/one_beam_size_fuse_pass_test.cc SRCS xpu/one_beam_size_fuse_pass_test.cc
DEPS one_beam_size_fuse_pass) DEPS one_beam_size_fuse_pass)
cc_test(
test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass)
endif() endif()
...@@ -25,62 +25,38 @@ namespace ir { ...@@ -25,62 +25,38 @@ namespace ir {
class Graph; class Graph;
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { bool InplaceOpVarPass::IsValidInplaceOp(
FusePassBase::Init("inplace_op_var", graph); Node* node, const std::unordered_set<std::string>& deny_var_names) const {
int found_subgraph_count = 0; if (!node->IsOp() || inplace_ops_.count(node->Op()->Type()) == 0)
auto nodes = graph->Nodes();
auto is_valid_reshape = [](Node* node) {
// Some cases need to consider, please refer to
// https://github.com/PaddlePaddle/Paddle/pull/49146
if (node->IsOp() && node->Op()->Type() == "reshape2") {
auto x_name = node->Op()->Input("X").front();
for (auto* var_node : node->inputs) {
if (var_node->Name() == x_name) {
if (!var_node->Var()->Persistable() && var_node->outputs.size() == 1)
return true;
}
}
}
return false; return false;
};
// Record all reshape2 op's input name and output name in block 0. // in_var_node should only has one out_op_node
// If the name used in other block, we can not inplace reshape op. auto x_name = node->Op()->Input("X").front();
std::unordered_set<std::string> var_names, deny_var_names; for (auto* var_node : node->inputs) {
for (auto* node : nodes) { if (var_node->Name() != x_name) continue;
if (is_valid_reshape(node)) { if (var_node->Var()->Persistable() || var_node->outputs.size() != 1)
for (auto n : node->inputs) var_names.insert(n->Name()); return false;
for (auto n : node->outputs) var_names.insert(n->Name());
}
}
for (size_t i = 1; i < graph->SubGraphsSize(); ++i) {
auto sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsOp()) {
for (auto var_node : node->inputs) {
if (var_names.count(var_node->Name()))
deny_var_names.insert(var_node->Name());
}
for (auto var_node : node->outputs) {
if (var_names.count(var_node->Name()))
deny_var_names.insert(var_node->Name());
}
}
}
} }
// in/out_var_node should be not used in multi graphs.
auto out_name = node->Op()->Output("Out").front();
if (deny_var_names.count(x_name) > 0 || deny_var_names.count(out_name) > 0)
return false;
return true;
}
int InplaceOpVarPass::ApplyImpl(
ir::Graph* graph,
const std::unordered_set<std::string>& deny_var_names) const {
int found_subgraph_count = 0;
// inplace all reshape op. // inplace all reshape op.
auto topo_nodes = TopologySortOperations(*graph); auto topo_nodes = TopologySortOperations(*graph);
for (auto* node : topo_nodes) { for (auto* node : topo_nodes) {
if (!is_valid_reshape(node)) continue; if (!IsValidInplaceOp(node, deny_var_names)) continue;
auto* op_node = node->Op(); auto* op_node = node->Op();
auto input_name = op_node->Input("X")[0]; auto input_name = op_node->Input("X")[0];
auto output_name = op_node->Output("Out")[0]; auto output_name = op_node->Output("Out")[0];
if (deny_var_names.count(input_name) || deny_var_names.count(output_name)) {
continue;
}
++found_subgraph_count;
for (auto* out_var : node->outputs) { for (auto* out_var : node->outputs) {
if (out_var->Name() == output_name) { if (out_var->Name() == output_name) {
out_var->RenameVar(input_name); out_var->RenameVar(input_name);
...@@ -90,9 +66,50 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const { ...@@ -90,9 +66,50 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
} }
} }
} }
op_node->RenameOutput(output_name, input_name); op_node->RenameOutput(output_name, input_name);
op_node->Flush(); op_node->Flush();
found_subgraph_count++;
}
return found_subgraph_count;
}
std::vector<std::string> InplaceOpVarPass::GetControlFlowVarNames(
ir::Graph* graph) const {
std::vector<std::string> control_flow_var_names;
for (auto* node : graph->Nodes()) {
if (!node->IsOp() || control_flow_ops_.count(node->Op()->Type()) == 0)
continue;
for (auto in_names : node->Op()->Inputs()) {
auto var_names = in_names.second;
control_flow_var_names.insert(
control_flow_var_names.end(), var_names.begin(), var_names.end());
}
for (auto out_names : node->Op()->Outputs()) {
auto var_names = out_names.second;
control_flow_var_names.insert(
control_flow_var_names.end(), var_names.begin(), var_names.end());
}
}
return control_flow_var_names;
}
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("inplace_op_var", graph);
if (!graph->IsMainGraph()) {
VLOG(3) << "Pass(apply in main graph) will work on all subgraphs.";
return;
}
std::unordered_set<std::string> deny_var_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
auto control_flow_var_names = GetControlFlowVarNames(graph->GetSubGraph(i));
deny_var_names.insert(control_flow_var_names.begin(),
control_flow_var_names.end());
}
int found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyImpl(graph->GetSubGraph(i), deny_var_names);
} }
AddStatis(found_subgraph_count); AddStatis(found_subgraph_count);
} }
...@@ -105,4 +122,16 @@ REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass); ...@@ -105,4 +122,16 @@ REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass);
REGISTER_PASS_CAPABILITY(inplace_op_var_pass) REGISTER_PASS_CAPABILITY(inplace_op_var_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ( paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"reshape2", 0)); "reshape2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"unsqueeze2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"unsqueeze", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"squeeze2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"squeeze", 0));
...@@ -28,6 +28,18 @@ class InplaceOpVarPass : public FusePassBase { ...@@ -28,6 +28,18 @@ class InplaceOpVarPass : public FusePassBase {
private: private:
virtual ~InplaceOpVarPass() = default; virtual ~InplaceOpVarPass() = default;
int ApplyImpl(ir::Graph* graph,
const std::unordered_set<std::string>& deny_var_names) const;
bool IsValidInplaceOp(
Node* node, const std::unordered_set<std::string>& deny_var_names) const;
std::vector<std::string> GetControlFlowVarNames(ir::Graph* graph) const;
std::set<std::string> inplace_ops_{
"reshape", "unsqueeze", "unsqueeze2", "squeeze", "squeeze2"};
std::set<std::string> control_flow_ops_{"while", "conditional_block"};
}; };
} // namespace ir } // namespace ir
......
...@@ -49,10 +49,22 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -49,10 +49,22 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass", "fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass",
};
static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_dropout_op_pass",
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass", "one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass", "fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"delete_op_device_pass", "link_xpu_op_max_pass",
}; };
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
...@@ -90,9 +102,15 @@ Graph *Pass::Apply(Graph *graph) const { ...@@ -90,9 +102,15 @@ Graph *Pass::Apply(Graph *graph) const {
} }
graph->Get<PassRecorder>(kPassRecorder).insert(Type()); graph->Get<PassRecorder>(kPassRecorder).insert(Type());
if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(), std::vector<std::string> subgraph_passes;
support_subgraph_passes.end(), bool use_xpu = Has("use_xpu") && Get<bool>("use_xpu");
Type())) { if (use_xpu) {
subgraph_passes = xpu_support_subgraph_passes;
} else {
subgraph_passes = support_subgraph_passes;
}
if (graph->IsMainGraph() &&
std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) { for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i); auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) { if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
......
...@@ -807,6 +807,20 @@ struct Layers { ...@@ -807,6 +807,20 @@ struct Layers {
return unary_op("logical_not", input); return unary_op("logical_not", input);
} }
VarDesc* stack(std::vector<VarDesc*> inputs, int axis = -1) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("stack");
std::vector<std::string> input_names;
for (auto* input : inputs) {
input_names.push_back(input->Name());
}
op->SetInput("X", input_names);
op->SetAttr("axis", axis);
op->SetOutput("Y", {out->Name()});
return out;
}
private: private:
VarDesc* lod_tensor(std::string name, VarDesc* lod_tensor(std::string name,
std::vector<int64_t> shape = {}, std::vector<int64_t> shape = {},
......
...@@ -60,10 +60,11 @@ class DeleteIsolatedNodePass : public Pass { ...@@ -60,10 +60,11 @@ class DeleteIsolatedNodePass : public Pass {
void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const { void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
PADDLE_ENFORCE(graph->IsMainGraph(), if (!graph->IsMainGraph()) {
platform::errors::PreconditionNotMet( VLOG(3) << "Pass(apply in main graph) will delete isolated nodes in all "
"Pass(apply in main graph) will delete isolated nodes in " "subgraphs.";
"all subgraphs. Do not apply pass in subgraph.")); return;
}
std::unordered_set<std::string> reserved_persistable_node_names; std::unordered_set<std::string> reserved_persistable_node_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) { for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct StackPattern : public PatternBase {
StackPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(stack);
// declare variable node's name
PATTERN_DECL_NODE(stack_out);
};
StackPattern::StackPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* stack = pattern->NewNode(stack_repr())
->assert_is_op("stack")
->assert_more([](Node* node) {
auto input_names = node->Op()->Input("X");
auto first_name = input_names[0];
for (auto name : input_names) {
if (name != first_name) return false;
}
return true;
});
auto* stack_out = pattern->NewNode(stack_out_repr())
->assert_is_op_output("stack", "Y")
->assert_more([](Node* node) {
std::map<std::string, std::string> support_out_ops{
{"elementwise_add", "Y"},
{"fused_multi_transformer", "SrcMask"}};
auto var_name = node->Name();
for (auto* out_node : node->outputs) {
auto op_type = out_node->Op()->Type();
if (support_out_ops.count(op_type) == 0)
return false;
auto out_op_in_names =
out_node->Op()->Input(support_out_ops[op_type]);
if (std::find(out_op_in_names.begin(),
out_op_in_names.end(),
var_name) == out_op_in_names.end())
return false;
}
return true;
});
stack->LinksTo({stack_out});
}
} // namespace patterns
/*
"stack" can be replaced by "unsqueeze" if:
1. "stack inputs" are the same。
1. "stack output" is "elementwise_add input" or "fused_multi_transformer
src_mask input".
*/
class StackFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"stack_fuse_pass"};
};
void StackFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::StackPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle StackFusePass fuse";
GET_IR_NODE(stack);
GET_IR_NODE(stack_out);
stack->RenameOp("unsqueeze2");
auto* op_desc = stack->Op();
int axis = op_desc->GetAttrIfExists<int>("axis");
op_desc->SetAttr("axes", std::vector<int>{axis});
op_desc->RemoveAttr("axis");
op_desc->MutableInputs()->at("X").resize(1);
auto* stack_in = stack->inputs[0];
IR_NODE_UNLINK(stack_in, stack);
IR_NODE_LINK_TO(stack_in, stack);
auto* outputs = op_desc->MutableOutputs();
(*outputs)["Out"] = outputs->at("Y");
outputs->erase("Y");
auto stack_out_shape = stack_out->Var()->GetShape();
if (axis < 0) axis += stack_out_shape.size();
stack_out_shape[axis] = 1;
stack_out->Var()->SetShape(stack_out_shape);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(stack_fuse_pass, paddle::framework::ir::StackFusePass);
REGISTER_PASS_CAPABILITY(stack_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"stack", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(StackFusePass, basic) {
Layers layers;
auto* block = layers.Block();
auto* stack_x = layers.data("stack_x", {-1, 64, 64});
auto* stack_out = layers.stack({stack_x, stack_x, stack_x}, 1);
stack_out->SetShape({-1, 3, 64, 64});
auto* add_x = layers.data("add_x", {-1, 24, 64, 64});
layers.elementwise_add(add_x, stack_out);
OpDesc* fused_multi_transformer_op = block->AppendOp();
fused_multi_transformer_op->SetType("fused_multi_transformer");
fused_multi_transformer_op->SetInput("SrcMask", {stack_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("stack_fuse_pass");
pass->Apply(graph.get());
auto stack_num = GetNumOpNodes(graph, "stack");
PADDLE_ENFORCE_EQ(stack_num,
0,
platform::errors::PreconditionNotMet(
"stack op should be removed from graph, but graph "
"still has %d stack op.",
stack_num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(stack_fuse_pass);
...@@ -105,6 +105,9 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -105,6 +105,9 @@ void IRPassManager::CreatePasses(Argument *argument,
new int(argument->mixed_precision_mode())); new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision())); pass->Set("model_precision", new int(argument->model_precision()));
// "use_xpu" is used for passes in subgraphs.
pass->Set("use_xpu", new bool(argument->use_xpu()));
if (pass_name == "graph_viz_pass") { if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir(); std::string optim_cache_dir = argument->optim_cache_dir();
std::string dot_file_path; std::string dot_file_path;
...@@ -260,7 +263,6 @@ void IRPassManager::CreatePasses(Argument *argument, ...@@ -260,7 +263,6 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("enable_int8", new bool(lite_enable_int8)); pass->Set("enable_int8", new bool(lite_enable_int8));
pass->Set("use_gpu", new bool(argument->use_gpu())); pass->Set("use_gpu", new bool(argument->use_gpu()));
pass->Set("zero_copy", new bool(argument->lite_zero_copy())); pass->Set("zero_copy", new bool(argument->lite_zero_copy()));
pass->Set("use_xpu", new bool(argument->use_xpu()));
pass->Set("xpu_l3_workspace_size", pass->Set("xpu_l3_workspace_size",
new int(argument->xpu_l3_workspace_size())); new int(argument->xpu_l3_workspace_size()));
pass->Set("use_opencl", new bool(argument->use_opencl())); pass->Set("use_opencl", new bool(argument->use_opencl()));
......
...@@ -1201,6 +1201,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1201,6 +1201,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_); argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
} }
argument_->SetUseXpu(config_.use_xpu_);
if (config_.lite_engine_enabled()) { if (config_.lite_engine_enabled()) {
argument_->SetCpuMathLibraryNumThreads( argument_->SetCpuMathLibraryNumThreads(
config_.cpu_math_library_num_threads()); config_.cpu_math_library_num_threads());
...@@ -1208,7 +1209,6 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1208,7 +1209,6 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetLitePassesFilter(config_.lite_passes_filter_); argument_->SetLitePassesFilter(config_.lite_passes_filter_);
argument_->SetLiteOpsFilter(config_.lite_ops_filter_); argument_->SetLiteOpsFilter(config_.lite_ops_filter_);
argument_->SetLiteZeroCopy(config_.lite_zero_copy_); argument_->SetLiteZeroCopy(config_.lite_zero_copy_);
argument_->SetUseXpu(config_.use_xpu_);
argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_); argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
argument_->SetXpuLocked(config_.xpu_locked_); argument_->SetXpuLocked(config_.xpu_locked_);
argument_->SetXpuAutotune(config_.xpu_autotune_); argument_->SetXpuAutotune(config_.xpu_autotune_);
...@@ -1316,7 +1316,7 @@ void AnalysisPredictor::PrepareArgument() { ...@@ -1316,7 +1316,7 @@ void AnalysisPredictor::PrepareArgument() {
// processed in a single // processed in a single
if (model_precision_ != phi::DataType::FLOAT32) { if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_ LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU " << ", we will use a new PassStrategy. Note that only GPU/XPU "
"backend is supported for now."; "backend is supported for now.";
if (!config_.use_cinn_compiler_) { if (!config_.use_cinn_compiler_) {
const auto &deleted_passes = pass_builder->GetAllDeletedPasses(); const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
......
...@@ -520,15 +520,18 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -520,15 +520,18 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"delete_dropout_op_pass", "delete_dropout_op_pass",
"identity_scale_op_clean_pass", "identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"generate_sequence_xpu_fuse_pass", "generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass", "embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass", "multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass", "multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass", "one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass", "fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
"delete_op_device_pass", "inplace_op_var_pass",
"delete_isolated_node_pass", "delete_isolated_node_pass",
}); });
use_xpu_ = true; use_xpu_ = true;
......
...@@ -27,6 +27,9 @@ void SqueezeInferKernel(const Context& dev_ctx, ...@@ -27,6 +27,9 @@ void SqueezeInferKernel(const Context& dev_ctx,
DenseTensor* out) { DenseTensor* out) {
auto out_dims = out->dims(); auto out_dims = out->dims();
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (x.Holder() == out->Holder()) {
return;
}
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims. out->Resize(out_dims); // copy will reset the dims.
} }
......
...@@ -32,6 +32,9 @@ void UnsqueezeInferKernel(const Context& dev_ctx, ...@@ -32,6 +32,9 @@ void UnsqueezeInferKernel(const Context& dev_ctx,
} }
out->Resize(out_dims); out->Resize(out_dims);
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
if (x.Holder() == out->Holder()) {
return;
}
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims. out->Resize(out_dims); // copy will reset the dims.
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册