未验证 提交 cbabbe2e 编写于 作者: A Aurelius84 提交者: GitHub

[XPU]Migrate Adam XPU kernel into Phi (#45572)

* [XPU]Migrate Adam XPU kernel into Phi

* test=kunlun
上级 e3e92c9a
......@@ -569,8 +569,8 @@ TEMPLATE_SPECIALIZED_FOR_MERGEADD_CPU(platform::complex<double>)
#ifdef PADDLE_WITH_XPU
template <typename T>
struct MergeAdd<platform::XPUDeviceContext, T> {
phi::SelectedRows operator()(const platform::XPUDeviceContext& context,
struct MergeAdd<phi::XPUContext, T> {
phi::SelectedRows operator()(const phi::XPUContext& context,
const phi::SelectedRows& input,
const bool sorted_result = false) {
phi::SelectedRows out;
......@@ -578,7 +578,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
return out;
}
void operator()(const platform::XPUDeviceContext& context,
void operator()(const phi::XPUContext& context,
const phi::SelectedRows& input,
phi::SelectedRows* output,
const bool sorted_result = false) {
......@@ -633,7 +633,7 @@ struct MergeAdd<platform::XPUDeviceContext, T> {
PADDLE_ENFORCE_XDNN_SUCCESS(r, "merge_dup_rows");
}
void operator()(const platform::XPUDeviceContext& context,
void operator()(const phi::XPUContext& context,
const std::vector<const phi::SelectedRows*>& inputs,
phi::SelectedRows* output,
const bool sorted_result = false) {
......@@ -838,7 +838,7 @@ struct MergeAverage<phi::CPUContext, T> {
};
#ifdef PADDLE_WITH_XPU
template struct MergeAdd<platform::XPUDeviceContext, float>;
template struct MergeAdd<phi::XPUContext, float>;
#endif
template struct MergeAverage<phi::CPUContext, int>;
......
......@@ -22,6 +22,7 @@ set_property(GLOBAL PROPERTY PHI_KERNELS "")
# [ 1. Common kernel compilation dependencies ]
set(COMMON_KERNEL_DEPS
dense_tensor
string_tensor
sparse_coo_tensor
sparse_csr_tensor
kernel_context
......@@ -30,6 +31,7 @@ set(COMMON_KERNEL_DEPS
convert_utils
lod_utils
custom_kernel
string_infermeta
phi_tensor_utils)
set(COMMON_KERNEL_DEPS
${COMMON_KERNEL_DEPS}
......@@ -67,21 +69,7 @@ set(COMMON_KERNEL_DEPS
sequence_padding
sequence_scale
fft
phi_data_layout_transform)
set(COMMON_KERNEL_DEPS
${COMMON_KERNEL_DEPS}
dense_tensor
string_tensor
sparse_coo_tensor
sparse_csr_tensor
kernel_context
kernel_factory
arg_map_context
convert_utils
lod_utils
custom_kernel
string_infermeta
phi_data_layout_transform
gpc
utf8proc
device_memory_aligment)
......@@ -136,7 +124,7 @@ else()
"strings/cpu/*.cc")
endif()
file(GLOB kernel_xpu "xpu/*.cc")
file(GLOB kernel_xpu "xpu/*.cc" "selected_rows/xpu/*.cc")
add_library(phi_cpu ${kernel_cc})
kernel_declare("${kernel_cc}")
......
......@@ -19,8 +19,142 @@
#include "paddle/phi/kernels/funcs/algorithm.h"
#ifdef PADDLE_WITH_XPU
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/backends/xpu/xpu_header.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memcpy.h"
#endif
namespace phi {
namespace funcs {
using float16 = dtype::float16;
#ifdef PADDLE_WITH_XPU
template <typename Context, typename T1, typename T2>
static int ConvertDataByType(
const T1* x, T2** y, int len, bool allocateFlag, const Context& dev_ctx) {
if (nullptr == x || nullptr == y || len <= 0)
return xpu::Error_t::INVALID_PARAM;
int r = 0;
if (allocateFlag) {
r = xpu_malloc(reinterpret_cast<void**>(y), sizeof(T2) * len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
T1* cpu_data = reinterpret_cast<T1*>(malloc(sizeof(T1) * len));
paddle::memory::Copy(
CPUPlace(), cpu_data, dev_ctx.GetPlace(), x, len * sizeof(T1));
T2* cpu_real_data = reinterpret_cast<T2*>(malloc(sizeof(T2) * len));
for (int i = 0; i < len; i++) cpu_real_data[i] = static_cast<T2>(cpu_data[i]);
paddle::memory::Copy(
dev_ctx.GetPlace(), *y, CPUPlace(), cpu_real_data, len * sizeof(T2));
free(cpu_data);
free(cpu_real_data);
return xpu::Error_t::SUCCESS;
}
template <typename Context, typename T>
static void GetDataPointer(const phi::DenseTensor& tensorData,
T** result,
const Context& dev_ctx) {
if (tensorData.dtype() == DataType::FLOAT16) {
const float16* real_data = tensorData.template data<float16>();
int len = tensorData.numel();
int r = ConvertDataByType<Context, float16, T>(
real_data, result, len, true, dev_ctx);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}
template <typename Context, typename T>
static void GetOutDataPointer(DenseTensor* tensorData,
DenseTensor* out,
T** result,
const Context& dev_ctx) {
if (tensorData->dtype() == DataType::FLOAT16) {
*result = dev_ctx.template Alloc<T>(out);
} else {
*result = dev_ctx.template Alloc<T>(tensorData);
}
}
template <typename Context, typename T>
static void CopyOutData(const DenseTensor& srcTensor,
phi::DenseTensor* dstTensor,
const Context& dev_ctx) {
if (dstTensor->dtype() == DataType::FLOAT16) {
const T* xpu_out_data = srcTensor.template data<T>();
float16* out_data = dev_ctx.template Alloc<float16>(dstTensor);
int len = srcTensor.numel();
int r = ConvertDataByType<Context, T, float16>(
xpu_out_data, &out_data, len, false, dev_ctx);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}
template <typename T>
static void FreeData(const phi::DenseTensor& tensorData, T* dataPtr) {
if (tensorData.dtype() == DataType::FLOAT16) xpu_free(dataPtr);
}
template <typename Context, typename T>
static void SetBetaData(const phi::DenseTensor& beta_pow,
phi::DenseTensor* beta_pow_out,
const T& beta,
const Context& dev_ctx) {
if (beta_pow.dtype() == DataType::FLOAT16) {
const float16* beta_pow_p = beta_pow.template data<float16>();
dev_ctx.template HostAlloc<float16>(beta_pow_out)[0] =
static_cast<float16>(beta) * beta_pow_p[0];
} else {
const T* beta_pow_p = beta_pow.template data<T>();
dev_ctx.template HostAlloc<T>(beta_pow_out)[0] = beta * beta_pow_p[0];
}
}
template <typename Context, typename T>
static void Scale(phi::DenseTensor* beta_pow_out,
const phi::DenseTensor& beta_pow,
T* beta_pow_ptr,
const T& beta,
const Context& dev_ctx) {
float16* beta_pow_out_p2 = dev_ctx.template Alloc<float16>(beta_pow_out);
DenseTensor xpu_beta_pow_out;
const phi::DenseTensorMeta meta_beta_pow_out(DataType::FLOAT32,
beta_pow_out->dims());
xpu_beta_pow_out.set_meta(meta_beta_pow_out);
T* beta_pow_out_ptr = dev_ctx.template Alloc<T>(&xpu_beta_pow_out);
int r = xpu::scale(dev_ctx.x_context(),
beta_pow_ptr,
beta_pow_out_ptr,
beta_pow.numel(),
false,
beta,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
const float* xpu_beta_pow_out_data =
dev_ctx.template Alloc<T>(&xpu_beta_pow_out);
int len = xpu_beta_pow_out.numel();
r = ConvertDataByType<Context, T, float16>(
xpu_beta_pow_out_data, &beta_pow_out_p2, len, false, dev_ctx);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
#endif
struct GPUAdam;
struct CPUAdam;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/adam_kernel.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/adam_functors.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace phi {
namespace sr {
using float16 = dtype::float16;
template <typename T, typename Context>
void AdamDenseParamSparseGradKernel(
const Context& dev_ctx,
const DenseTensor& param,
const SelectedRows& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
float* param_ptr = nullptr;
funcs::GetDataPointer<Context, float>(param, &param_ptr, dev_ctx);
float* mom1_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment1, &mom1_ptr, dev_ctx);
float* mom2_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment2, &mom2_ptr, dev_ctx);
float* lr_ptr = nullptr;
funcs::GetDataPointer<Context, float>(learning_rate, &lr_ptr, dev_ctx);
float* beta1_pow_ptr = nullptr;
const float* beta1_const_pow_ptr = nullptr;
if (beta1_pow.place() == CPUPlace()) {
DenseTensor xpu_beta1_pow;
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, &xpu_beta1_pow);
if (xpu_beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta1_pow, &beta1_pow_ptr, dev_ctx);
else
beta1_const_pow_ptr = xpu_beta1_pow.template data<float>();
} else {
if (beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta1_pow, &beta1_pow_ptr, dev_ctx);
else
beta1_const_pow_ptr = beta1_pow.template data<float>();
}
float* beta2_pow_ptr = nullptr;
const float* beta2_const_pow_ptr = nullptr;
if (beta2_pow.place() == CPUPlace()) {
DenseTensor xpu_beta2_pow;
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, &xpu_beta2_pow);
if (xpu_beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta2_pow, &beta2_pow_ptr, dev_ctx);
else
beta2_const_pow_ptr = xpu_beta2_pow.template data<float>();
} else {
if (beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta2_pow, &beta2_pow_ptr, dev_ctx);
else
beta2_const_pow_ptr = beta2_pow.template data<float>();
}
DenseTensor xpu_param_out;
float* param_out_ptr = nullptr;
const phi::DenseTensorMeta meta_param(DataType::FLOAT32, param_out->dims());
xpu_param_out.set_meta(meta_param);
funcs::GetOutDataPointer<Context, float>(
param_out, &xpu_param_out, &param_out_ptr, dev_ctx);
DenseTensor xpu_mom1_out;
float* mom1_out_ptr = nullptr;
const phi::DenseTensorMeta meta_mom1(DataType::FLOAT32, moment1_out->dims());
xpu_mom1_out.set_meta(meta_mom1);
funcs::GetOutDataPointer<Context, float>(
moment1_out, &xpu_mom1_out, &mom1_out_ptr, dev_ctx);
DenseTensor xpu_mom2_out;
float* mom2_out_ptr = nullptr;
const phi::DenseTensorMeta meta_mom2(DataType::FLOAT32, moment2_out->dims());
xpu_mom2_out.set_meta(meta_mom2);
funcs::GetOutDataPointer<Context, float>(
moment2_out, &xpu_mom2_out, &mom2_out_ptr, dev_ctx);
bool skip_update_ = false;
if (skip_update.is_initialized()) {
PADDLE_ENFORCE_EQ(
skip_update->numel(),
1,
errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d",
skip_update->numel()));
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec);
skip_update_ = skip_update_vec[0];
}
if (skip_update_) {
VLOG(4) << "Adam skip update";
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
PADDLE_ENFORCE_EQ(
beta1_pow_out->numel(),
1,
errors::InvalidArgument("Tensor holds the wrong size, Expected beta1 pow "
"output size is 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(
beta2_pow_out->numel(),
1,
errors::InvalidArgument("Tensor holds the wrong size, Expected beta2 pow "
"output size is 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
auto beta1_ = beta1.to<float>();
auto beta2_ = beta2.to<float>();
auto epsilon_ = epsilon.to<float>();
float* grad_c = nullptr;
if (grad.rows().size() == 0) {
VLOG(3) << "grad row size is 0!!";
return;
}
std::vector<int64_t> cpu_rows(grad.rows().begin(), grad.rows().end());
bool is_strict_sorted = true;
for (size_t i = 1; i < cpu_rows.size(); ++i) {
if (cpu_rows[i - 1] >= cpu_rows[i]) {
is_strict_sorted = false;
break;
}
}
SelectedRows tmp_grad_merge;
const SelectedRows* grad_merge_ptr;
if (is_strict_sorted) {
grad_merge_ptr = &grad;
} else {
paddle::operators::math::scatter::MergeAdd<Context, float> merge_func;
merge_func(dev_ctx, grad, &tmp_grad_merge, true);
xpu_wait(dev_ctx.x_context()->xpu_stream);
grad_merge_ptr = &tmp_grad_merge;
}
auto& grad_merge = *grad_merge_ptr;
auto& grad_tensor = grad_merge.value();
funcs::GetDataPointer<Context, float>(grad_tensor, &grad_c, dev_ctx);
int row_count = grad_merge.rows().size();
std::vector<int> rows(row_count);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* xpu_rows = RAII_GUARD.alloc_l3_or_gm<int>(row_count);
std::vector<int64_t> merge_rows(grad_merge.rows().begin(),
grad_merge.rows().end());
for (size_t i = 0; i < grad_merge.rows().size(); ++i) {
rows[i] = static_cast<int>(merge_rows[i]);
}
xpu_wait(dev_ctx.x_context()->xpu_stream);
paddle::memory::Copy(dev_ctx.GetPlace(),
xpu_rows,
CPUPlace(),
rows.data(),
row_count * sizeof(int));
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
auto ori_rows = param.numel() / row_numel;
int r = xpu::sparse_adam(
dev_ctx.x_context(),
grad_c != nullptr ? grad_c : grad_tensor.template data<float>(),
mom1_ptr != nullptr ? mom1_ptr : moment1.template data<float>(),
mom2_ptr != nullptr ? mom2_ptr : moment2.template data<float>(),
param_ptr != nullptr ? param_ptr : param.template data<float>(),
beta1_pow_ptr != nullptr ? beta1_pow_ptr : beta1_const_pow_ptr,
beta2_pow_ptr != nullptr ? beta2_pow_ptr : beta2_const_pow_ptr,
lr_ptr != nullptr ? lr_ptr : learning_rate.template data<float>(),
mom1_out_ptr,
mom2_out_ptr,
param_out_ptr,
beta1_,
beta2_,
epsilon_,
ori_rows,
xpu_rows,
row_numel,
grad_merge.rows().size(),
lazy_mode);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
funcs::FreeData<float>(grad_tensor, grad_c);
funcs::CopyOutData<Context, float>(xpu_mom1_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_mom2_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_param_out, moment1_out, dev_ctx);
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
funcs::SetBetaData<Context, float>(
beta1_pow, beta1_pow_out, beta1_, dev_ctx);
funcs::SetBetaData<Context, float>(
beta2_pow, beta2_pow_out, beta2_, dev_ctx);
} else {
float* beta1_pow_out_p1 = nullptr;
if (beta1_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1_, dev_ctx);
} else {
const float* beta1_pow_data = beta1_pow.template data<float>();
beta1_pow_out_p1 = dev_ctx.template Alloc<float>(beta1_pow_out);
r = xpu::scale(dev_ctx.x_context(),
beta1_pow_data,
beta1_pow_out_p1,
beta1_pow.numel(),
false,
beta1_,
0.0f);
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
float* beta2_pow_out_p1 = nullptr;
if (beta2_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2_, dev_ctx);
} else {
const float* beta2_pow_data = beta2_pow.template data<float>();
beta2_pow_out_p1 = dev_ctx.template Alloc<float>(beta2_pow_out);
r = xpu::scale(dev_ctx.x_context(),
beta2_pow_data,
beta2_pow_out_p1,
beta2_pow.numel(),
false,
beta2_,
0.0f);
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}
}
funcs::FreeData<float>(param, param_ptr);
funcs::FreeData<float>(moment1, mom1_ptr);
funcs::FreeData<float>(moment2, mom2_ptr);
funcs::FreeData<float>(learning_rate, lr_ptr);
}
} // namespace sr
} // namespace phi
PD_REGISTER_KERNEL(adam_dense_param_sparse_grad,
XPU,
ALL_LAYOUT,
phi::sr::AdamDenseParamSparseGradKernel,
float,
phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/adam_kernel.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/adam_functors.h"
namespace phi {
using float16 = dtype::float16;
template <typename T, typename Context>
void AdamDenseKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& moment1,
const DenseTensor& moment2,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const paddle::optional<DenseTensor>& master_param,
const paddle::optional<DenseTensor>& skip_update,
const Scalar& beta1,
const Scalar& beta2,
const Scalar& epsilon,
bool lazy_mode,
int64_t min_row_size_to_use_multithread,
bool multi_precision,
bool use_global_beta_pow,
DenseTensor* param_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* master_param_outs) {
float* param_ptr = nullptr;
funcs::GetDataPointer<Context, float>(param, &param_ptr, dev_ctx);
float* mom1_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment1, &mom1_ptr, dev_ctx);
float* mom2_ptr = nullptr;
funcs::GetDataPointer<Context, float>(moment2, &mom2_ptr, dev_ctx);
float* lr_ptr = nullptr;
funcs::GetDataPointer<Context, float>(learning_rate, &lr_ptr, dev_ctx);
float* beta1_pow_ptr = nullptr;
const float* beta1_const_pow_ptr = nullptr;
if (beta1_pow.place() == CPUPlace()) {
DenseTensor xpu_beta1_pow;
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, &xpu_beta1_pow);
if (xpu_beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta1_pow, &beta1_pow_ptr, dev_ctx);
else
beta1_const_pow_ptr = xpu_beta1_pow.template data<float>();
} else {
if (beta1_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta1_pow, &beta1_pow_ptr, dev_ctx);
else
beta1_const_pow_ptr = beta1_pow.template data<float>();
}
float* beta2_pow_ptr = nullptr;
const float* beta2_const_pow_ptr = nullptr;
if (beta2_pow.place() == CPUPlace()) {
DenseTensor xpu_beta2_pow;
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, &xpu_beta2_pow);
if (xpu_beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(
xpu_beta2_pow, &beta2_pow_ptr, dev_ctx);
else
beta2_const_pow_ptr = xpu_beta2_pow.template data<float>();
} else {
if (beta2_pow.dtype() == DataType::FLOAT16)
funcs::GetDataPointer<Context, float>(beta2_pow, &beta2_pow_ptr, dev_ctx);
else
beta2_const_pow_ptr = beta2_pow.template data<float>();
}
DenseTensor xpu_param_out;
float* param_out_ptr = nullptr;
const phi::DenseTensorMeta meta_param(DataType::FLOAT32, param_out->dims());
xpu_param_out.set_meta(meta_param);
funcs::GetOutDataPointer<Context, float>(
param_out, &xpu_param_out, &param_out_ptr, dev_ctx);
DenseTensor xpu_mom1_out;
float* mom1_out_ptr = nullptr;
const phi::DenseTensorMeta meta_mom1(DataType::FLOAT32, moment1_out->dims());
xpu_mom1_out.set_meta(meta_mom1);
funcs::GetOutDataPointer<Context, float>(
moment1_out, &xpu_mom1_out, &mom1_out_ptr, dev_ctx);
DenseTensor xpu_mom2_out;
float* mom2_out_ptr = nullptr;
const phi::DenseTensorMeta meta_mom2(DataType::FLOAT32, moment2_out->dims());
xpu_mom2_out.set_meta(meta_mom2);
funcs::GetOutDataPointer<Context, float>(
moment2_out, &xpu_mom2_out, &mom2_out_ptr, dev_ctx);
bool skip_update_ = false;
if (skip_update.is_initialized()) {
PADDLE_ENFORCE_EQ(
skip_update->numel(),
1,
errors::InvalidArgument("Input(SkipUpdate) size must be 1, but get %d",
skip_update->numel()));
std::vector<bool> skip_update_vec;
paddle::framework::TensorToVector(*skip_update, dev_ctx, &skip_update_vec);
skip_update_ = skip_update_vec[0];
}
if (skip_update_) {
VLOG(4) << "Adam skip update";
phi::Copy(dev_ctx, param, dev_ctx.GetPlace(), false, param_out);
phi::Copy(dev_ctx, moment1, dev_ctx.GetPlace(), false, moment1_out);
phi::Copy(dev_ctx, moment2, dev_ctx.GetPlace(), false, moment2_out);
phi::Copy(dev_ctx, beta1_pow, beta1_pow.place(), false, beta1_pow_out);
phi::Copy(dev_ctx, beta2_pow, beta2_pow.place(), false, beta2_pow_out);
return;
}
PADDLE_ENFORCE_EQ(
beta1_pow_out->numel(),
1,
errors::InvalidArgument("Tensor holds the wrong size, Expected beta1 pow "
"output size is 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(
beta2_pow_out->numel(),
1,
errors::InvalidArgument("Tensor holds the wrong size, Expected beta2 pow "
"output size is 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
VLOG(4) << "use_global_beta_pow:" << use_global_beta_pow;
auto beta1_ = beta1.to<float>();
auto beta2_ = beta2.to<float>();
auto epsilon_ = epsilon.to<float>();
float* grad_c = nullptr;
funcs::GetDataPointer<Context, float>(grad, &grad_c, dev_ctx);
int r = xpu::adam(
dev_ctx.x_context(),
grad_c != nullptr ? grad_c : grad.template data<float>(),
mom1_ptr != nullptr ? mom1_ptr : moment1.template data<float>(),
mom2_ptr != nullptr ? mom2_ptr : moment2.template data<float>(),
param_ptr != nullptr ? param_ptr : param.template data<float>(),
beta1_pow_ptr != nullptr ? beta1_pow_ptr : beta1_const_pow_ptr,
beta2_pow_ptr != nullptr ? beta2_pow_ptr : beta2_const_pow_ptr,
lr_ptr != nullptr ? lr_ptr : learning_rate.template data<float>(),
mom1_out_ptr,
mom2_out_ptr,
param_out_ptr,
beta1_,
beta2_,
epsilon_,
param.numel());
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
funcs::FreeData<float>(grad, grad_c);
funcs::CopyOutData<Context, float>(xpu_mom1_out, moment1_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_mom2_out, moment2_out, dev_ctx);
funcs::CopyOutData<Context, float>(xpu_param_out, param_out, dev_ctx);
if (!use_global_beta_pow) {
// update in cpu and then copy to xpu
if (beta1_pow.place() == CPUPlace() && beta2_pow.place() == CPUPlace()) {
funcs::SetBetaData<Context, float>(
beta1_pow, beta1_pow_out, beta1_, dev_ctx);
funcs::SetBetaData<Context, float>(
beta2_pow, beta2_pow_out, beta2_, dev_ctx);
} else {
float* beta1_pow_out_p1 = nullptr;
if (beta1_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta1_pow_out, beta1_pow, beta1_pow_ptr, beta1_, dev_ctx);
} else {
const float* beta1_pow_data = beta1_pow.template data<float>();
beta1_pow_out_p1 = dev_ctx.template Alloc<float>(beta1_pow_out);
r = xpu::scale(dev_ctx.x_context(),
beta1_pow_data,
beta1_pow_out_p1,
beta1_pow.numel(),
false,
beta1_,
0.0f);
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
float* beta2_pow_out_p1 = nullptr;
if (beta2_pow_out->dtype() == DataType::FLOAT16) {
funcs::Scale<Context, float>(
beta2_pow_out, beta2_pow, beta2_pow_ptr, beta2_, dev_ctx);
} else {
const float* beta2_pow_data = beta2_pow.template data<float>();
beta2_pow_out_p1 = dev_ctx.template Alloc<float>(beta2_pow_out);
r = xpu::scale(dev_ctx.x_context(),
beta2_pow_data,
beta2_pow_out_p1,
beta2_pow.numel(),
false,
beta2_,
0.0f);
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}
}
funcs::FreeData<float>(param, param_ptr);
funcs::FreeData<float>(moment1, mom1_ptr);
funcs::FreeData<float>(moment2, mom2_ptr);
funcs::FreeData<float>(learning_rate, lr_ptr);
}
} // namespace phi
PD_REGISTER_KERNEL(
adam, XPU, ALL_LAYOUT, phi::AdamDenseKernel, float, phi::dtype::float16) {
// Skip beta1_pow, beta2_pow, skip_update data transform
kernel->InputAt(5).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(6).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(8).SetBackend(phi::Backend::ALL_BACKEND);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册