diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 03b4801e2cada8e8fad0a2f6e8a3957d658f60be..32d140c0e182244cbe98ada4f591d8cf68287981 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -27,19 +27,17 @@ ELSEIF(WITH_CENTOS) SET(XPU_XRE_DIR_NAME "xre-centos7_x86_64") SET(XPU_XDNN_DIR_NAME "xdnn-centos7_x86_64") SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") + ELSE () SET(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64") SET(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64") SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") ENDIF() -IF(NOT XPU_BASE_URL) - SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev/20210527") -ENDIF() - -SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) -SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) -SET(XPU_XCCL_URL "${XPU_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") +SET(XPU_XRE_URL "${XPU_BASE_URL}/20210625/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_XDNN_URL "${XPU_BASE_URL}/20210625/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_XCCL_URL "${XPU_BASE_URL}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_PACK_DEPENCE_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index 647b7cb34f6598bafd519b3dbf8d4dbcba40fafa..eba30ff8edebf9b4fd0b101c45a13c0a9086e42b 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -33,7 +33,8 @@ AmpOperators::AmpOperators() for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) { bool supported = false; for (auto& kernel_type : it->second) { - if (platform::is_gpu_place(kernel_type.first.place_) && + if ((platform::is_gpu_place(kernel_type.first.place_) || + platform::is_xpu_place(kernel_type.first.place_)) && kernel_type.first.data_type_ == fp16_dtype) { supported = true; } @@ -91,7 +92,8 @@ inline std::string GetDtypeStr( inline bool NeedCast(const std::shared_ptr& var) { if (platform::is_gpu_place(var->Place()) || - platform::is_cuda_pinned_place(var->Place())) { + platform::is_cuda_pinned_place(var->Place()) || + platform::is_xpu_place(var->Place())) { // CudaPinndePlace is added for varbase created by dataloader if (var->DataType() == framework::proto::VarType::FP32 || var->DataType() == framework::proto::VarType::FP16) { diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc index ca15858cf67d756fc8eb41f4e26a2e0b923abef6..c7c0f81f2131f73d0d9f89a7871550aab38cece8 100644 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -23,21 +23,9 @@ limitations under the License. */ namespace paddle { namespace operators { -template -class XPUFPTypeTrait { - public: - using Type = T; -}; - -template <> -class XPUFPTypeTrait { - public: - using Type = float16; -}; - template class CastXPUKernel : public framework::OpKernel { - using XPUInTDType = typename XPUFPTypeTrait::Type; + using XPUInTDType = typename XPUTypeTrait::Type; public: void Compute(const framework::ExecutionContext& context) const override { @@ -49,7 +37,6 @@ class CastXPUKernel : public framework::OpKernel { context.Attr("out_dtype")); auto* in_data = in->data(); - // using XPUOutTDType = typename XPUFPTypeTrait::Type; auto numel = in->numel(); auto& dev_ctx = context.template device_context(); int r = -1; diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 6fa96aca4be147e9d70c6e62500acaae88822315..7097b5327d86fab115ff85fd114dce6dd9e5ae2f 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -102,6 +102,7 @@ template static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, bool trans_x, bool trans_y, const paddle::framework::ExecutionContext &ctx) { + using XPUType = typename XPUTypeTrait::Type; const auto &x_dims = x->dims(); const auto &y_dims = y->dims(); auto &dev_ctx = @@ -162,34 +163,36 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, int ldout = n; if (batch_size <= 1) { int r = 0; - r = xpu::fc_fusion( - dev_ctx.x_context(), x->data(), y->data(), data_c, m, n, k, - mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, - ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); + r = xpu::fc_fusion( + dev_ctx.x_context(), reinterpret_cast(x->data()), + reinterpret_cast(y->data()), + reinterpret_cast(data_c), m, n, k, mat_dim_a.trans_, + mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, ldout, alpha, 0, + nullptr, xpu::Activation_t::LINEAR); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU fc_fusion kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( @@ -210,10 +213,14 @@ class MatMulXPUKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); bool trans_x = context.Attr("transpose_X"); bool trans_y = context.Attr("transpose_Y"); - if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, context); - } else { + if (std::is_same::value) { MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } else { + MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } } } }; @@ -224,6 +231,7 @@ class MatMulXPUKernel : public framework::OpKernel { template static framework::Tensor XPUFoldHeadAndLastDims( const DeviceContext &context, const framework::Tensor &input) { + using XPUType = typename XPUTypeTrait::Type; auto in_dims = input.dims(); if (in_dims.size() != 3) { return input; @@ -236,8 +244,9 @@ static framework::Tensor XPUFoldHeadAndLastDims( static_cast(in_dims[1]), static_cast(in_dims[2])}; std::vector axis_host = {1, 0, 2}; - int r = xpu::transpose(context.x_context(), input.data(), output.data(), - in_shape_host, axis_host); + int r = xpu::transpose( + context.x_context(), reinterpret_cast(input.data()), + reinterpret_cast(output.data()), in_shape_host, axis_host); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU transpose kernel return wrong value[%d %s]", r, @@ -280,10 +289,14 @@ class MatMulGradXPUKernel : public framework::OpKernel { const framework::Tensor &b, bool trans_b, framework::Tensor *out) const { out->mutable_data(context.GetPlace()); - if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); - } else { + if (std::is_same::value) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } else { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } } } @@ -370,10 +383,14 @@ class MatMulGradXPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL( - matmul, ops::MatMulXPUKernel); + matmul, ops::MatMulXPUKernel, + ops::MatMulXPUKernel); REGISTER_OP_XPU_KERNEL( matmul_grad, - ops::MatMulGradXPUKernel); + ops::MatMulGradXPUKernel, + ops::MatMulGradXPUKernel); #endif diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index d992ef847db2aca8bc284781fdd1408d36bd14e5..ae1e9358f68115e4952696325051d142a25789f8 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -25,6 +25,7 @@ template static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, bool trans_x, bool trans_y, const paddle::framework::ExecutionContext& ctx) { + using XPUType = typename XPUTypeTrait::Type; const auto& x_dims = x->dims(); const auto& y_dims = y->dims(); auto& dev_ctx = @@ -75,9 +76,11 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, int batch_size = mat_dim_a.batch_size_; if (batch_size <= 1) { int r = 0; - r = xpu::fc(dev_ctx.x_context(), x->data(), y->data(), - data_c, m, n, k, mat_dim_a.trans_, - mat_dim_b.trans_, nullptr, nullptr, nullptr); + r = xpu::fc( + dev_ctx.x_context(), reinterpret_cast(x->data()), + reinterpret_cast(y->data()), + reinterpret_cast(data_c), m, n, k, mat_dim_a.trans_, + mat_dim_b.trans_, nullptr, nullptr, nullptr); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( @@ -87,24 +90,24 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, r, XPUAPIErrorMsg[r], m, n, k, mat_dim_a.trans_, mat_dim_b.trans_)); } else { // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - 1.0, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + 1.0, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( @@ -123,10 +126,14 @@ class MatMulV2XPUKernel : public framework::OpKernel { bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); out->mutable_data(ctx.GetPlace()); - if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else { + if (std::is_same::value) { MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } } } }; @@ -134,6 +141,7 @@ class MatMulV2XPUKernel : public framework::OpKernel { template static framework::Tensor XPUFoldHeadAndLastDims( const DeviceContext& context, const framework::Tensor& input) { + using XPUType = typename XPUTypeTrait::Type; auto in_dims = input.dims(); if (in_dims.size() != 3) { return input; @@ -147,8 +155,9 @@ static framework::Tensor XPUFoldHeadAndLastDims( static_cast(in_dims[2])}; std::vector axis_host = {1, 0, 2}; - int r = xpu::transpose(context.x_context(), input.data(), output.data(), - in_shape_host, axis_host); + int r = xpu::transpose( + context.x_context(), reinterpret_cast(input.data()), + reinterpret_cast(output.data()), in_shape_host, axis_host); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU transpose kernel return wrong value[%d %s]", r, @@ -166,10 +175,14 @@ class MatMulV2XPUGradKernel : public framework::OpKernel { const framework::Tensor& b, bool trans_b, framework::Tensor* out) const { out->mutable_data(ctx.GetPlace()); - if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else { + if (std::is_same::value) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } } } @@ -261,8 +274,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel); -REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel); +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel, + ops::MatMulV2XPUKernel); +REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel, + ops::MatMulV2XPUGradKernel); #endif diff --git a/paddle/fluid/operators/softmax_op_xpu.cc b/paddle/fluid/operators/softmax_op_xpu.cc index ed7034ef6ab416a4e98ddcd02f045af459298d65..3527478f7661058e193d14d95f815beb28f1e92a 100644 --- a/paddle/fluid/operators/softmax_op_xpu.cc +++ b/paddle/fluid/operators/softmax_op_xpu.cc @@ -47,8 +47,8 @@ class SoftmaxXPUKernel : public framework::OpKernel { int len = x->numel(); T* clip_x_data = clip_x.mutable_data(context.GetPlace(), len * sizeof(T)); - r = xpu::clip(dev_ctx.x_context(), x->data(), clip_x_data, len, - -1e30, 1e30); + r = xpu::clip_v2(dev_ctx.x_context(), x->data(), clip_x_data, len, + static_cast(-1e20), static_cast(1e20)); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External("XPU API(clip) return wrong " "value[%d %s]", diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 8635def2ecf138550bf02f0013b31b59647777b9..a79e31eb8d028d3d319176e397ba5da9da54cd0e 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -54,8 +54,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { int len = logits->numel(); T* clip_logits_data = clip_logits.mutable_data(context.GetPlace(), len * sizeof(T)); - r = xpu::clip(dev_ctx.x_context(), logits->data(), clip_logits_data, - len, -1e30, 1e30); + r = xpu::clip_v2(dev_ctx.x_context(), logits->data(), + clip_logits_data, len, static_cast(-1e20), + static_cast(1e20)); PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error. clip " diff --git a/paddle/fluid/platform/xpu_header.h b/paddle/fluid/platform/xpu_header.h index 9f2befc123f224aeda3cb4a3d196cbce470d51b2..99f4224b5d408a6450d801ff643f658b74333387 100644 --- a/paddle/fluid/platform/xpu_header.h +++ b/paddle/fluid/platform/xpu_header.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/float16.h" #include "xpu/api.h" #include "xpu/refactor/fusion.h" #include "xpu/refactor/math.h" @@ -58,4 +59,16 @@ static std::map XPUAPIErrorMsg = { {xpu::Error_t::RUNTIME_ERROR, "xpu api runtime error"}, {xpu::Error_t::NO_ENOUGH_WORKSPACE, "xpu api no enough workspace"}}; +template +class XPUTypeTrait { + public: + using Type = T; +}; + +template <> +class XPUTypeTrait { + public: + using Type = float16; +}; + #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 67f004e61cbfdf5212e47a35a2779589f80f4393..883ade66d4fd9eee7f64f0d720d0e3da525ac368 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -225,7 +225,9 @@ OpSupportedInfos(const std::string &place, [](unsigned char c) { return std::toupper(c); }); using fn_type = std::add_pointer::type; std::unordered_map is_target_place{ - {"GPU", &platform::is_gpu_place}, {"CPU", &platform::is_cpu_place}, + {"GPU", &platform::is_gpu_place}, + {"CPU", &platform::is_cpu_place}, + {"XPU", &platform::is_xpu_place}, }; PADDLE_ENFORCE_NE( is_target_place.count(query_place), 0, diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index 5cfa77b3d9a4f8a0ea5a74c8746d51fb8312c72f..44f8e5027fba8e4ed72a972bb2c050bdc52e072c 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -14,6 +14,7 @@ import copy from ... import core +import paddle.fluid as fluid __all__ = ["CustomOpLists", "AutoMixedPrecisionLists"] @@ -152,8 +153,14 @@ gray_list = { # The set of ops that don't support fp16 calculation # lookup_table fp16 is slower than fp32, though fp16 is supported. -_, _, _sys_unsupported_fp16_list = core.op_supported_infos( - 'GPU', core.VarDesc.VarType.FP16) +_sys_unsupported_fp16_list = [] +if fluid.is_compiled_with_xpu(): + _, _, _sys_unsupported_fp16_list = core.op_supported_infos( + 'XPU', core.VarDesc.VarType.FP16) +else: + _, _, _sys_unsupported_fp16_list = core.op_supported_infos( + 'GPU', core.VarDesc.VarType.FP16) + unsupported_fp16_list = {'lookup_table', 'lookup_table_v2'} | _sys_unsupported_fp16_list diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index b14b2be7394a1c036b031a4a1f662c229402ce9c..7af8c18e33f8f7c6bf1d4adc927dba035be612d2 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -130,9 +130,10 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): raise ValueError( "current_tracer is None, maybe it is not in imperative mode.") - if enable and not tracer._expected_place.is_gpu_place(): + if enable and not (tracer._expected_place.is_gpu_place() or + tracer._expected_place.is_xpu_place()): warnings.warn( - 'amp_guard can only be enabled on CUDAPlace, current place is %s, so it makes no effect.' + 'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' % tracer._expected_place) enable = False diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index ff57f30dcd2ec73d55ff06e751767deea0a2eead..e0bd60fbeb4a73db22ce1141750c2cf22ac29288 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -90,9 +90,10 @@ class AmpScaler(object): raise ValueError( "current_tracer is None, maybe it is not in imperative mode.") - if enable and not tracer._expected_place.is_gpu_place(): + if enable and not (tracer._expected_place.is_gpu_place() or + tracer._expected_place.is_xpu_place()): warnings.warn( - 'AmpScaler can only be enabled on CUDAPlace, current place is %s, so it makes no effect.' + 'AmpScaler can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' % tracer._expected_place) enable = False