未验证 提交 8f8a6848 编写于 作者: L Leo Chen 提交者: GitHub

unify cpu context (#44049)

上级 2b0c22ad
...@@ -117,15 +117,15 @@ void ScaleAPI(const paddle::experimental::Tensor& x, ...@@ -117,15 +117,15 @@ void ScaleAPI(const paddle::experimental::Tensor& x,
paddle::platform::DeviceContextPool::Instance(); paddle::platform::DeviceContextPool::Instance();
if (expected_kernel_place == paddle::platform::CPUPlace()) { if (expected_kernel_place == paddle::platform::CPUPlace()) {
auto* dev_ctx = dynamic_cast<paddle::platform::CPUDeviceContext*>( auto* dev_ctx =
pool.Get(expected_kernel_place)); dynamic_cast<phi::CPUContext*>(pool.Get(expected_kernel_place));
if (!dev_ctx) { if (!dev_ctx) {
PADDLE_THROW(paddle::platform::errors::Fatal( PADDLE_THROW(paddle::platform::errors::Fatal(
"Cannot convert device_context to CPUDeviceContext." "Cannot convert device_context to phi::CPUContext."
"This indicates backend mismatch." "This indicates backend mismatch."
"Pleas double check your expected place")); "Pleas double check your expected place"));
} }
ScaleDeviceDispatch<paddle::platform::CPUDeviceContext>(*dense_tensor.get(), ScaleDeviceDispatch<phi::CPUContext>(*dense_tensor.get(),
*dev_ctx, *dev_ctx,
scale, scale,
bias, bias,
......
...@@ -133,7 +133,6 @@ constexpr DeviceType kIPU = DeviceType::IPU; ...@@ -133,7 +133,6 @@ constexpr DeviceType kIPU = DeviceType::IPU;
constexpr DeviceType kMLU = DeviceType::MLU; constexpr DeviceType kMLU = DeviceType::MLU;
using DeviceContext = phi::DeviceContext; using DeviceContext = phi::DeviceContext;
using CPUDeviceContext = phi::CPUContext;
template <typename Place> template <typename Place>
struct DefaultDeviceContextType; struct DefaultDeviceContextType;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册