未验证 提交 77606f5d 编写于 作者: S shentanyue 提交者: GitHub

[XPU] fix fc_xpu_fuse_pass (#50569)

上级 2e07c8b7
...@@ -310,6 +310,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -310,6 +310,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
delete_nodes = {mul, mul_out, act}; delete_nodes = {mul, mul_out, act};
} else if (add) { } else if (add) {
delete_nodes = {mul, mul_out, add}; delete_nodes = {mul, mul_out, add};
} else {
delete_nodes = {mul};
} }
GraphSafeRemoveNodes(graph, delete_nodes); GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++; found_subgraph_count++;
......
...@@ -160,8 +160,16 @@ void QuantWeight(phi::DenseTensor* weight, ...@@ -160,8 +160,16 @@ void QuantWeight(phi::DenseTensor* weight,
weight->Resize(weight_trans.dims()); weight->Resize(weight_trans.dims());
} }
// Find max // Find max
auto* xpu_ctx = static_cast<phi::XPUContext*>( paddle::platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance().Get(phi::XPUPlace())); paddle::platform::DeviceContextPool::Instance();
const auto& dev_ctxs = pool.device_contexts();
auto place = phi::XPUPlace(); // xpu:0
for (auto it = dev_ctxs.begin(); it != dev_ctxs.end(); it++) {
if (it->first.GetType() == phi::AllocationType::XPU) { // maybe xpu:1
place = it->first;
}
}
phi::XPUContext* xpu_ctx = static_cast<phi::XPUContext*>(pool.Get(place));
int max_ptr_size = xpu_ctx->x_context()->max_ptr_size(); int max_ptr_size = xpu_ctx->x_context()->max_ptr_size();
int size = weight->numel(); int size = weight->numel();
float max_val = FindMaxAbs(weight_data, size); float max_val = FindMaxAbs(weight_data, size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册