diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 23690cb879123198376314f0bf264be2b97393b5..b9a8dd984560770c76a8cfc2ce5e6a8875e55146 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -189,13 +189,25 @@ XPUDeviceContext::XPUDeviceContext(XPUPlace place) : place_(place) { "Baidu Kunlun Card is properly installed.", ret)); context_ = xpu::create_context(); - void* l3ptr = nullptr; - int l3_size = 13.5 * 1024 * 1024; - xpu_malloc(static_cast(&l3ptr), l3_size, XPU_MEM_L3); - if (l3ptr != nullptr) { - context_->_l3_mgr.set(l3ptr, l3_size); - std::cout << "set l3 size " << l3_size << std::endl; + const int MAX_XPU_NUM = 16; + const int l3_size = 13.5 * 1024 * 1024; + static void* l3ptrs[MAX_XPU_NUM] = {nullptr}; + + auto selected_xpus = GetXPUSelectedDevices(); + for (unsigned int i = 0; i < selected_xpus.size(); i++) { + if (place.device == selected_xpus[i]) { + if (l3ptrs[place.device] == nullptr) { + xpu_malloc(static_cast(&l3ptrs[place.device]), l3_size, + XPU_MEM_L3); + } + if (l3ptrs[place.device] != nullptr) { + context_->_l3_mgr.set(l3ptrs[place.device], l3_size); + VLOG(3) << "xpu place " << place.device << " set l3 size " << l3_size; + } + break; + } } + ret = xpu_set_device(dev_id); PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, platform::errors::External(