diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index a8c33618a61359e01e89399ceb0546a208179691..8d202b5a99bfc0e4065be9ea54466b8e086e0841 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -27,19 +27,18 @@ ELSEIF(WITH_CENTOS) SET(XPU_XRE_DIR_NAME "xre-centos7_x86_64") SET(XPU_XDNN_DIR_NAME "xdnn-centos7_x86_64") SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") + ELSE () SET(XPU_XRE_DIR_NAME "xre-ubuntu_x86_64") SET(XPU_XDNN_DIR_NAME "xdnn-ubuntu_x86_64") SET(XPU_XCCL_DIR_NAME "xccl-bdcentos_x86_64") ENDIF() -IF(NOT XPU_BASE_URL) - SET(XPU_BASE_URL "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev/20210527") -ENDIF() - +SET(XPU_BASE_URL_WITHOUT_DATE "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") +SET(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210701") SET(XPU_XRE_URL "${XPU_BASE_URL}/${XPU_XRE_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_XDNN_URL "${XPU_BASE_URL}/${XPU_XDNN_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) -SET(XPU_XCCL_URL "${XPU_BASE_URL}/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) +SET(XPU_XCCL_URL "${XPU_BASE_URL_WITHOUT_DATE}/20210623/${XPU_XCCL_DIR_NAME}.tar.gz" CACHE STRING "" FORCE) SET(XPU_PACK_DEPENCE_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/pack_paddle_depence.sh" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") @@ -96,7 +95,11 @@ ELSE(WITH_XPU_BKCL) TARGET_LINK_LIBRARIES(xpulib ${XPU_API_LIB} ${XPU_RT_LIB}) ENDIF(WITH_XPU_BKCL) -ADD_DEPENDENCIES(xpulib ${XPU_PROJECT}) +if(NOT XPU_SDK_ROOT) + ADD_DEPENDENCIES(xpulib ${XPU_PROJECT}) +else() + ADD_CUSTOM_TARGET(extern_xpu DEPENDS xpulib) +endif() # Ensure that xpu/api.h can be included without dependency errors. file(GENERATE OUTPUT ${CMAKE_CURRENT_BINARY_DIR}/.xpu_headers_dummy.cc CONTENT "") diff --git a/paddle/fluid/imperative/amp_auto_cast.cc b/paddle/fluid/imperative/amp_auto_cast.cc index b4154737e0fbc6245617fb0208f6623e4ebb5943..d67a5483155414f232ef1d7ecd9129808fa3bac9 100644 --- a/paddle/fluid/imperative/amp_auto_cast.cc +++ b/paddle/fluid/imperative/amp_auto_cast.cc @@ -33,7 +33,8 @@ AmpOperators::AmpOperators() for (auto it = all_kernels.begin(); it != all_kernels.end(); it++) { bool supported = false; for (auto& kernel_type : it->second) { - if (platform::is_gpu_place(kernel_type.first.place_) && + if ((platform::is_gpu_place(kernel_type.first.place_) || + platform::is_xpu_place(kernel_type.first.place_)) && kernel_type.first.data_type_ == fp16_dtype) { supported = true; } @@ -91,7 +92,8 @@ inline std::string GetDtypeStr( inline bool NeedCast(const std::shared_ptr& var) { if (platform::is_gpu_place(var->Place()) || - platform::is_cuda_pinned_place(var->Place())) { + platform::is_cuda_pinned_place(var->Place()) || + platform::is_xpu_place(var->Place())) { // CudaPinndePlace is added for varbase created by dataloader if (var->DataType() == framework::proto::VarType::FP32 || var->DataType() == framework::proto::VarType::FP16) { diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_xpu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..210f3e098f95f490f9c5d4adf53d9ee4f20f3e97 --- /dev/null +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_xpu.cc @@ -0,0 +1,170 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/float16.h" +namespace paddle { +namespace operators { +template +class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + using XPUTyp = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto& dev_ctx = ctx.template device_context(); + const auto xs = ctx.MultiInput("X"); + const auto* scale = ctx.Input("Scale"); + auto outs = ctx.MultiOutput("Out"); + auto* found_inf = ctx.Output("FoundInfinite"); + + const MPDType* scale_data = scale->data(); + bool* found_inf_data = found_inf->mutable_data(dev_ctx.GetPlace()); + + // cpy to cpu + bool cpu_found_inf_data = false; + + MPDType cpu_scale_data; + if (platform::is_xpu_place(scale->place())) { + xpu_memcpy(&cpu_scale_data, scale_data, sizeof(MPDType), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } else { + cpu_scale_data = (*scale_data); + } + MPDType inverse_scale = 1.0 / cpu_scale_data; + for (size_t i = 0; i < xs.size(); ++i) { + const auto* x = xs[i]; + auto* out = outs[i]; + out->mutable_data(dev_ctx.GetPlace()); + framework::Tensor is_finite = + ctx.AllocateTmpTensor(x->dims(), + dev_ctx); + framework::Tensor is_nan = + ctx.AllocateTmpTensor(x->dims(), + dev_ctx); + framework::Tensor is_finite_and_nan = + ctx.AllocateTmpTensor(x->dims(), + dev_ctx); + if (cpu_found_inf_data == false) { + int r = xpu::isfinite(dev_ctx.x_context(), + reinterpret_cast(x->data()), + is_finite.data(), x->numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(isfinite) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::logical_not(dev_ctx.x_context(), reinterpret_cast( + is_finite.data()), + is_finite.data(), x->numel()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(logical_not) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::isnan(dev_ctx.x_context(), + reinterpret_cast(x->data()), + is_nan.data(), x->numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(isnan) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::logical_or(dev_ctx.x_context(), is_finite.data(), + is_nan.data(), is_finite.data(), + x->numel()); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External("XPU API(logical_or) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::any(dev_ctx.x_context(), is_finite.data(), + found_inf_data, x->numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(any) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + memory::Copy(platform::CPUPlace(), &cpu_found_inf_data, + BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()), + found_inf_data, sizeof(bool)); + } + + if (cpu_found_inf_data) { + inverse_scale = 0.0; + } + auto dev_env = XPUEnv::getenv("XPUSIM_DEVICE_MODEL"); + + if (std::is_same::value && + (dev_env == nullptr || std::strcmp(dev_env, "KUNLUN1"))) { + framework::Tensor float_x; + framework::Tensor float_out; + float_x.mutable_data(dev_ctx.GetPlace(), + x->numel() * sizeof(MPDType)); + float_out.mutable_data(dev_ctx.GetPlace(), + out->numel() * sizeof(MPDType)); + int r = xpu::cast_v2(dev_ctx.x_context(), + reinterpret_cast(x->data()), + float_x.data(), x->numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(cast_v2) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + + r = xpu::scale(dev_ctx.x_context(), float_x.data(), + float_out.data(), x->numel(), false, + inverse_scale, 0.0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(scale) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + + r = xpu::cast_v2(dev_ctx.x_context(), float_out.data(), + reinterpret_cast(out->data()), + out->numel()); + + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(cast_v2) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + if (dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } + + } else { + int r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(x->data()), + reinterpret_cast(out->data()), + x->numel(), false, inverse_scale, 0.0); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(scale) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + } + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()), + found_inf_data, platform::CPUPlace(), &cpu_found_inf_data, + sizeof(bool)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL(check_finite_and_unscale, + ops::CheckFiniteAndUnscaleXPUKernel, + ops::CheckFiniteAndUnscaleXPUKernel); + +#endif diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_xpu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..1f05e5f246d9c564dbf53b121b07ff4beb84c686 --- /dev/null +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_xpu.cc @@ -0,0 +1,166 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/amp/update_loss_scaling_op.h" +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/float16.h" + +namespace paddle { +namespace operators { + +template +class UpdateLossScalingXPUKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + using XPUTyp = typename XPUTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + + const auto xs = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); + const auto* found_inf = ctx.Input("FoundInfinite"); + PADDLE_ENFORCE_EQ(found_inf->numel(), 1, + platform::errors::InvalidArgument( + "FoundInfinite must has only one element.")); + const bool* found_inf_data = found_inf->data(); + bool cpu_found_inf_data = false; + if (platform::is_xpu_place(found_inf->place())) { + xpu_memcpy(&cpu_found_inf_data, found_inf_data, sizeof(bool), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } else { + cpu_found_inf_data = (*found_inf_data); + } + + for (size_t i = 0; i < xs.size(); ++i) { + auto* out = outs[i]; + T* out_data = out->mutable_data(dev_ctx.GetPlace()); + int num = out->numel(); + if (cpu_found_inf_data) { + VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --"; + int r = 0; + r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(out_data), num, + XPUTyp(0.0)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(constant) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + } + } + const bool stop_update = ctx.Attr("stop_update"); + if (stop_update) { + return; + } + + const auto* pre_loss_scaling = ctx.Input("PrevLossScaling"); + const auto* good_in = ctx.Input("InGoodSteps"); + const auto* bad_in = ctx.Input("InBadSteps"); + auto* updated_loss_scaling = ctx.Output("LossScaling"); + auto* good_out = ctx.Output("OutGoodSteps"); + auto* bad_out = ctx.Output("OutBadSteps"); + const MPDType* pre_loss_scaling_data = pre_loss_scaling->data(); + const int* good_in_data = good_in->data(); + const int* bad_in_data = bad_in->data(); + + MPDType* updated_loss_scaling_data = + updated_loss_scaling->mutable_data(dev_ctx.GetPlace()); + int* good_out_data = good_out->mutable_data(dev_ctx.GetPlace()); + int* bad_out_data = bad_out->mutable_data(dev_ctx.GetPlace()); + + const int incr_every_n_steps = ctx.Attr("incr_every_n_steps"); + const int decr_every_n_nan_or_inf = + ctx.Attr("decr_every_n_nan_or_inf"); + const float incr_ratio = ctx.Attr("incr_ratio"); + const float decr_ratio = ctx.Attr("decr_ratio"); + + int cpu_bad_in_data; + int cpu_good_in_data; + MPDType cpu_pre_loss_scaling_data; + if (platform::is_xpu_place(bad_in->place())) { + xpu_memcpy(&cpu_bad_in_data, bad_in_data, sizeof(int), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } else { + cpu_bad_in_data = (*bad_in_data); + } + + if (platform::is_xpu_place(good_in->place())) { + xpu_memcpy(&cpu_good_in_data, good_in_data, sizeof(int), + XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } else { + cpu_good_in_data = (*good_in_data); + } + + if (platform::is_xpu_place(pre_loss_scaling->place())) { + xpu_memcpy(&cpu_pre_loss_scaling_data, pre_loss_scaling_data, + sizeof(MPDType), XPUMemcpyKind::XPU_DEVICE_TO_HOST); + } else { + cpu_pre_loss_scaling_data = (*pre_loss_scaling_data); + } + + int cpu_good_out_data = 0; + int cpu_bad_out_data = 0; + MPDType cpu_updated_loss_scaling_data; + + if (cpu_found_inf_data) { + cpu_good_out_data = 0; + cpu_bad_out_data = cpu_bad_in_data + 1; + if (cpu_bad_out_data == decr_every_n_nan_or_inf) { + MPDType new_loss_scaling = cpu_pre_loss_scaling_data * decr_ratio; + cpu_updated_loss_scaling_data = + (new_loss_scaling < static_cast(1)) + ? (static_cast(1)) + : (new_loss_scaling); + cpu_bad_out_data = 0; + } + } else { + cpu_bad_out_data = 0; + cpu_good_out_data = cpu_good_in_data + 1; + if (cpu_good_out_data == incr_every_n_steps) { + MPDType new_loss_scaling = cpu_pre_loss_scaling_data * incr_ratio; + cpu_updated_loss_scaling_data = (std::isfinite(new_loss_scaling)) + ? new_loss_scaling + : cpu_pre_loss_scaling_data; + cpu_good_out_data = 0; + } + } + + // copy to host + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()), + bad_out_data, platform::CPUPlace(), &cpu_bad_out_data, + sizeof(int)); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()), + good_out_data, platform::CPUPlace(), &cpu_good_out_data, + sizeof(int)); + memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()), + updated_loss_scaling_data, platform::CPUPlace(), + &cpu_updated_loss_scaling_data, sizeof(MPDType)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_XPU_KERNEL(update_loss_scaling, + ops::UpdateLossScalingXPUKernel, + ops::UpdateLossScalingXPUKernel); +#endif diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc index ca15858cf67d756fc8eb41f4e26a2e0b923abef6..c7c0f81f2131f73d0d9f89a7871550aab38cece8 100644 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -23,21 +23,9 @@ limitations under the License. */ namespace paddle { namespace operators { -template -class XPUFPTypeTrait { - public: - using Type = T; -}; - -template <> -class XPUFPTypeTrait { - public: - using Type = float16; -}; - template class CastXPUKernel : public framework::OpKernel { - using XPUInTDType = typename XPUFPTypeTrait::Type; + using XPUInTDType = typename XPUTypeTrait::Type; public: void Compute(const framework::ExecutionContext& context) const override { @@ -49,7 +37,6 @@ class CastXPUKernel : public framework::OpKernel { context.Attr("out_dtype")); auto* in_data = in->data(); - // using XPUOutTDType = typename XPUFPTypeTrait::Type; auto numel = in->numel(); auto& dev_ctx = context.template device_context(); int r = -1; diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index f5d831fa24012031897eca2ce5a1cd9004f5a03b..79d239074845ad29f4f40e64a7d1ecc9f19168bb 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -16,11 +16,11 @@ namespace paddle { namespace operators { #ifdef PADDLE_WITH_XPU -static std::map mask_data_tables; -static const int max_data_size = 32 * 1024 * 1024; -static std::mutex s_mask_data_table_lock; + template class DropoutXPUKernel : public framework::OpKernel { + using XPUTyp = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); @@ -30,93 +30,70 @@ class DropoutXPUKernel : public framework::OpKernel { float dropout_prob = context.Attr("dropout_prob"); auto dropout_implementation = context.Attr("dropout_implementation"); - float* mask_data_table = nullptr; + auto& dev_ctx = context.template device_context(); + PADDLE_ENFORCE_EQ(!context.HasInput("Seed"), true, platform::errors::InvalidArgument( ("Input(Seed) not supported on XPU"))); + int is_upscale = (dropout_implementation == "upscale_in_train"); + if (!context.Attr("is_test")) { - int dev_id = - BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()).GetDeviceId(); - int prop = static_cast(dropout_prob * 100); - int is_upscale = (dropout_implementation == "upscale_in_train"); - /* mask_data_tables key contains 3 part: - * | 31-16 | 15-8 | 7-0 | - * | dev_id | prob | is_upscale | - */ - int index = (dev_id << 16) + (prop << 8) + is_upscale; - std::lock_guard lock(s_mask_data_table_lock); - if (mask_data_tables.find(index) == mask_data_tables.end()) { - float* mask_data_host = new float[max_data_size]; - std::random_device rnd; - std::minstd_rand engine; - int seed = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - engine.seed(seed); - std::uniform_real_distribution dist(0, 1); - for (size_t i = 0; i < max_data_size; ++i) { - if (dist(engine) < dropout_prob) { - mask_data_host[i] = 0.0f; - } else { - if (is_upscale) { - mask_data_host[i] = 1.0f / static_cast(1.0f - dropout_prob); - } else { - mask_data_host[i] = 1.0; - } - } - } - PADDLE_ENFORCE_EQ( - xpu_malloc(reinterpret_cast(&mask_data_table), - max_data_size * sizeof(float)), - XPU_SUCCESS, - platform::errors::ResourceExhausted( - "\n\nOut of memory error on XPU, Cannot" - "allocate %s memory on XPU. \n\nPlease " - "check whether there is any other process " - "using XPU.\n", - string::HumanReadableSize(max_data_size * sizeof(void*)))); - memory::Copy(BOOST_GET_CONST(platform::XPUPlace, context.GetPlace()), - mask_data_table, platform::CPUPlace(), mask_data_host, - max_data_size * sizeof(float)); - mask_data_tables[index] = mask_data_table; - free(mask_data_host); + std::random_device rnd; + // int seed = (context.Attr("fix_seed")) ? + // int(context.Attr("seed")) : (rnd()); + int seed = 0; + if (context.Attr("fix_seed") == true) { + seed = static_cast(context.Attr("seed")); } else { - mask_data_table = mask_data_tables[index]; + seed = rnd(); } - } - if (!context.Attr("is_test")) { // Train + auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); - size_t size = framework::product(mask->dims()); - auto& dev_ctx = context.template device_context(); - int r = xpu::dropout(dev_ctx.x_context(), mask_data_table, x_data, - mask_data, y_data, max_data_size, size); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU dropout return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); - } else { // Infer - float scale = 0.0f; - if (dropout_implementation == "upscale_in_train") { - scale = 1.0f; - } else { - scale = static_cast(1.0f - dropout_prob); + // Special case when dropout_prob is 1.0 + if (dropout_prob == 1.0f) { + int r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(y_data), y->numel(), + XPUTyp(0)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(constant) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + r = xpu::constant(dev_ctx.x_context(), + reinterpret_cast(mask_data), mask->numel(), + XPUTyp(0)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(constant) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + return; } - auto& dev_ctx = context.template device_context(); - int r = xpu::scale(dev_ctx.x_context(), x->numel(), scale, 0.0f, 0, - x_data, y_data); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU dropout return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + int r = xpu::dropout(dev_ctx.x_context(), + reinterpret_cast(x->data()), + reinterpret_cast(y->data()), + reinterpret_cast(mask_data), seed, + mask->numel(), is_upscale, dropout_prob); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(dropout) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + } else { + float scale = + (is_upscale) ? (1.0) : (static_cast(1.0f - dropout_prob)); + int r = xpu::scale( + dev_ctx.x_context(), reinterpret_cast(x_data), + reinterpret_cast(y_data), x->numel(), false, scale, 0.0f); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(scale) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); } } }; template class DropoutGradXPUKernel : public framework::OpKernel { + using XPUTyp = typename XPUTypeTrait::Type; + public: void Compute(const framework::ExecutionContext& context) const override { PADDLE_ENFORCE_EQ(!context.Attr("is_test"), true, @@ -127,23 +104,47 @@ class DropoutGradXPUKernel : public framework::OpKernel { auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); - int r = xpu::elementwise_mul(dev_ctx.x_context(), grad_y->data(), - mask->data(), grad_x->data(), - grad_y->numel()); - PADDLE_ENFORCE_EQ( - r, xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU dropout return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); + auto& dropout_implementation = + context.Attr("dropout_implementation"); + float dropout_prob = context.Attr("dropout_prob"); + const T* mask_data = mask->data(); + framework::Tensor mask_new; + if (dropout_implementation == "upscale_in_train") { + mask_new = context.AllocateTmpTensor( + mask->dims(), dev_ctx); + float scale = + (dropout_prob == 1.0f) ? (1.0f) : (1.0f / (1.0f - dropout_prob)); + int r = xpu::scale(dev_ctx.x_context(), + reinterpret_cast(mask->data()), + reinterpret_cast(mask_new.data()), + mask->numel(), false, scale, 0.0f); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( + "XPU API(scale) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); + mask_data = mask_new.data(); + } + + int r = xpu::mul( + dev_ctx.x_context(), reinterpret_cast(grad_y->data()), + reinterpret_cast(mask_data), + reinterpret_cast(grad_x->data()), grad_y->numel()); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External("XPU API(mul) return wrong " + "value[%d %s]", + r, XPUAPIErrorMsg[r])); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL( - dropout, ops::DropoutXPUKernel); + dropout, ops::DropoutXPUKernel, + ops::DropoutXPUKernel); REGISTER_OP_XPU_KERNEL( dropout_grad, - ops::DropoutGradXPUKernel); + ops::DropoutGradXPUKernel, + ops::DropoutGradXPUKernel); #endif diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc index 8b902acebb4c5d4a8f739c9fe0e5a6f40c31ee9f..2e902bd277b1e4d016d0c3190579c409c8d361f3 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc @@ -122,33 +122,50 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { axis)); std::vector x_dims_vec(max_dim, 1); std::vector y_dims_vec(max_dim, 1); + int x_len = 1; + int y_len = 1; if (x_dims.size() == max_dim) { for (int i = 0; i < max_dim; i++) { x_dims_vec[i] = x_dims[i]; + x_len *= x_dims_vec[i]; } } else { for (int i = 0; i < x_dims.size(); i++) { x_dims_vec[i + axis] = x_dims[i]; + x_len *= x_dims_vec[i]; } } if (y_dims.size() == max_dim) { for (int i = 0; i < max_dim; i++) { y_dims_vec[i] = y_dims[i]; + y_len *= y_dims_vec[i]; } } else { for (int i = 0; i < y_dims.size(); i++) { y_dims_vec[i + axis] = y_dims[i]; + y_len *= y_dims_vec[i]; } } const T* dz_data = dz->data(); + framework::Tensor dx_local_tensor; + framework::Tensor dy_local_tensor; + bool need_wait = false; T* dx_data = nullptr; T* dy_data = nullptr; if (dx) { dx_data = dx->mutable_data(ctx.GetPlace()); + } else { + dx_data = + dx_local_tensor.mutable_data(ctx.GetPlace(), x_len * sizeof(T)); + need_wait = true; } if (dy) { dy_data = dy->mutable_data(ctx.GetPlace()); + } else { + dy_data = + dy_local_tensor.mutable_data(ctx.GetPlace(), y_len * sizeof(T)); + need_wait = true; } auto& dev_ctx = @@ -161,6 +178,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { platform::errors::External( "XPU kernel Elementwise occur error in XPUElementwise error code ", ret, XPUAPIErrorMsg[ret])); + if (need_wait && dev_ctx.x_context()->xpu_stream) { + dev_ctx.Wait(); + } } }; diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc index 6fa96aca4be147e9d70c6e62500acaae88822315..7097b5327d86fab115ff85fd114dce6dd9e5ae2f 100644 --- a/paddle/fluid/operators/matmul_op_xpu.cc +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -102,6 +102,7 @@ template static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, bool trans_x, bool trans_y, const paddle::framework::ExecutionContext &ctx) { + using XPUType = typename XPUTypeTrait::Type; const auto &x_dims = x->dims(); const auto &y_dims = y->dims(); auto &dev_ctx = @@ -162,34 +163,36 @@ static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out, int ldout = n; if (batch_size <= 1) { int r = 0; - r = xpu::fc_fusion( - dev_ctx.x_context(), x->data(), y->data(), data_c, m, n, k, - mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, - ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); + r = xpu::fc_fusion( + dev_ctx.x_context(), reinterpret_cast(x->data()), + reinterpret_cast(y->data()), + reinterpret_cast(data_c), m, n, k, mat_dim_a.trans_, + mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy, ldout, alpha, 0, + nullptr, xpu::Activation_t::LINEAR); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU fc_fusion kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } else { // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - alpha, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + alpha, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( @@ -210,10 +213,14 @@ class MatMulXPUKernel : public framework::OpKernel { out->mutable_data(context.GetPlace()); bool trans_x = context.Attr("transpose_X"); bool trans_y = context.Attr("transpose_Y"); - if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, context); - } else { + if (std::is_same::value) { MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } else { + MatMulXPUFunction(x, y, out, trans_x, trans_y, context); + } } } }; @@ -224,6 +231,7 @@ class MatMulXPUKernel : public framework::OpKernel { template static framework::Tensor XPUFoldHeadAndLastDims( const DeviceContext &context, const framework::Tensor &input) { + using XPUType = typename XPUTypeTrait::Type; auto in_dims = input.dims(); if (in_dims.size() != 3) { return input; @@ -236,8 +244,9 @@ static framework::Tensor XPUFoldHeadAndLastDims( static_cast(in_dims[1]), static_cast(in_dims[2])}; std::vector axis_host = {1, 0, 2}; - int r = xpu::transpose(context.x_context(), input.data(), output.data(), - in_shape_host, axis_host); + int r = xpu::transpose( + context.x_context(), reinterpret_cast(input.data()), + reinterpret_cast(output.data()), in_shape_host, axis_host); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU transpose kernel return wrong value[%d %s]", r, @@ -280,10 +289,14 @@ class MatMulGradXPUKernel : public framework::OpKernel { const framework::Tensor &b, bool trans_b, framework::Tensor *out) const { out->mutable_data(context.GetPlace()); - if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); - } else { + if (std::is_same::value) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } else { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, context); + } } } @@ -370,10 +383,14 @@ class MatMulGradXPUKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; +namespace plat = paddle::platform; REGISTER_OP_XPU_KERNEL( - matmul, ops::MatMulXPUKernel); + matmul, ops::MatMulXPUKernel, + ops::MatMulXPUKernel); REGISTER_OP_XPU_KERNEL( matmul_grad, - ops::MatMulGradXPUKernel); + ops::MatMulGradXPUKernel, + ops::MatMulGradXPUKernel); #endif diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc index d992ef847db2aca8bc284781fdd1408d36bd14e5..ae1e9358f68115e4952696325051d142a25789f8 100644 --- a/paddle/fluid/operators/matmul_v2_op_xpu.cc +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -25,6 +25,7 @@ template static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, bool trans_x, bool trans_y, const paddle::framework::ExecutionContext& ctx) { + using XPUType = typename XPUTypeTrait::Type; const auto& x_dims = x->dims(); const auto& y_dims = y->dims(); auto& dev_ctx = @@ -75,9 +76,11 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, int batch_size = mat_dim_a.batch_size_; if (batch_size <= 1) { int r = 0; - r = xpu::fc(dev_ctx.x_context(), x->data(), y->data(), - data_c, m, n, k, mat_dim_a.trans_, - mat_dim_b.trans_, nullptr, nullptr, nullptr); + r = xpu::fc( + dev_ctx.x_context(), reinterpret_cast(x->data()), + reinterpret_cast(y->data()), + reinterpret_cast(data_c), m, n, k, mat_dim_a.trans_, + mat_dim_b.trans_, nullptr, nullptr, nullptr); PADDLE_ENFORCE_EQ( r, XPU_SUCCESS, platform::errors::External( @@ -87,24 +90,24 @@ static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out, r, XPUAPIErrorMsg[r], m, n, k, mat_dim_a.trans_, mat_dim_b.trans_)); } else { // batch matmul - int r = xpu::fc_batched( - dev_ctx.x_context(), // Context* ctx, - batch_size, // int batch_size, - mat_dim_a.trans_, // bool x_trans, - mat_dim_b.trans_, // bool w_trans, - m, // int m, - n, // int n, - k, // int k, - 1.0, // float alpha, - reinterpret_cast(x->data()), // const TX* x, - mat_dim_a.stride_, // int stride_a, - reinterpret_cast(y->data()), // const TW* w, - mat_dim_b.stride_, // int stride_b, - 0.0, // float beta, - reinterpret_cast(data_c), // TY* y, - m * n, // int stride_c, - nullptr, // const float* x_maxptr, - nullptr); // const float* w_maxptr + int r = xpu::fc_batched( + dev_ctx.x_context(), // Context* ctx, + batch_size, // int batch_size, + mat_dim_a.trans_, // bool x_trans, + mat_dim_b.trans_, // bool w_trans, + m, // int m, + n, // int n, + k, // int k, + 1.0, // float alpha, + reinterpret_cast(x->data()), // const TX* x, + mat_dim_a.stride_, // int stride_a, + reinterpret_cast(y->data()), // const TW* w, + mat_dim_b.stride_, // int stride_b, + 0.0, // float beta, + reinterpret_cast(data_c), // TY* y, + m * n, // int stride_c, + nullptr, // const float* x_maxptr, + nullptr); // const float* w_maxptr PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( @@ -123,10 +126,14 @@ class MatMulV2XPUKernel : public framework::OpKernel { bool trans_x = ctx.Attr("trans_x"); bool trans_y = ctx.Attr("trans_y"); out->mutable_data(ctx.GetPlace()); - if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) { - MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); - } else { + if (std::is_same::value) { MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } else { + MatMulXPUFunction(x, y, out, trans_x, trans_y, ctx); + } } } }; @@ -134,6 +141,7 @@ class MatMulV2XPUKernel : public framework::OpKernel { template static framework::Tensor XPUFoldHeadAndLastDims( const DeviceContext& context, const framework::Tensor& input) { + using XPUType = typename XPUTypeTrait::Type; auto in_dims = input.dims(); if (in_dims.size() != 3) { return input; @@ -147,8 +155,9 @@ static framework::Tensor XPUFoldHeadAndLastDims( static_cast(in_dims[2])}; std::vector axis_host = {1, 0, 2}; - int r = xpu::transpose(context.x_context(), input.data(), output.data(), - in_shape_host, axis_host); + int r = xpu::transpose( + context.x_context(), reinterpret_cast(input.data()), + reinterpret_cast(output.data()), in_shape_host, axis_host); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External( "XPU transpose kernel return wrong value[%d %s]", r, @@ -166,10 +175,14 @@ class MatMulV2XPUGradKernel : public framework::OpKernel { const framework::Tensor& b, bool trans_b, framework::Tensor* out) const { out->mutable_data(ctx.GetPlace()); - if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) { - MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); - } else { + if (std::is_same::value) { MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else { + if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } else { + MatMulXPUFunction(&a, &b, out, trans_a, trans_b, ctx); + } } } @@ -261,8 +274,10 @@ class MatMulV2XPUGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel); -REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel); +namespace plat = paddle::platform; +REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel, + ops::MatMulV2XPUKernel); +REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel, + ops::MatMulV2XPUGradKernel); #endif diff --git a/paddle/fluid/operators/softmax_op_xpu.cc b/paddle/fluid/operators/softmax_op_xpu.cc index ed7034ef6ab416a4e98ddcd02f045af459298d65..3527478f7661058e193d14d95f815beb28f1e92a 100644 --- a/paddle/fluid/operators/softmax_op_xpu.cc +++ b/paddle/fluid/operators/softmax_op_xpu.cc @@ -47,8 +47,8 @@ class SoftmaxXPUKernel : public framework::OpKernel { int len = x->numel(); T* clip_x_data = clip_x.mutable_data(context.GetPlace(), len * sizeof(T)); - r = xpu::clip(dev_ctx.x_context(), x->data(), clip_x_data, len, - -1e30, 1e30); + r = xpu::clip_v2(dev_ctx.x_context(), x->data(), clip_x_data, len, + static_cast(-1e20), static_cast(1e20)); PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External("XPU API(clip) return wrong " "value[%d %s]", diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 8635def2ecf138550bf02f0013b31b59647777b9..a79e31eb8d028d3d319176e397ba5da9da54cd0e 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -54,8 +54,9 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { int len = logits->numel(); T* clip_logits_data = clip_logits.mutable_data(context.GetPlace(), len * sizeof(T)); - r = xpu::clip(dev_ctx.x_context(), logits->data(), clip_logits_data, - len, -1e30, 1e30); + r = xpu::clip_v2(dev_ctx.x_context(), logits->data(), + clip_logits_data, len, static_cast(-1e20), + static_cast(1e20)); PADDLE_ENFORCE_EQ( r, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error. clip " diff --git a/paddle/fluid/platform/xpu_header.h b/paddle/fluid/platform/xpu_header.h index 9f2befc123f224aeda3cb4a3d196cbce470d51b2..99f4224b5d408a6450d801ff643f658b74333387 100644 --- a/paddle/fluid/platform/xpu_header.h +++ b/paddle/fluid/platform/xpu_header.h @@ -1,4 +1,4 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #include #include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/float16.h" #include "xpu/api.h" #include "xpu/refactor/fusion.h" #include "xpu/refactor/math.h" @@ -58,4 +59,16 @@ static std::map XPUAPIErrorMsg = { {xpu::Error_t::RUNTIME_ERROR, "xpu api runtime error"}, {xpu::Error_t::NO_ENOUGH_WORKSPACE, "xpu api no enough workspace"}}; +template +class XPUTypeTrait { + public: + using Type = T; +}; + +template <> +class XPUTypeTrait { + public: + using Type = float16; +}; + #endif diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 560d8c892b09f9b6f17136040455ee8469587f53..fd4ae63265366a27a090fed4ab694bae6ef261d4 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -224,7 +224,9 @@ OpSupportedInfos(const std::string &place, [](unsigned char c) { return std::toupper(c); }); using fn_type = std::add_pointer::type; std::unordered_map is_target_place{ - {"GPU", &platform::is_gpu_place}, {"CPU", &platform::is_cpu_place}, + {"GPU", &platform::is_gpu_place}, + {"CPU", &platform::is_cpu_place}, + {"XPU", &platform::is_xpu_place}, }; PADDLE_ENFORCE_NE( is_target_place.count(query_place), 0, diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py index f940f6a3143a09fa82d4e10fba38f7d86b9c025d..7c6f32e1e8e6254ef8100fa953f27317252cd110 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_lists.py @@ -149,8 +149,14 @@ gray_list = { # The set of ops that don't support fp16 calculation # lookup_table fp16 is slower than fp32, though fp16 is supported. -_, _, _sys_unsupported_fp16_list = core.op_supported_infos( - 'GPU', core.VarDesc.VarType.FP16) +_sys_unsupported_fp16_list = [] +if core.is_compiled_with_xpu(): + _, _, _sys_unsupported_fp16_list = core.op_supported_infos( + 'XPU', core.VarDesc.VarType.FP16) +else: + _, _, _sys_unsupported_fp16_list = core.op_supported_infos( + 'GPU', core.VarDesc.VarType.FP16) + unsupported_fp16_list = {'lookup_table', 'lookup_table_v2'} | _sys_unsupported_fp16_list diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 4ff08337875c030827a762eb4199c1a1e28781e4..6121732bf1f723fc56cca48854f85cdc515a4f9f 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -128,9 +128,10 @@ def amp_guard(enable=True, custom_white_list=None, custom_black_list=None): raise ValueError( "current_tracer is None, maybe it is not in imperative mode.") - if enable and not tracer._expected_place.is_gpu_place(): + if enable and not (tracer._expected_place.is_gpu_place() or + tracer._expected_place.is_xpu_place()): warnings.warn( - 'amp_guard can only be enabled on CUDAPlace, current place is %s, so it makes no effect.' + 'amp_guard can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' % tracer._expected_place) enable = False diff --git a/python/paddle/fluid/dygraph/amp/loss_scaler.py b/python/paddle/fluid/dygraph/amp/loss_scaler.py index ff57f30dcd2ec73d55ff06e751767deea0a2eead..e0bd60fbeb4a73db22ce1141750c2cf22ac29288 100644 --- a/python/paddle/fluid/dygraph/amp/loss_scaler.py +++ b/python/paddle/fluid/dygraph/amp/loss_scaler.py @@ -90,9 +90,10 @@ class AmpScaler(object): raise ValueError( "current_tracer is None, maybe it is not in imperative mode.") - if enable and not tracer._expected_place.is_gpu_place(): + if enable and not (tracer._expected_place.is_gpu_place() or + tracer._expected_place.is_xpu_place()): warnings.warn( - 'AmpScaler can only be enabled on CUDAPlace, current place is %s, so it makes no effect.' + 'AmpScaler can only be enabled on CUDAPlace and XPUPlace, current place is %s, so it makes no effect.' % tracer._expected_place) enable = False diff --git a/python/paddle/fluid/tests/unittests/xpu/test_amp_check_finite_and_scale_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_amp_check_finite_and_scale_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2976f82a46019bcafb40c35265ce5a936ff67e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_amp_check_finite_and_scale_op_xpu.py @@ -0,0 +1,99 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +sys.path.append("..") +import paddle +import unittest +import numpy as np +from op_test_xpu import XPUOpTest +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +paddle.enable_static() + + +class TestCheckFiniteAndUnscaleOp(XPUOpTest): + def setUp(self): + self.op_type = "check_finite_and_unscale" + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + scale = np.random.random((1)).astype(self.dtype) + # self.attrs = {'stop_gradient': True} + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([0]), + 'Out': [('out0', x / scale)], + } + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + +# class TestCheckFiniteAndUnscaleOpWithNan(XPUOpTest): +# def setUp(self): +# self.op_type = "check_finite_and_unscale" +# self.init_dtype() +# x = np.random.random((1024, 1024)).astype(self.dtype) +# x[128][128] = np.nan +# print("x shape = ", x.shape) +# print(x) +# scale = np.random.random((1)).astype(self.dtype) + +# self.inputs = {'X': [('x0', x)], 'Scale': scale} +# self.outputs = { +# 'FoundInfinite': np.array([1]), +# 'Out': [('out0', x)], +# } + +# def init_dtype(self): +# self.dtype = np.float32 + +# def test_check_output(self): +# # When input contains nan, do not check the output, +# # since the output may be nondeterministic and will be discarded. +# if paddle.is_compiled_with_xpu(): +# place = paddle.XPUPlace(0) +# self.check_output_with_place(place, no_check_set=['Out']) + +# class TestCheckFiniteAndUnscaleOpWithInf(XPUOpTest): +# def setUp(self): +# self.op_type = "check_finite_and_unscale" +# self.init_dtype() +# x = np.random.random((1024, 1024)).astype(self.dtype) +# x[128][128] = np.inf +# scale = np.random.random((1)).astype(self.dtype) + +# self.inputs = {'X': [('x0', x)], 'Scale': scale} +# self.outputs = { +# 'FoundInfinite': np.array([1]), +# 'Out': [('out0', x)], +# } + +# def init_dtype(self): +# self.dtype = np.float32 + +# def test_check_output(self): +# # When input contains inf, do not check the output, +# # since the output may be nondeterministic and will be discarded. +# if paddle.is_compiled_with_xpu(): +# place = paddle.XPUPlace(0) +# self.check_output_with_place(place, no_check_set=['Out']) + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py index 6c3368c3b6bfc4f8105107df06fd0aa38c18a2db..ca3b3a418abf6c968612fcca7997960f01945ce9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_dropout_op_xpu.py @@ -22,9 +22,11 @@ from op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid as fluid from paddle.fluid import Program, program_guard +from op_test_xpu import XPUOpTest +paddle.enable_static() -class TestDropoutOp(OpTest): +class TestDropoutOp(XPUOpTest): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} @@ -47,7 +49,7 @@ class TestDropoutOp(OpTest): self.check_grad_with_place(place, ['X'], 'Out') -class TestDropoutOpInput1d(OpTest): +class TestDropoutOpInput1d(XPUOpTest): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((2000, )).astype("float32")} diff --git a/python/paddle/fluid/tests/unittests/xpu/test_update_loss_scaling_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_update_loss_scaling_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..33b13081b54420841a521afd7573c0cb8788ecb6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_update_loss_scaling_op_xpu.py @@ -0,0 +1,245 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +sys.path.append("..") +import numpy as np +from op_test import OpTest +from op_test_xpu import XPUOpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn + +paddle.enable_static() + + +class TestUpdateLossScalingOp(XPUOpTest): + def setUp(self): + self.op_type = "update_loss_scaling" + self.init() + found_inf = np.array([False], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', x)], + 'LossScaling': self.prev_loss_scaling * self.incr_ratio, + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def init(self): + self.incr_ratio = 2.0 + self.decr_ratio = 0.8 + self.dtype = np.float32 + self.prev_loss_scaling = np.array([2048]).astype(self.dtype) + self.num_good_steps = np.array([999], dtype=np.int32) + self.num_bad_steps = np.array([1], dtype=np.int32) + self.zero_steps = np.array([0], dtype=np.int32) + self.attrs = { + 'incr_every_n_steps': 1000, + 'decr_every_n_nan_or_inf': 2, + 'incr_ratio': self.incr_ratio, + 'decr_ratio': self.decr_ratio, + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, no_check_set=['Out']) + + +class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): + def setUp(self): + self.op_type = "update_loss_scaling" + self.init() + found_inf = np.array([True], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + x[i[0]][j[0]] = np.inf + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', np.zeros_like(x))], + 'LossScaling': self.prev_loss_scaling * self.decr_ratio, + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + #self.check_output() + + +class TestUpdateLossScalingLayer(unittest.TestCase): + def loss_scaling_check(self, scope=fluid.Scope()): + a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') + b = fluid.data(name="b", shape=[512, 128], dtype='float32') + x = [a, b] + found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') + prev_loss_scaling = fluid.data( + name="prev_loss_scaling", shape=[1], dtype='float32') + num_good_steps = fluid.data( + name="num_good_steps", shape=[1], dtype='int32') + num_bad_steps = fluid.data( + name="num_bad_steps", shape=[1], dtype='int32') + + a_v = np.random.random([1024, 1024]).astype('float32') + b_v = np.random.random([512, 128]).astype('float32') + found_inf_v = np.array([False]).astype('bool') + prev_loss_scaling_v = np.array([2048]).astype('float32') + num_good_steps_v = np.array([999], dtype=np.int32) + num_bad_steps_v = np.array([1], dtype=np.int32) + + incr_every_n_steps = 1000 + decr_every_n_nan_or_inf = 2 + incr_ratio = 2 + decr_ratio = 0.8 + + result = amp_nn.update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name="update_loss_scaling") + + place = fluid.XPUPlace(0) + exe = fluid.Executor(place) + with fluid.scope_guard(scope): + exe.run(fluid.default_startup_program()) + result_v = exe.run(feed={ + 'a': a_v, + 'b': b_v, + 'found_inf': found_inf_v, + 'prev_loss_scaling': prev_loss_scaling_v, + 'num_good_steps': num_good_steps_v, + 'num_bad_steps': num_bad_steps_v + }, + fetch_list=[ + result, x, found_inf, prev_loss_scaling, + num_good_steps, num_bad_steps + ]) + assert np.array_equal(result_v[0], a_v) + assert np.array_equal(result_v[1], b_v) + assert np.array_equal(result_v[0], result_v[2]) + assert np.array_equal(result_v[1], result_v[3]) + assert np.array_equal(result_v[4], found_inf_v) + assert np.array_equal(result_v[5], prev_loss_scaling_v * incr_ratio) + assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) + assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) + + def loss_scaling_check_inf(self, use_cuda=True, scope=fluid.Scope()): + a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') + b = fluid.data(name="b", shape=[512, 128], dtype='float32') + x = [a, b] + found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') + prev_loss_scaling = fluid.data( + name="prev_loss_scaling", shape=[1], dtype='float32') + num_good_steps = fluid.data( + name="num_good_steps", shape=[1], dtype='int32') + num_bad_steps = fluid.data( + name="num_bad_steps", shape=[1], dtype='int32') + + a_v = np.random.random([1024, 1024]).astype('float32') + b_v = np.random.random([512, 128]).astype('float32') + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + a_v[i[0]][j[0]] = np.inf + found_inf_v = np.array([True]).astype('bool') + prev_loss_scaling_v = np.array([2048]).astype('float32') + num_good_steps_v = np.array([999], dtype=np.int32) + num_bad_steps_v = np.array([1], dtype=np.int32) + + incr_every_n_steps = 1000 + decr_every_n_nan_or_inf = 2 + incr_ratio = 2 + decr_ratio = 0.8 + + result = amp_nn.update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name="update_loss_scaling") + + place = fluid.XPUPlace(0) + exe = fluid.Executor(place) + with fluid.scope_guard(scope): + exe.run(fluid.default_startup_program()) + result_v = exe.run(feed={ + 'a': a_v, + 'b': b_v, + 'found_inf': found_inf_v, + 'prev_loss_scaling': prev_loss_scaling_v, + 'num_good_steps': num_good_steps_v, + 'num_bad_steps': num_bad_steps_v + }, + fetch_list=[ + result, x, found_inf, prev_loss_scaling, + num_good_steps, num_bad_steps + ]) + assert np.array_equal(result_v[0], np.zeros_like(a_v)) + assert np.array_equal(result_v[1], np.zeros_like(b_v)) + assert np.array_equal(result_v[2], np.zeros_like(a_v)) + assert np.array_equal(result_v[3], np.zeros_like(b_v)) + assert np.array_equal(result_v[4], found_inf_v) + assert np.array_equal(result_v[5], prev_loss_scaling_v * decr_ratio) + assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) + assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) + + def test_loss_scaling(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check() + + def test_loss_scaling_inf(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check_inf() + + +if __name__ == '__main__': + unittest.main()