From cbabbe2e9b1d4a21ecb55d38b5a5b801c2ce5f5d Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 2 Sep 2022 09:50:48 +0800 Subject: [PATCH] [XPU]Migrate Adam XPU kernel into Phi (#45572) * [XPU]Migrate Adam XPU kernel into Phi * test=kunlun --- .../operators/math/selected_rows_functor.cc | 10 +- .../fluid/operators/optimizers/adam_op_xpu.cc | 643 ------------------ paddle/phi/kernels/CMakeLists.txt | 20 +- paddle/phi/kernels/funcs/adam_functors.h | 134 ++++ .../kernels/selected_rows/xpu/adam_kernel.cc | 308 +++++++++ paddle/phi/kernels/xpu/adam_kernel.cc | 252 +++++++ 6 files changed, 703 insertions(+), 664 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/adam_op_xpu.cc create mode 100644 paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc create mode 100644 paddle/phi/kernels/xpu/adam_kernel.cc diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 9ec1172c410..354af32beab 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -569,8 +569,8 @@ TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex) #ifdef PADDLE_WITH_XPU template -struct MergeAdd { - phi::SelectedRows operator()(const platform::XPUDeviceContext& context, +struct MergeAdd { + phi::SelectedRows operator()(const phi::XPUContext& context, const phi::SelectedRows& input, const bool sorted_result = false) { phi::SelectedRows out; @@ -578,7 +578,7 @@ struct MergeAdd { return out; } - void operator()(const platform::XPUDeviceContext& context, + void operator()(const phi::XPUContext& context, const phi::SelectedRows& input, phi::SelectedRows* output, const bool sorted_result = false) { @@ -633,7 +633,7 @@ struct MergeAdd { PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows"); } - void operator()(const platform::XPUDeviceContext& context, + void operator()(const phi::XPUContext& context, const std::vector& inputs, phi::SelectedRows* output, const bool sorted_result = false) { @@ -838,7 +838,7 @@ struct MergeAverage { }; #ifdef PADDLE_WITH_XPU -template struct MergeAdd; +template struct MergeAdd; #endif template struct MergeAverage; diff --git a/paddle/fluid/operators/optimizers/adam_op_xpu.cc b/paddle/fluid/operators/optimizers/adam_op_xpu.cc deleted file mode 100644 index c9e2f71c9e2..00000000000 --- a/paddle/fluid/operators/optimizers/adam_op_xpu.cc +++ /dev/null @@ -1,643 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "gflags/gflags.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/optimizers/adam_op_functor.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using float16 = paddle::platform::float16; - -#ifdef PADDLE_WITH_XPU -template -static int ConvertDataByType(const T1* x, - T2** y, - int len, - bool allocateFlag, - const framework::ExecutionContext& ctx) { - if (nullptr == x || nullptr == y || len <= 0) - return xpu::Error_t::INVALID_PARAM; - int r = 0; - if (allocateFlag) { - r = xpu_malloc(reinterpret_cast(y), sizeof(T2) * len); - - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "Alloc memory in xpu for result data failed with [%d]", r)); - } - - T1* cpu_data = reinterpret_cast(malloc(sizeof(T1) * len)); - - paddle::memory::Copy(paddle::platform::CPUPlace(), - cpu_data, - ctx.GetPlace(), - x, - len * sizeof(T1)); - - T2* cpu_real_data = reinterpret_cast(malloc(sizeof(T2) * len)); - for (int i = 0; i < len; i++) cpu_real_data[i] = static_cast(cpu_data[i]); - - paddle::memory::Copy(ctx.GetPlace(), - *y, - paddle::platform::CPUPlace(), - cpu_real_data, - len * sizeof(T2)); - - free(cpu_data); - free(cpu_real_data); - - return xpu::Error_t::SUCCESS; -} - -template -static void getDataPointer(const phi::DenseTensor& tensorData, - T** result, - const framework::ExecutionContext& ctx) { - if (tensorData.dtype() == paddle::experimental::DataType::FLOAT16) { - const float16* real_data = - tensorData.template data(); - int len = tensorData.numel(); - - int r = ConvertDataByType(real_data, result, len, true, ctx); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "execute function ConvertDataByType failed with [%d]", r)); - } -} - -template -static void getOutDataPointer(phi::DenseTensor* tensorData, - Tensor* out, - T** result, - const framework::ExecutionContext& ctx) { - if (tensorData->dtype() == paddle::experimental::DataType::FLOAT16) { - *result = out->template mutable_data(ctx.GetPlace()); - } else { - *result = tensorData->template mutable_data(ctx.GetPlace()); - } -} - -template -static void copyOutData(const Tensor& srcTensor, - phi::DenseTensor* dstTensor, - const framework::ExecutionContext& ctx) { - if (dstTensor->dtype() == paddle::experimental::DataType::FLOAT16) { - const T* xpu_out_data = srcTensor.template data(); - float16* out_data = - dstTensor->template mutable_data(ctx.GetPlace()); - - int len = srcTensor.numel(); - - int r = - ConvertDataByType(xpu_out_data, &out_data, len, false, ctx); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "execute function ConvertDataByType failed with[%d]", r)); - } -} - -template -static void setBetaData(const phi::DenseTensor& beta_pow, - phi::DenseTensor* beta_pow_out, - const T& beta) { - if (beta_pow.dtype() == paddle::experimental::DataType::FLOAT16) { - const float16* beta_pow_p = beta_pow.template data(); - beta_pow_out->mutable_data(platform::CPUPlace())[0] = - static_cast(beta) * beta_pow_p[0]; - } else { - const T* beta_pow_p = beta_pow.template data(); - beta_pow_out->mutable_data(platform::CPUPlace())[0] = - beta * beta_pow_p[0]; - } -} - -template -static void scale(phi::DenseTensor* beta_pow_out, - const phi::DenseTensor& beta_pow, - T* beta_pow_ptr, - const T& beta, - const framework::ExecutionContext& ctx) { - float16* beta_pow_out_p2 = - beta_pow_out->mutable_data(ctx.GetPlace()); - - Tensor xpu_beta_pow_out; - const phi::DenseTensorMeta meta_beta_pow_out( - paddle::experimental::DataType::FLOAT32, beta_pow_out->dims()); - xpu_beta_pow_out.set_meta(meta_beta_pow_out); - - T* beta_pow_out_ptr = - xpu_beta_pow_out.template mutable_data(ctx.GetPlace()); - - auto& dev_ctx = ctx.template device_context(); - int r = xpu::scale(dev_ctx.x_context(), - beta_pow_ptr, - beta_pow_out_ptr, - beta_pow.numel(), - false, - beta, - 0.0f); - PADDLE_ENFORCE_EQ(r, - xpu::SUCCESS, - platform::errors::External( - "XPU kernel scale occur error in adam error code ", - r, - XPUAPIErrorMsg[r])); - - const float* xpu_beta_pow_out_data = xpu_beta_pow_out.template data(); - int len = xpu_beta_pow_out.numel(); - - r = ConvertDataByType( - xpu_beta_pow_out_data, &beta_pow_out_p2, len, false, ctx); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "execute function ConvertDataByType failed with [%d]", r)); -} - -template -static void freeData(const phi::DenseTensor& tensorData, T* dataPtr) { - if (tensorData.dtype() == paddle::experimental::DataType::FLOAT16) - xpu_free(dataPtr); -} - -template -class AdamOpXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto* param_var = ctx.InputVar("Param"); - PADDLE_ENFORCE_EQ(param_var->IsType(), - true, - platform::errors::InvalidArgument( - "Tensor holds the wrong type,Expected Var(%s)'s " - "type is LoDTensor, " - "but the received is %s", - ctx.InputNames("Param").front(), - framework::ToTypeName(param_var->Type()))); - using paddle::framework::LoDTensor; - - auto& param = GET_DATA_SAFELY( - ctx.Input("Param"), "Input", "Param", "Adam"); - - float* param_ptr = nullptr; - getDataPointer(param, ¶m_ptr, ctx); - - auto* grad_var = ctx.InputVar("Grad"); - float* grad_c = nullptr; - - auto& mom1 = GET_DATA_SAFELY( - ctx.Input("Moment1"), "Input", "Moment1", "Adam"); - float* mom1_ptr = nullptr; - getDataPointer(mom1, &mom1_ptr, ctx); - - auto& mom2 = GET_DATA_SAFELY( - ctx.Input("Moment2"), "Input", "Moment2", "Adam"); - float* mom2_ptr = nullptr; - getDataPointer(mom2, &mom2_ptr, ctx); - - auto& lr = GET_DATA_SAFELY( - ctx.Input("LearningRate"), "Input", "LearningRate", "Adam"); - float* lr_ptr = nullptr; - getDataPointer(lr, &lr_ptr, ctx); - - auto& beta1_pow = GET_DATA_SAFELY( - ctx.Input("Beta1Pow"), "Input", "Beta1Pow", "Adam"); - auto& dev_ctx = ctx.template device_context(); - float* beta1_pow_ptr = nullptr; - const float* beta1_const_pow_ptr = nullptr; - if (beta1_pow.place() == platform::CPUPlace()) { - Tensor xpu_beta1_pow; - paddle::framework::TensorCopy( - beta1_pow, ctx.GetPlace(), dev_ctx, &xpu_beta1_pow); - if (xpu_beta1_pow.dtype() == paddle::experimental::DataType::FLOAT16) - getDataPointer(xpu_beta1_pow, &beta1_pow_ptr, ctx); - else - beta1_const_pow_ptr = xpu_beta1_pow.template data(); - } else { - if (beta1_pow.dtype() == paddle::experimental::DataType::FLOAT16) - getDataPointer(beta1_pow, &beta1_pow_ptr, ctx); - else - beta1_const_pow_ptr = beta1_pow.template data(); - } - - auto& beta2_pow = GET_DATA_SAFELY( - ctx.Input("Beta2Pow"), "Input", "Beta2Pow", "Adam"); - float* beta2_pow_ptr = nullptr; - const float* beta2_const_pow_ptr = nullptr; - if (beta2_pow.place() == platform::CPUPlace()) { - Tensor xpu_beta2_pow; - paddle::framework::TensorCopy( - beta2_pow, ctx.GetPlace(), dev_ctx, &xpu_beta2_pow); - if (xpu_beta2_pow.dtype() == paddle::experimental::DataType::FLOAT16) - getDataPointer(xpu_beta2_pow, &beta2_pow_ptr, ctx); - else - beta2_const_pow_ptr = xpu_beta2_pow.template data(); - } else { - if (beta2_pow.dtype() == paddle::experimental::DataType::FLOAT16) - getDataPointer(beta2_pow, &beta2_pow_ptr, ctx); - else - beta2_const_pow_ptr = beta2_pow.template data(); - } - - auto& param_out = GET_DATA_SAFELY( - ctx.Output("ParamOut"), "Output", "ParamOut", "Adam"); - Tensor xpu_param_out; - float* param_out_ptr = nullptr; - const phi::DenseTensorMeta meta_param( - paddle::experimental::DataType::FLOAT32, param_out.dims()); - xpu_param_out.set_meta(meta_param); - getOutDataPointer(¶m_out, &xpu_param_out, ¶m_out_ptr, ctx); - - auto& mom1_out = GET_DATA_SAFELY( - ctx.Output("Moment1Out"), "Output", "Moment1Out", "Adam"); - Tensor xpu_mom1_out; - float* mom1_out_ptr = nullptr; - const phi::DenseTensorMeta meta_mom1( - paddle::experimental::DataType::FLOAT32, mom1_out.dims()); - xpu_mom1_out.set_meta(meta_mom1); - getOutDataPointer(&mom1_out, &xpu_mom1_out, &mom1_out_ptr, ctx); - - auto& mom2_out = GET_DATA_SAFELY( - ctx.Output("Moment2Out"), "Output", "Moment2Out", "Adam"); - Tensor xpu_mom2_out; - float* mom2_out_ptr = nullptr; - const phi::DenseTensorMeta meta_mom2( - paddle::experimental::DataType::FLOAT32, mom2_out.dims()); - xpu_mom2_out.set_meta(meta_mom2); - getOutDataPointer(&mom2_out, &xpu_mom2_out, &mom2_out_ptr, ctx); - - auto* beta1_pow_out = ctx.Output("Beta1PowOut"); - auto* beta2_pow_out = ctx.Output("Beta2PowOut"); - - bool skip_update = false; - if (ctx.HasInput("SkipUpdate")) { - auto* skip_update_tensor = ctx.Input("SkipUpdate"); - PADDLE_ENFORCE_EQ(skip_update_tensor->numel(), - 1, - platform::errors::InvalidArgument( - "Input(SkipUpdate) size must be 1, but get %d", - skip_update_tensor->numel())); - std::vector skip_update_vec; - paddle::framework::TensorToVector( - *skip_update_tensor, ctx.device_context(), &skip_update_vec); - skip_update = skip_update_vec[0]; - } - // skip_update=true, just copy input to output, and TensorCopy will call - // mutable_data - if (skip_update) { - VLOG(4) << "Adam skip update"; - framework::TensorCopy( - param, - ctx.GetPlace(), - ctx.template device_context(), - ¶m_out); - framework::TensorCopy( - mom1, - ctx.GetPlace(), - ctx.template device_context(), - &mom1_out); - framework::TensorCopy( - mom2, - ctx.GetPlace(), - ctx.template device_context(), - &mom2_out); - framework::TensorCopy( - beta1_pow, - beta1_pow.place(), - ctx.template device_context(), - beta1_pow_out); - framework::TensorCopy( - beta2_pow, - beta2_pow.place(), - ctx.template device_context(), - beta2_pow_out); - return; - } - - PADDLE_ENFORCE_EQ(beta1_pow_out->numel(), - 1, - platform::errors::InvalidArgument( - "Tensor holds the wrong size, Expected beta1 pow " - "output size is 1, but received " - "value is:%d.", - beta1_pow_out->numel())); - - PADDLE_ENFORCE_EQ(beta2_pow_out->numel(), - 1, - platform::errors::InvalidArgument( - "Tensor holds the wrong size, Expected beta2 pow " - "output size is 1, but received " - "value is:%d.", - beta2_pow_out->numel())); - - bool use_global_beta_pow = ctx.Attr("use_global_beta_pow"); - VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; - - float beta1 = static_cast(ctx.Attr("beta1")); - if (ctx.HasInput("Beta1Tensor")) { - auto* beta1_tensor = ctx.Input("Beta1Tensor"); - beta1 = static_cast(GetAttrFromTensor(beta1_tensor)); - } - float beta2 = static_cast(ctx.Attr("beta2")); - if (ctx.HasInput("Beta2Tensor")) { - auto* beta2_tensor = ctx.Input("Beta2Tensor"); - beta2 = static_cast(GetAttrFromTensor(beta2_tensor)); - } - float epsilon = static_cast(ctx.Attr("epsilon")); - if (ctx.HasInput("EpsilonTensor")) { - auto* epsilon_tensor = ctx.Input("EpsilonTensor"); - epsilon = static_cast(GetAttrFromTensor(epsilon_tensor)); - } - - if (grad_var->IsType()) { - auto& grad = GET_DATA_SAFELY( - ctx.Input("Grad"), "Input", "Grad", "Adam"); - getDataPointer(grad, &grad_c, ctx); - - int r = xpu::adam( - dev_ctx.x_context(), - grad_c != nullptr ? grad_c : grad.template data(), - mom1_ptr != nullptr ? mom1_ptr : mom1.template data(), - mom2_ptr != nullptr ? mom2_ptr : mom2.template data(), - param_ptr != nullptr ? param_ptr : param.template data(), - beta1_pow_ptr != nullptr ? beta1_pow_ptr : beta1_const_pow_ptr, - beta2_pow_ptr != nullptr ? beta2_pow_ptr : beta2_const_pow_ptr, - lr_ptr != nullptr ? lr_ptr : lr.template data(), - mom1_out_ptr, - mom2_out_ptr, - param_out_ptr, - beta1, - beta2, - epsilon, - param.numel()); - - xpu_wait(dev_ctx.x_context()->xpu_stream); - PADDLE_ENFORCE_EQ( - r == xpu::Error_t::SUCCESS, - true, - platform::errors::External("XPU API return wrong value[%d],", r)); - - freeData(grad, grad_c); - - copyOutData(xpu_mom1_out, &mom1_out, ctx); - copyOutData(xpu_mom2_out, &mom2_out, ctx); - copyOutData(xpu_param_out, ¶m_out, ctx); - - if (!use_global_beta_pow) { - // update in cpu and then copy to xpu - if (beta1_pow.place() == platform::CPUPlace() && - beta2_pow.place() == platform::CPUPlace()) { - setBetaData(beta1_pow, beta1_pow_out, beta1); - - setBetaData(beta2_pow, beta2_pow_out, beta2); - } else { - float* beta1_pow_out_p1 = nullptr; - - if (beta1_pow_out->dtype() == - paddle::experimental::DataType::FLOAT16) { - scale( - beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1, ctx); - } else { - const float* beta1_pow_data = beta1_pow.template data(); - beta1_pow_out_p1 = - beta1_pow_out->mutable_data(ctx.GetPlace()); - r = xpu::scale(dev_ctx.x_context(), - beta1_pow_data, - beta1_pow_out_p1, - beta1_pow.numel(), - false, - beta1, - 0.0f); - xpu_wait(dev_ctx.x_context()->xpu_stream); - PADDLE_ENFORCE_EQ( - r, - xpu::SUCCESS, - platform::errors::External( - "XPU kernel scale occur error in adam error code ", - r, - XPUAPIErrorMsg[r])); - } - - float* beta2_pow_out_p1 = nullptr; - if (beta2_pow_out->dtype() == - paddle::experimental::DataType::FLOAT16) { - scale( - beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2, ctx); - } else { - const float* beta2_pow_data = beta2_pow.template data(); - beta2_pow_out_p1 = - beta2_pow_out->mutable_data(ctx.GetPlace()); - r = xpu::scale(dev_ctx.x_context(), - beta2_pow_data, - beta2_pow_out_p1, - beta2_pow.numel(), - false, - beta2, - 0.0f); - xpu_wait(dev_ctx.x_context()->xpu_stream); - PADDLE_ENFORCE_EQ( - r, - xpu::SUCCESS, - platform::errors::External( - "XPU kernel scale occur error in adam error code ", - r, - XPUAPIErrorMsg[r])); - } - } - } - } else if (grad_var->IsType()) { - auto* grad = ctx.Input("Grad"); - - if (grad->rows().size() == 0) { - VLOG(3) << "grad row size is 0!!"; - return; - } - - std::vector cpu_rows(grad->rows().begin(), grad->rows().end()); - bool is_strict_sorted = true; - for (size_t i = 1; i < cpu_rows.size(); ++i) { - if (cpu_rows[i - 1] >= cpu_rows[i]) { - is_strict_sorted = false; - break; - } - } - - phi::SelectedRows tmp_grad_merge; - const phi::SelectedRows* grad_merge_ptr; - if (is_strict_sorted) { - grad_merge_ptr = grad; - } else { - scatter::MergeAdd merge_func; - merge_func(ctx.template device_context(), - *grad, - &tmp_grad_merge, - true); - - xpu_wait(dev_ctx.x_context()->xpu_stream); - grad_merge_ptr = &tmp_grad_merge; - } - - auto& grad_merge = *grad_merge_ptr; - auto& grad_tensor = grad_merge.value(); - - getDataPointer(grad_tensor, &grad_c, ctx); - - int row_count = grad_merge.rows().size(); - std::vector rows(row_count); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int* xpu_rows = RAII_GUARD.alloc_l3_or_gm(row_count); - std::vector merge_rows(grad_merge.rows().begin(), - grad_merge.rows().end()); - for (size_t i = 0; i < grad_merge.rows().size(); ++i) { - rows[i] = static_cast(merge_rows[i]); - } - xpu_wait(dev_ctx.x_context()->xpu_stream); - memory::Copy(ctx.GetPlace(), - xpu_rows, - platform::CPUPlace(), - rows.data(), - row_count * sizeof(int)); - auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); - auto ori_rows = param.numel() / row_numel; - - int lazy_mode = static_cast(ctx.Attr("lazy_mode")); - int r = xpu::sparse_adam( - dev_ctx.x_context(), - grad_c != nullptr ? grad_c : grad_tensor.template data(), - mom1_ptr != nullptr ? mom1_ptr : mom1.template data(), - mom2_ptr != nullptr ? mom2_ptr : mom2.template data(), - param_ptr != nullptr ? param_ptr : param.template data(), - beta1_pow_ptr != nullptr ? beta1_pow_ptr : beta1_const_pow_ptr, - beta2_pow_ptr != nullptr ? beta2_pow_ptr : beta2_const_pow_ptr, - lr_ptr != nullptr ? lr_ptr : lr.template data(), - mom1_out_ptr, - mom2_out_ptr, - param_out_ptr, - beta1, - beta2, - epsilon, - ori_rows, - xpu_rows, - row_numel, - grad_merge.rows().size(), - lazy_mode); - - PADDLE_ENFORCE_EQ( - r == xpu::Error_t::SUCCESS, - true, - platform::errors::External("XPU API return wrong value[%d],", r)); - - freeData(grad_tensor, grad_c); - - copyOutData(xpu_mom1_out, &mom1_out, ctx); - copyOutData(xpu_mom2_out, &mom2_out, ctx); - copyOutData(xpu_param_out, ¶m_out, ctx); - - if (!use_global_beta_pow) { - // update in cpu and then copy to xpu - if (beta1_pow.place() == platform::CPUPlace() && - beta2_pow.place() == platform::CPUPlace()) { - setBetaData(beta1_pow, beta1_pow_out, beta1); - - setBetaData(beta2_pow, beta2_pow_out, beta2); - } else { - float* beta1_pow_out_p1 = nullptr; - - if (beta1_pow_out->dtype() == - paddle::experimental::DataType::FLOAT16) { - scale( - beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1, ctx); - } else { - const float* beta1_pow_data = beta1_pow.template data(); - beta1_pow_out_p1 = - beta1_pow_out->mutable_data(ctx.GetPlace()); - r = xpu::scale(dev_ctx.x_context(), - beta1_pow_data, - beta1_pow_out_p1, - beta1_pow.numel(), - false, - beta1, - 0.0f); - xpu_wait(dev_ctx.x_context()->xpu_stream); - PADDLE_ENFORCE_EQ( - r, - xpu::SUCCESS, - platform::errors::External( - "XPU kernel scale occur error in adam error code ", - r, - XPUAPIErrorMsg[r])); - } - - float* beta2_pow_out_p1 = nullptr; - if (beta2_pow_out->dtype() == - paddle::experimental::DataType::FLOAT16) { - scale( - beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2, ctx); - } else { - const float* beta2_pow_data = beta2_pow.template data(); - beta2_pow_out_p1 = - beta2_pow_out->mutable_data(ctx.GetPlace()); - r = xpu::scale(dev_ctx.x_context(), - beta2_pow_data, - beta2_pow_out_p1, - beta2_pow.numel(), - false, - beta2, - 0.0f); - xpu_wait(dev_ctx.x_context()->xpu_stream); - PADDLE_ENFORCE_EQ( - r, - xpu::SUCCESS, - platform::errors::External( - "XPU kernel scale occur error in adam error code ", - r, - XPUAPIErrorMsg[r])); - } - } - } - } else { - PADDLE_ENFORCE_EQ(1, - 2, - platform::errors::InvalidArgument( - "Variable type not supported by adam_op")); - } - - freeData(param, param_ptr); - freeData(mom1, mom1_ptr); - freeData(mom2, mom2_ptr); - freeData(lr, lr_ptr); - } -}; -#endif - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -#ifdef PADDLE_WITH_XPU -REGISTER_OP_XPU_KERNEL( - adam, - ops::AdamOpXPUKernel, - ops::AdamOpXPUKernel); -#endif diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 10d3e730cc1..66867c938dd 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -22,6 +22,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "") # [ 1. Common kernel compilation dependencies ] set(COMMON_KERNEL_DEPS dense_tensor + string_tensor sparse_coo_tensor sparse_csr_tensor kernel_context @@ -30,6 +31,7 @@ set(COMMON_KERNEL_DEPS convert_utils lod_utils custom_kernel + string_infermeta phi_tensor_utils) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} @@ -67,21 +69,7 @@ set(COMMON_KERNEL_DEPS sequence_padding sequence_scale fft - phi_data_layout_transform) - -set(COMMON_KERNEL_DEPS - ${COMMON_KERNEL_DEPS} - dense_tensor - string_tensor - sparse_coo_tensor - sparse_csr_tensor - kernel_context - kernel_factory - arg_map_context - convert_utils - lod_utils - custom_kernel - string_infermeta + phi_data_layout_transform gpc utf8proc device_memory_aligment) @@ -136,7 +124,7 @@ else() "strings/cpu/*.cc") endif() -file(GLOB kernel_xpu "xpu/*.cc") +file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc") add_library(phi_cpu ${kernel_cc}) kernel_declare("${kernel_cc}") diff --git a/paddle/phi/kernels/funcs/adam_functors.h b/paddle/phi/kernels/funcs/adam_functors.h index b14ee7f072e..4edc83ca30a 100644 --- a/paddle/phi/kernels/funcs/adam_functors.h +++ b/paddle/phi/kernels/funcs/adam_functors.h @@ -19,8 +19,142 @@ #include "paddle/phi/kernels/funcs/algorithm.h" +#ifdef PADDLE_WITH_XPU +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_header.h" +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/memory/memcpy.h" +#endif + namespace phi { namespace funcs { +using float16 = dtype::float16; + +#ifdef PADDLE_WITH_XPU + +template +static int ConvertDataByType( + const T1* x, T2** y, int len, bool allocateFlag, const Context& dev_ctx) { + if (nullptr == x || nullptr == y || len <= 0) + return xpu::Error_t::INVALID_PARAM; + int r = 0; + if (allocateFlag) { + r = xpu_malloc(reinterpret_cast(y), sizeof(T2) * len); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } + + T1* cpu_data = reinterpret_cast(malloc(sizeof(T1) * len)); + + paddle::memory::Copy( + CPUPlace(), cpu_data, dev_ctx.GetPlace(), x, len * sizeof(T1)); + + T2* cpu_real_data = reinterpret_cast(malloc(sizeof(T2) * len)); + for (int i = 0; i < len; i++) cpu_real_data[i] = static_cast(cpu_data[i]); + + paddle::memory::Copy( + dev_ctx.GetPlace(), *y, CPUPlace(), cpu_real_data, len * sizeof(T2)); + + free(cpu_data); + free(cpu_real_data); + + return xpu::Error_t::SUCCESS; +} + +template +static void GetDataPointer(const phi::DenseTensor& tensorData, + T** result, + const Context& dev_ctx) { + if (tensorData.dtype() == DataType::FLOAT16) { + const float16* real_data = tensorData.template data(); + int len = tensorData.numel(); + + int r = ConvertDataByType( + real_data, result, len, true, dev_ctx); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } +} + +template +static void GetOutDataPointer(DenseTensor* tensorData, + DenseTensor* out, + T** result, + const Context& dev_ctx) { + if (tensorData->dtype() == DataType::FLOAT16) { + *result = dev_ctx.template Alloc(out); + } else { + *result = dev_ctx.template Alloc(tensorData); + } +} + +template +static void CopyOutData(const DenseTensor& srcTensor, + phi::DenseTensor* dstTensor, + const Context& dev_ctx) { + if (dstTensor->dtype() == DataType::FLOAT16) { + const T* xpu_out_data = srcTensor.template data(); + float16* out_data = dev_ctx.template Alloc(dstTensor); + int len = srcTensor.numel(); + + int r = ConvertDataByType( + xpu_out_data, &out_data, len, false, dev_ctx); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } +} + +template +static void FreeData(const phi::DenseTensor& tensorData, T* dataPtr) { + if (tensorData.dtype() == DataType::FLOAT16) xpu_free(dataPtr); +} + +template +static void SetBetaData(const phi::DenseTensor& beta_pow, + phi::DenseTensor* beta_pow_out, + const T& beta, + const Context& dev_ctx) { + if (beta_pow.dtype() == DataType::FLOAT16) { + const float16* beta_pow_p = beta_pow.template data(); + dev_ctx.template HostAlloc(beta_pow_out)[0] = + static_cast(beta) * beta_pow_p[0]; + } else { + const T* beta_pow_p = beta_pow.template data(); + dev_ctx.template HostAlloc(beta_pow_out)[0] = beta * beta_pow_p[0]; + } +} + +template +static void Scale(phi::DenseTensor* beta_pow_out, + const phi::DenseTensor& beta_pow, + T* beta_pow_ptr, + const T& beta, + const Context& dev_ctx) { + float16* beta_pow_out_p2 = dev_ctx.template Alloc(beta_pow_out); + + DenseTensor xpu_beta_pow_out; + const phi::DenseTensorMeta meta_beta_pow_out(DataType::FLOAT32, + beta_pow_out->dims()); + xpu_beta_pow_out.set_meta(meta_beta_pow_out); + + T* beta_pow_out_ptr = dev_ctx.template Alloc(&xpu_beta_pow_out); + + int r = xpu::scale(dev_ctx.x_context(), + beta_pow_ptr, + beta_pow_out_ptr, + beta_pow.numel(), + false, + beta, + 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + + const float* xpu_beta_pow_out_data = + dev_ctx.template Alloc(&xpu_beta_pow_out); + int len = xpu_beta_pow_out.numel(); + + r = ConvertDataByType( + xpu_beta_pow_out_data, &beta_pow_out_p2, len, false, dev_ctx); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); +} +#endif struct GPUAdam; struct CPUAdam; diff --git a/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc new file mode 100644 index 00000000000..c9cd5f563fc --- /dev/null +++ b/paddle/phi/kernels/selected_rows/xpu/adam_kernel.cc @@ -0,0 +1,308 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/selected_rows/adam_kernel.h" + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/adam_functors.h" +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/operators/math/selected_rows_functor.h" + +namespace phi { +namespace sr { +using float16 = dtype::float16; + +template +void AdamDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const SelectedRows& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { + float* param_ptr = nullptr; + funcs::GetDataPointer(param, ¶m_ptr, dev_ctx); + + float* mom1_ptr = nullptr; + funcs::GetDataPointer(moment1, &mom1_ptr, dev_ctx); + + float* mom2_ptr = nullptr; + funcs::GetDataPointer(moment2, &mom2_ptr, dev_ctx); + + float* lr_ptr = nullptr; + funcs::GetDataPointer(learning_rate, &lr_ptr, dev_ctx); + + float* beta1_pow_ptr = nullptr; + const float* beta1_const_pow_ptr = nullptr; + if (beta1_pow.place() == CPUPlace()) { + DenseTensor xpu_beta1_pow; + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, &xpu_beta1_pow); + if (xpu_beta1_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer( + xpu_beta1_pow, &beta1_pow_ptr, dev_ctx); + else + beta1_const_pow_ptr = xpu_beta1_pow.template data(); + } else { + if (beta1_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer(beta1_pow, &beta1_pow_ptr, dev_ctx); + else + beta1_const_pow_ptr = beta1_pow.template data(); + } + + float* beta2_pow_ptr = nullptr; + const float* beta2_const_pow_ptr = nullptr; + if (beta2_pow.place() == CPUPlace()) { + DenseTensor xpu_beta2_pow; + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, &xpu_beta2_pow); + if (xpu_beta2_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer( + xpu_beta2_pow, &beta2_pow_ptr, dev_ctx); + else + beta2_const_pow_ptr = xpu_beta2_pow.template data(); + } else { + if (beta2_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer(beta2_pow, &beta2_pow_ptr, dev_ctx); + else + beta2_const_pow_ptr = beta2_pow.template data(); + } + + DenseTensor xpu_param_out; + float* param_out_ptr = nullptr; + const phi::DenseTensorMeta meta_param(DataType::FLOAT32, param_out->dims()); + xpu_param_out.set_meta(meta_param); + funcs::GetOutDataPointer( + param_out, &xpu_param_out, ¶m_out_ptr, dev_ctx); + + DenseTensor xpu_mom1_out; + float* mom1_out_ptr = nullptr; + const phi::DenseTensorMeta meta_mom1(DataType::FLOAT32, moment1_out->dims()); + xpu_mom1_out.set_meta(meta_mom1); + funcs::GetOutDataPointer( + moment1_out, &xpu_mom1_out, &mom1_out_ptr, dev_ctx); + + DenseTensor xpu_mom2_out; + float* mom2_out_ptr = nullptr; + const phi::DenseTensorMeta meta_mom2(DataType::FLOAT32, moment2_out->dims()); + xpu_mom2_out.set_meta(meta_mom2); + funcs::GetOutDataPointer( + moment2_out, &xpu_mom2_out, &mom2_out_ptr, dev_ctx); + + bool skip_update_ = false; + if (skip_update.is_initialized()) { + PADDLE_ENFORCE_EQ( + skip_update->numel(), + 1, + errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", + skip_update->numel())); + std::vector skip_update_vec; + paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + skip_update_ = skip_update_vec[0]; + } + + if (skip_update_) { + VLOG(4) << "Adam skip update"; + phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); + phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); + phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + return; + } + + PADDLE_ENFORCE_EQ( + beta1_pow_out->numel(), + 1, + errors::InvalidArgument("Tensor holds the wrong size, Expected beta1 pow " + "output size is 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ( + beta2_pow_out->numel(), + 1, + errors::InvalidArgument("Tensor holds the wrong size, Expected beta2 pow " + "output size is 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + auto beta1_ = beta1.to(); + auto beta2_ = beta2.to(); + auto epsilon_ = epsilon.to(); + + float* grad_c = nullptr; + if (grad.rows().size() == 0) { + VLOG(3) << "grad row size is 0!!"; + return; + } + + std::vector cpu_rows(grad.rows().begin(), grad.rows().end()); + bool is_strict_sorted = true; + for (size_t i = 1; i < cpu_rows.size(); ++i) { + if (cpu_rows[i - 1] >= cpu_rows[i]) { + is_strict_sorted = false; + break; + } + } + + SelectedRows tmp_grad_merge; + const SelectedRows* grad_merge_ptr; + if (is_strict_sorted) { + grad_merge_ptr = &grad; + } else { + paddle::operators::math::scatter::MergeAdd merge_func; + merge_func(dev_ctx, grad, &tmp_grad_merge, true); + + xpu_wait(dev_ctx.x_context()->xpu_stream); + grad_merge_ptr = &tmp_grad_merge; + } + + auto& grad_merge = *grad_merge_ptr; + auto& grad_tensor = grad_merge.value(); + + funcs::GetDataPointer(grad_tensor, &grad_c, dev_ctx); + + int row_count = grad_merge.rows().size(); + std::vector rows(row_count); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* xpu_rows = RAII_GUARD.alloc_l3_or_gm(row_count); + std::vector merge_rows(grad_merge.rows().begin(), + grad_merge.rows().end()); + for (size_t i = 0; i < grad_merge.rows().size(); ++i) { + rows[i] = static_cast(merge_rows[i]); + } + xpu_wait(dev_ctx.x_context()->xpu_stream); + paddle::memory::Copy(dev_ctx.GetPlace(), + xpu_rows, + CPUPlace(), + rows.data(), + row_count * sizeof(int)); + auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); + auto ori_rows = param.numel() / row_numel; + + int r = xpu::sparse_adam( + dev_ctx.x_context(), + grad_c != nullptr ? grad_c : grad_tensor.template data(), + mom1_ptr != nullptr ? mom1_ptr : moment1.template data(), + mom2_ptr != nullptr ? mom2_ptr : moment2.template data(), + param_ptr != nullptr ? param_ptr : param.template data(), + beta1_pow_ptr != nullptr ? beta1_pow_ptr : beta1_const_pow_ptr, + beta2_pow_ptr != nullptr ? beta2_pow_ptr : beta2_const_pow_ptr, + lr_ptr != nullptr ? lr_ptr : learning_rate.template data(), + mom1_out_ptr, + mom2_out_ptr, + param_out_ptr, + beta1_, + beta2_, + epsilon_, + ori_rows, + xpu_rows, + row_numel, + grad_merge.rows().size(), + lazy_mode); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + + funcs::FreeData(grad_tensor, grad_c); + + funcs::CopyOutData(xpu_mom1_out, moment1_out, dev_ctx); + funcs::CopyOutData(xpu_mom2_out, moment1_out, dev_ctx); + funcs::CopyOutData(xpu_param_out, moment1_out, dev_ctx); + + if (!use_global_beta_pow) { + // update in cpu and then copy to xpu + if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { + funcs::SetBetaData( + beta1_pow, beta1_pow_out, beta1_, dev_ctx); + + funcs::SetBetaData( + beta2_pow, beta2_pow_out, beta2_, dev_ctx); + } else { + float* beta1_pow_out_p1 = nullptr; + + if (beta1_pow_out->dtype() == DataType::FLOAT16) { + funcs::Scale( + beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1_, dev_ctx); + } else { + const float* beta1_pow_data = beta1_pow.template data(); + beta1_pow_out_p1 = dev_ctx.template Alloc(beta1_pow_out); + r = xpu::scale(dev_ctx.x_context(), + beta1_pow_data, + beta1_pow_out_p1, + beta1_pow.numel(), + false, + beta1_, + 0.0f); + xpu_wait(dev_ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } + + float* beta2_pow_out_p1 = nullptr; + if (beta2_pow_out->dtype() == DataType::FLOAT16) { + funcs::Scale( + beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2_, dev_ctx); + } else { + const float* beta2_pow_data = beta2_pow.template data(); + beta2_pow_out_p1 = dev_ctx.template Alloc(beta2_pow_out); + r = xpu::scale(dev_ctx.x_context(), + beta2_pow_data, + beta2_pow_out_p1, + beta2_pow.numel(), + false, + beta2_, + 0.0f); + xpu_wait(dev_ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } + } + } + funcs::FreeData(param, param_ptr); + funcs::FreeData(moment1, mom1_ptr); + funcs::FreeData(moment2, mom2_ptr); + funcs::FreeData(learning_rate, lr_ptr); +} +} // namespace sr +} // namespace phi + +PD_REGISTER_KERNEL(adam_dense_param_sparse_grad, + XPU, + ALL_LAYOUT, + phi::sr::AdamDenseParamSparseGradKernel, + float, + phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} diff --git a/paddle/phi/kernels/xpu/adam_kernel.cc b/paddle/phi/kernels/xpu/adam_kernel.cc new file mode 100644 index 00000000000..b4d3301667e --- /dev/null +++ b/paddle/phi/kernels/xpu/adam_kernel.cc @@ -0,0 +1,252 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/adam_kernel.h" + +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/adam_functors.h" + +namespace phi { + +using float16 = dtype::float16; + +template +void AdamDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& moment1, + const DenseTensor& moment2, + const DenseTensor& beta1_pow, + const DenseTensor& beta2_pow, + const paddle::optional& master_param, + const paddle::optional& skip_update, + const Scalar& beta1, + const Scalar& beta2, + const Scalar& epsilon, + bool lazy_mode, + int64_t min_row_size_to_use_multithread, + bool multi_precision, + bool use_global_beta_pow, + DenseTensor* param_out, + DenseTensor* moment1_out, + DenseTensor* moment2_out, + DenseTensor* beta1_pow_out, + DenseTensor* beta2_pow_out, + DenseTensor* master_param_outs) { + float* param_ptr = nullptr; + funcs::GetDataPointer(param, ¶m_ptr, dev_ctx); + + float* mom1_ptr = nullptr; + funcs::GetDataPointer(moment1, &mom1_ptr, dev_ctx); + + float* mom2_ptr = nullptr; + funcs::GetDataPointer(moment2, &mom2_ptr, dev_ctx); + + float* lr_ptr = nullptr; + funcs::GetDataPointer(learning_rate, &lr_ptr, dev_ctx); + + float* beta1_pow_ptr = nullptr; + const float* beta1_const_pow_ptr = nullptr; + if (beta1_pow.place() == CPUPlace()) { + DenseTensor xpu_beta1_pow; + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, &xpu_beta1_pow); + if (xpu_beta1_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer( + xpu_beta1_pow, &beta1_pow_ptr, dev_ctx); + else + beta1_const_pow_ptr = xpu_beta1_pow.template data(); + } else { + if (beta1_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer(beta1_pow, &beta1_pow_ptr, dev_ctx); + else + beta1_const_pow_ptr = beta1_pow.template data(); + } + + float* beta2_pow_ptr = nullptr; + const float* beta2_const_pow_ptr = nullptr; + if (beta2_pow.place() == CPUPlace()) { + DenseTensor xpu_beta2_pow; + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, &xpu_beta2_pow); + if (xpu_beta2_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer( + xpu_beta2_pow, &beta2_pow_ptr, dev_ctx); + else + beta2_const_pow_ptr = xpu_beta2_pow.template data(); + } else { + if (beta2_pow.dtype() == DataType::FLOAT16) + funcs::GetDataPointer(beta2_pow, &beta2_pow_ptr, dev_ctx); + else + beta2_const_pow_ptr = beta2_pow.template data(); + } + + DenseTensor xpu_param_out; + float* param_out_ptr = nullptr; + const phi::DenseTensorMeta meta_param(DataType::FLOAT32, param_out->dims()); + xpu_param_out.set_meta(meta_param); + funcs::GetOutDataPointer( + param_out, &xpu_param_out, ¶m_out_ptr, dev_ctx); + + DenseTensor xpu_mom1_out; + float* mom1_out_ptr = nullptr; + const phi::DenseTensorMeta meta_mom1(DataType::FLOAT32, moment1_out->dims()); + xpu_mom1_out.set_meta(meta_mom1); + funcs::GetOutDataPointer( + moment1_out, &xpu_mom1_out, &mom1_out_ptr, dev_ctx); + + DenseTensor xpu_mom2_out; + float* mom2_out_ptr = nullptr; + const phi::DenseTensorMeta meta_mom2(DataType::FLOAT32, moment2_out->dims()); + xpu_mom2_out.set_meta(meta_mom2); + funcs::GetOutDataPointer( + moment2_out, &xpu_mom2_out, &mom2_out_ptr, dev_ctx); + + bool skip_update_ = false; + if (skip_update.is_initialized()) { + PADDLE_ENFORCE_EQ( + skip_update->numel(), + 1, + errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d", + skip_update->numel())); + std::vector skip_update_vec; + paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec); + skip_update_ = skip_update_vec[0]; + } + + if (skip_update_) { + VLOG(4) << "Adam skip update"; + phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out); + phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out); + phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out); + phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out); + phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out); + return; + } + + PADDLE_ENFORCE_EQ( + beta1_pow_out->numel(), + 1, + errors::InvalidArgument("Tensor holds the wrong size, Expected beta1 pow " + "output size is 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ( + beta2_pow_out->numel(), + 1, + errors::InvalidArgument("Tensor holds the wrong size, Expected beta2 pow " + "output size is 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow; + + auto beta1_ = beta1.to(); + auto beta2_ = beta2.to(); + auto epsilon_ = epsilon.to(); + + float* grad_c = nullptr; + funcs::GetDataPointer(grad, &grad_c, dev_ctx); + + int r = xpu::adam( + dev_ctx.x_context(), + grad_c != nullptr ? grad_c : grad.template data(), + mom1_ptr != nullptr ? mom1_ptr : moment1.template data(), + mom2_ptr != nullptr ? mom2_ptr : moment2.template data(), + param_ptr != nullptr ? param_ptr : param.template data(), + beta1_pow_ptr != nullptr ? beta1_pow_ptr : beta1_const_pow_ptr, + beta2_pow_ptr != nullptr ? beta2_pow_ptr : beta2_const_pow_ptr, + lr_ptr != nullptr ? lr_ptr : learning_rate.template data(), + mom1_out_ptr, + mom2_out_ptr, + param_out_ptr, + beta1_, + beta2_, + epsilon_, + param.numel()); + + xpu_wait(dev_ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + + funcs::FreeData(grad, grad_c); + + funcs::CopyOutData(xpu_mom1_out, moment1_out, dev_ctx); + funcs::CopyOutData(xpu_mom2_out, moment2_out, dev_ctx); + funcs::CopyOutData(xpu_param_out, param_out, dev_ctx); + + if (!use_global_beta_pow) { + // update in cpu and then copy to xpu + if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) { + funcs::SetBetaData( + beta1_pow, beta1_pow_out, beta1_, dev_ctx); + + funcs::SetBetaData( + beta2_pow, beta2_pow_out, beta2_, dev_ctx); + } else { + float* beta1_pow_out_p1 = nullptr; + + if (beta1_pow_out->dtype() == DataType::FLOAT16) { + funcs::Scale( + beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1_, dev_ctx); + } else { + const float* beta1_pow_data = beta1_pow.template data(); + beta1_pow_out_p1 = dev_ctx.template Alloc(beta1_pow_out); + r = xpu::scale(dev_ctx.x_context(), + beta1_pow_data, + beta1_pow_out_p1, + beta1_pow.numel(), + false, + beta1_, + 0.0f); + xpu_wait(dev_ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } + + float* beta2_pow_out_p1 = nullptr; + if (beta2_pow_out->dtype() == DataType::FLOAT16) { + funcs::Scale( + beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2_, dev_ctx); + } else { + const float* beta2_pow_data = beta2_pow.template data(); + beta2_pow_out_p1 = dev_ctx.template Alloc(beta2_pow_out); + r = xpu::scale(dev_ctx.x_context(), + beta2_pow_data, + beta2_pow_out_p1, + beta2_pow.numel(), + false, + beta2_, + 0.0f); + xpu_wait(dev_ctx.x_context()->xpu_stream); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam"); + } + } + } + funcs::FreeData(param, param_ptr); + funcs::FreeData(moment1, mom1_ptr); + funcs::FreeData(moment2, mom2_ptr); + funcs::FreeData(learning_rate, lr_ptr); +} +} // namespace phi + +PD_REGISTER_KERNEL( + adam, XPU, ALL_LAYOUT, phi::AdamDenseKernel, float, phi::dtype::float16) { + // Skip beta1_pow, beta2_pow, skip_update data transform + kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND); + kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND); +} -- GitLab