From 8f8a68485543ad735e0ab212283264d8eaa50898 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Mon, 4 Jul 2022 06:54:10 -0500 Subject: [PATCH] unify cpu context (#44049) --- .../eager_generated/backwards/scale_node.cc | 18 +++++++++--------- paddle/fluid/platform/device_context.h | 1 - 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc index 4ee33ad100f..1409119daf1 100644 --- a/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc +++ b/paddle/fluid/eager/api/generated/eager_generated/backwards/scale_node.cc @@ -117,20 +117,20 @@ void ScaleAPI(const paddle::experimental::Tensor& x, paddle::platform::DeviceContextPool::Instance(); if (expected_kernel_place == paddle::platform::CPUPlace()) { - auto* dev_ctx = dynamic_cast( - pool.Get(expected_kernel_place)); + auto* dev_ctx = + dynamic_cast(pool.Get(expected_kernel_place)); if (!dev_ctx) { PADDLE_THROW(paddle::platform::errors::Fatal( - "Cannot convert device_context to CPUDeviceContext." + "Cannot convert device_context to phi::CPUContext." "This indicates backend mismatch." "Pleas double check your expected place")); } - ScaleDeviceDispatch(*dense_tensor.get(), - *dev_ctx, - scale, - bias, - bias_after_scale, - dense_out.get()); + ScaleDeviceDispatch(*dense_tensor.get(), + *dev_ctx, + scale, + bias, + bias_after_scale, + dense_out.get()); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } else if (expected_kernel_place == paddle::platform::CUDAPlace()) { diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 1b7aafdac6f..4459c913f00 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -133,7 +133,6 @@ constexpr DeviceType kIPU = DeviceType::IPU; constexpr DeviceType kMLU = DeviceType::MLU; using DeviceContext = phi::DeviceContext; -using CPUDeviceContext = phi::CPUContext; template struct DefaultDeviceContextType; -- GitLab