From 77606f5d67d1107056f82289d41df260ae87dbc2 Mon Sep 17 00:00:00 2001 From: shentanyue <34421038+shentanyue@users.noreply.github.com> Date: Mon, 20 Feb 2023 15:09:13 +0800 Subject: [PATCH] [XPU] fix fc_xpu_fuse_pass (#50569) --- paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc | 2 ++ paddle/fluid/framework/ir/xpu/quant_utils.cc | 12 ++++++++++-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index c7cc1dfc07f..22a7229b70e 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -310,6 +310,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, delete_nodes = {mul, mul_out, act}; } else if (add) { delete_nodes = {mul, mul_out, add}; + } else { + delete_nodes = {mul}; } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index 3c249f6995d..b1aaace6952 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -160,8 +160,16 @@ void QuantWeight(phi::DenseTensor* weight, weight->Resize(weight_trans.dims()); } // Find max - auto* xpu_ctx = static_cast( - platform::DeviceContextPool::Instance().Get(phi::XPUPlace())); + paddle::platform::DeviceContextPool& pool = + paddle::platform::DeviceContextPool::Instance(); + const auto& dev_ctxs = pool.device_contexts(); + auto place = phi::XPUPlace(); // xpu:0 + for (auto it = dev_ctxs.begin(); it != dev_ctxs.end(); it++) { + if (it->first.GetType() == phi::AllocationType::XPU) { // maybe xpu:1 + place = it->first; + } + } + phi::XPUContext* xpu_ctx = static_cast(pool.Get(place)); int max_ptr_size = xpu_ctx->x_context()->max_ptr_size(); int size = weight->numel(); float max_val = FindMaxAbs(weight_data, size); -- GitLab