diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index d32f13e68e5582d029fe24d4d806d44c3595a954..e602b899fe62e63524b0e9e08a84216807ec92aa 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -97,6 +97,7 @@ pass_library(shuffle_channel_detect_pass inference) pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(trt_delete_weight_dequant_linear_op_pass inference) +pass_library(delete_op_device_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_dropout_op_pass inference) @@ -221,13 +222,16 @@ if(WITH_XPU) SRCS xpu/pass_utils.cc DEPS pass) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) - pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu) + pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(multi_encoder_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) - pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu) - pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu) - pass_library(link_xpu_op_max_pass inference DIR xpu) + pass_library(multi_encoder_xpu_slice_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) + pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) + pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) endif() cc_library( @@ -372,6 +376,10 @@ cc_test( test_generate_pass_cc SRCS generate_pass_tester.cc DEPS generate_pass pass_desc_proto) +cc_test( + test_delete_op_device_pass + SRCS delete_op_device_pass_test.cc + DEPS delete_op_device_pass) cc_test( test_delete_dropout_pass_cc SRCS delete_dropout_op_pass_test.cc diff --git a/paddle/fluid/framework/ir/delete_op_device_pass.cc b/paddle/fluid/framework/ir/delete_op_device_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..dfd174a442af639ae9c389895751c5b2bc3c858e --- /dev/null +++ b/paddle/fluid/framework/ir/delete_op_device_pass.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +// "op_device" attr is only used in model training. "op_device" attr will change +// place of op kernel, so we use "delete_op_device_pass" to remove it. +class DeleteOpDevicePass : public Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + int found_subgraph_count = 0; + for (auto* node : graph->Nodes()) { + if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue; + node->Op()->RemoveAttr("op_device"); + found_subgraph_count++; + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- detected " << found_subgraph_count << " subgraphs"; + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_op_device_pass, paddle::framework::ir::DeleteOpDevicePass); diff --git a/paddle/fluid/framework/ir/delete_op_device_pass_test.cc b/paddle/fluid/framework/ir/delete_op_device_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c88c3f4fa6a798ee353a37064f6c7fe22cc32f19 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_op_device_pass_test.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +TEST(delete_op_device_pass, relu) { + ProgramDesc program; + auto* x_var = program.MutableBlock(0)->Var("relu_x"); + auto* out_var = program.MutableBlock(0)->Var("relu_out"); + OpDesc* relu_op = program.MutableBlock(0)->AppendOp(); + relu_op->SetType("relu"); + relu_op->SetInput("X", {x_var->Name()}); + relu_op->SetOutput("Out", {out_var->Name()}); + relu_op->SetAttr("op_device", std::string{"gpu:0"}); + + std::unique_ptr graph(new Graph(program)); + auto pass = PassRegistry::Instance().Get("delete_op_device_pass"); + graph.reset(pass->Apply(graph.release())); + for (auto* node : graph->Nodes()) { + if (!node->IsOp()) continue; + if (node->Op()->Type() == "relu") { + PADDLE_ENFORCE(!node->Op()->HasAttr("op_device"), + platform::errors::InvalidArgument( + "Run delete_op_device_pass failed. Relu op still has " + "'op_device' attr.")); + } + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(delete_op_device_pass); diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index fbe2b3e748d40cf8e657a0525e36d1e85b619e8e..df15fd6d516a3dddb1bd58e8e718a243106330ea 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -49,7 +49,7 @@ static const std::vector support_subgraph_passes = { "fuse_multi_transformer_layer_pass", "delete_quant_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass", -}; + "delete_op_device_pass"}; Graph *Pass::Apply(Graph *graph) const { VLOG(10) << "start to apply pass " << Type() << " to graph"; diff --git a/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc index ed17144b6b6a17b34b64961e7b3f2df2fb2ea86c..7b40b67824d16f7c02acc0975b0a642933601ea8 100644 --- a/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/generate_sequence_xpu_fuse_pass.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" -#include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc index 86b6da2868714d51b2f75834f7bd5739f8eb0158..932d4ca7b8864688121b575f7309741d3caeb821 100644 --- a/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc +++ b/paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc @@ -17,7 +17,6 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" -#include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc index 64693ebd082d8e0c6c23946d7afe08184385462c..722ac525d41762e92f23da1c3185a104e3a3087b 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_slice_fuse_pass.cc @@ -17,7 +17,6 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/xpu/pass_utils.h" -#include "paddle/fluid/framework/ir/xpu/quant_utils.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 7680309744f0ba4e4a6c062573c037ed65419960..fa4224709756ec5b1e591aba0be1bed30c3c0f7a 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "multi_encoder_xpu_slice_fuse_pass", "fc_xpu_fuse_pass", "link_xpu_op_max_pass", + "delete_op_device_pass", }); use_xpu_ = true; }