未验证 提交 77606f5d 编写于 作者: S shentanyue 提交者: GitHub

[XPU] fix fc_xpu_fuse_pass (#50569)

上级 2e07c8b7
......@@ -310,6 +310,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
delete_nodes = {mul, mul_out, act};
} else if (add) {
delete_nodes = {mul, mul_out, add};
} else {
delete_nodes = {mul};
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
......
......@@ -160,8 +160,16 @@ void QuantWeight(phi::DenseTensor* weight,
weight->Resize(weight_trans.dims());
}
// Find max
auto* xpu_ctx = static_cast<phi::XPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::XPUPlace()));
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
const auto& dev_ctxs = pool.device_contexts();
auto place = phi::XPUPlace(); // xpu:0
for (auto it = dev_ctxs.begin(); it != dev_ctxs.end(); it++) {
if (it->first.GetType() == phi::AllocationType::XPU) { // maybe xpu:1
place = it->first;
}
}
phi::XPUContext* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
int size = weight->numel();
float max_val = FindMaxAbs(weight_data, size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册