未验证 提交 c9309942 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] delete op device (#51029)

上级 af149c0c
......@@ -97,6 +97,7 @@ pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(trt_delete_weight_dequant_linear_op_pass inference)
pass_library(delete_op_device_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
......@@ -221,13 +222,16 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu)
pass_library(link_xpu_op_max_pass inference DIR xpu)
pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -372,6 +376,10 @@ cc_test(
test_generate_pass_cc
SRCS generate_pass_tester.cc
DEPS generate_pass pass_desc_proto)
cc_test(
test_delete_op_device_pass
SRCS delete_op_device_pass_test.cc
DEPS delete_op_device_pass)
cc_test(
test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
// "op_device" attr is only used in model training. "op_device" attr will change
// place of op kernel, so we use "delete_op_device_pass" to remove it.
class DeleteOpDevicePass : public Pass {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
int found_subgraph_count = 0;
for (auto* node : graph->Nodes()) {
if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue;
node->Op()->RemoveAttr("op_device");
found_subgraph_count++;
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- detected " << found_subgraph_count << " subgraphs";
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_op_device_pass, paddle::framework::ir::DeleteOpDevicePass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(delete_op_device_pass, relu) {
ProgramDesc program;
auto* x_var = program.MutableBlock(0)->Var("relu_x");
auto* out_var = program.MutableBlock(0)->Var("relu_out");
OpDesc* relu_op = program.MutableBlock(0)->AppendOp();
relu_op->SetType("relu");
relu_op->SetInput("X", {x_var->Name()});
relu_op->SetOutput("Out", {out_var->Name()});
relu_op->SetAttr("op_device", std::string{"gpu:0"});
std::unique_ptr<Graph> graph(new Graph(program));
auto pass = PassRegistry::Instance().Get("delete_op_device_pass");
graph.reset(pass->Apply(graph.release()));
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
if (node->Op()->Type() == "relu") {
PADDLE_ENFORCE(!node->Op()->HasAttr("op_device"),
platform::errors::InvalidArgument(
"Run delete_op_device_pass failed. Relu op still has "
"'op_device' attr."));
}
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_op_device_pass);
......@@ -49,7 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
};
"delete_op_device_pass"};
Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph";
......
......@@ -18,7 +18,6 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......
......@@ -17,7 +17,6 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......
......@@ -17,7 +17,6 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
......
......@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_slice_fuse_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_op_device_pass",
});
use_xpu_ = true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册