未验证 提交 8f8a6848 编写于 作者: L Leo Chen 提交者: GitHub

unify cpu context (#44049)

上级 2b0c22ad
......@@ -117,20 +117,20 @@ void ScaleAPI(const paddle::experimental::Tensor& x,
paddle::platform::DeviceContextPool::Instance();
if (expected_kernel_place == paddle::platform::CPUPlace()) {
auto* dev_ctx = dynamic_cast<paddle::platform::CPUDeviceContext*>(
pool.Get(expected_kernel_place));
auto* dev_ctx =
dynamic_cast<phi::CPUContext*>(pool.Get(expected_kernel_place));
if (!dev_ctx) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"Cannot convert device_context to CPUDeviceContext."
"Cannot convert device_context to phi::CPUContext."
"This indicates backend mismatch."
"Pleas double check your expected place"));
}
ScaleDeviceDispatch<paddle::platform::CPUDeviceContext>(*dense_tensor.get(),
*dev_ctx,
scale,
bias,
bias_after_scale,
dense_out.get());
ScaleDeviceDispatch<phi::CPUContext>(*dense_tensor.get(),
*dev_ctx,
scale,
bias,
bias_after_scale,
dense_out.get());
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
} else if (expected_kernel_place == paddle::platform::CUDAPlace()) {
......
......@@ -133,7 +133,6 @@ constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMLU = DeviceType::MLU;
using DeviceContext = phi::DeviceContext;
using CPUDeviceContext = phi::CPUContext;
template <typename Place>
struct DefaultDeviceContextType;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册