From 766c50ac67a445f69c98b0ffc001fca1c89a28d1 Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Mon, 9 May 2022 09:37:31 +0200 Subject: [PATCH] [Need approval] Add AdamW-CPU FP32 JIT assembly kernel (#42522) * Add AdamW jit kernel * Second implementation * Add missing header * Correct number of jit kernels in the test --- paddle/fluid/operators/jit/gen/CMakeLists.txt | 1 + paddle/fluid/operators/jit/gen/adam.cc | 6 +- paddle/fluid/operators/jit/gen/adamw.cc | 165 ++++++++++++++++++ paddle/fluid/operators/jit/gen/adamw.h | 81 +++++++++ paddle/fluid/operators/jit/helper.cc | 1 + paddle/fluid/operators/jit/kernel_base.h | 10 ++ .../fluid/operators/jit/refer/CMakeLists.txt | 1 + paddle/fluid/operators/jit/refer/refer.cc | 1 + paddle/fluid/operators/jit/refer/refer.h | 16 ++ paddle/fluid/operators/jit/test.cc | 72 +++++++- paddle/phi/kernels/cpu/adamw_kernel.cc | 125 +++++++++---- 11 files changed, 442 insertions(+), 37 deletions(-) create mode 100644 paddle/fluid/operators/jit/gen/adamw.cc create mode 100644 paddle/fluid/operators/jit/gen/adamw.h diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 79fcb780fe..ab8829b7ba 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -33,5 +33,6 @@ USE_JITKERNEL_GEN(kHMax) USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kEmbSeqPool) USE_JITKERNEL_GEN(kAdam) +USE_JITKERNEL_GEN(kAdamW) USE_JITKERNEL_GEN(kSgd) USE_JITKERNEL_GEN(kVBroadcast) diff --git a/paddle/fluid/operators/jit/gen/adam.cc b/paddle/fluid/operators/jit/gen/adam.cc index 7e8cb7f59e..38ef6772f0 100644 --- a/paddle/fluid/operators/jit/gen/adam.cc +++ b/paddle/fluid/operators/jit/gen/adam.cc @@ -80,7 +80,7 @@ void AdamJitCode::mainCode() { // beta2 * mom2 + (1 - beta2) * g * g vmulps(ymm7 | k1, ymm7, ymm7); vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7); - vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]); + vfmadd231ps(ymm7 | k1, ymm_beta2, ptr[reg_mom2_ptr + reg_offset]); // store mom1 and mom2 vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8); @@ -88,11 +88,11 @@ void AdamJitCode::mainCode() { // sqrt(mom2) + eps vsqrtps(ymm7 | k1, ymm7); - vaddps(ymm7 | k1, ymm7, ymm3); + vaddps(ymm7 | k1, ymm7, ymm_eps); // p + (-lr) * (mom1 / sqrt(mom2) + eps) vdivps(ymm7 | k1, ymm8, ymm7); - vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]); + vfmadd213ps(ymm7 | k1, ymm_lr, ptr[reg_param_ptr + reg_offset]); // store p vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7); diff --git a/paddle/fluid/operators/jit/gen/adamw.cc b/paddle/fluid/operators/jit/gen/adamw.cc new file mode 100644 index 0000000000..b470143fb7 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/adamw.cc @@ -0,0 +1,165 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/adamw.h" + +#include // offsetof + +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void AdamWJitCode::loadArgs() { + static constexpr int32_t one_as_float = 0x3f800000; + static constexpr int32_t mask_all_ones = 0xFFFFFFFF; + static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8; + static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8; + + mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]); + mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]); + mov(eax, one_as_float); + movd(xmm_one, eax); + + vbroadcastss(ymm_one, xmm_one); // 1 + vbroadcastss(ymm_beta1, xmm_beta1); // beta1 + vbroadcastss(ymm_beta2, xmm_beta2); // beta2 + vbroadcastss(ymm_lr, xmm_lr); // -lr + vbroadcastss(ymm_eps, xmm_eps); // eps + vbroadcastss(ymm_old_lr, xmm_old_lr); // old lr + vbroadcastss(ymm_lr_ratio, xmm_lr_ratio); // lr_ratio + vbroadcastss(ymm_coeff, xmm_coeff); // coeff + vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1 + vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2 + + mov(reg_numel_without_tail, reg_numel); + and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible + + shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset + shl(reg_numel, 2); + + mov(eax, mask_all_ones); + kmovw(k1, eax); + + xor_(reg_offset, reg_offset); +} + +void AdamWJitCode::setTailOpmask() { + mov(r13, rcx); + + mov(rcx, reg_numel); + sub(rcx, reg_offset); // get tail numel as float size + shr(rcx, 2); // as elements + mov(r14, 1); + shl(r14, cl); // 2 ^ elements + dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1 + kmovw(k1, r14d); + + mov(rcx, r13); +} + +void AdamWJitCode::mainCode() { + // load p + vmovups(ymm10 | k1, ptr[reg_param_ptr + reg_offset]); + + // ((lr * lr_ratio) * coeff) + vmulps(ymm11 | k1, ymm_old_lr, ymm_lr_ratio); + vmulps(ymm11 | k1, ymm11, ymm_coeff); + + // - (lr * lr_ratio) * coeff) * p + p + // p is stored in ymm11 + vfnmadd132ps(ymm11 | k1, ymm10, ymm10); + + // load grad + vmovups(ymm10 | k1, ptr[reg_grad_ptr + reg_offset]); + + // beta1 * mom1 + (1 - beta1) * g + vmulps(ymm12 | k1, ymm_one_sub_beta1, ymm10); + vfmadd231ps(ymm12 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]); + + // beta2 * mom2 + (1 - beta2) * g * g + vmulps(ymm10 | k1, ymm10, ymm10); + vmulps(ymm10 | k1, ymm_one_sub_beta2, ymm10); + vfmadd231ps(ymm10 | k1, ymm_beta2, ptr[reg_mom2_ptr + reg_offset]); + + // store mom1 and mom2 + vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm12); + vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm10); + + // sqrt(mom2) + eps + vsqrtps(ymm10 | k1, ymm10); + vaddps(ymm10 | k1, ymm10, ymm_eps); + + // p + (-lr) * (mom1 / sqrt(mom2) + eps) + vdivps(ymm10 | k1, ymm12, ymm10); + vfmadd213ps(ymm10 | k1, ymm_lr, ymm11); + + // store p + vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm10); +} + +void AdamWJitCode::genCode() { + static constexpr int64_t main_loop_elems_size = + 8 * sizeof(float); // 8 floats in YMM + static constexpr int64_t offset_increment = main_loop_elems_size; + preCode(); + loadArgs(); + + cmp(reg_numel, main_loop_elems_size); + jl("process_tail"); + + L("main_loop"); + { + mainCode(); + add(reg_offset, offset_increment); + cmp(reg_numel_without_tail, reg_offset); + jg("main_loop"); + } + + cmp(reg_numel, reg_offset); + je("end", T_NEAR); // size between jmp and label is larger than 127 byte, + // T_NEAR allow long jump + + L("process_tail"); + { + setTailOpmask(); + mainCode(); + } + + L("end"); + postCode(); +} + +class AdamWCreator : public JitCodeCreator { + public: + bool CanBeUsed(const int& attr) const override { + return platform::MayIUse(platform::avx512f); + } + size_t CodeSize(const int& attr) const override { return 96 + 32 * 8; } + std::unique_ptr CreateJitCode(const int& attr) const override { + return make_unique(attr, CodeSize(attr)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kAdamW, gen::AdamWCreator); diff --git a/paddle/fluid/operators/jit/gen/adamw.h b/paddle/fluid/operators/jit/gen/adamw.h new file mode 100644 index 0000000000..759dcd62c8 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/adamw.h @@ -0,0 +1,81 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class AdamWJitCode : public JitCode { + public: + explicit AdamWJitCode(const int& attr, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr) { + this->genCode(); + } + + DECLARE_JIT_CODE(AdamJitCode); + void genCode() override; + void loadArgs(); + void setTailOpmask(); + void mainCode(); + + private: + reg64_t reg_numel{abi_param1}; + reg64_t reg_grad_ptr{abi_param2}; + reg64_t reg_mom1_ptr{abi_param3}; + reg64_t reg_mom2_ptr{abi_param4}; + reg64_t reg_param_ptr{abi_param5}; + reg64_t reg_mom1_out_ptr{abi_param6}; + + xmm_t xmm_beta1 = xmm_t(0); + xmm_t xmm_beta2 = xmm_t(1); + xmm_t xmm_lr = xmm_t(2); + xmm_t xmm_eps = xmm_t(3); + xmm_t xmm_old_lr = xmm_t(4); + xmm_t xmm_lr_ratio = xmm_t(5); + xmm_t xmm_coeff = xmm_t(6); + xmm_t xmm_one_sub_beta1 = xmm_t(7); + xmm_t xmm_one_sub_beta2 = xmm_t(8); + xmm_t xmm_one = xmm_t(9); + + ymm_t ymm_beta1 = ymm_t(0); + ymm_t ymm_beta2 = ymm_t(1); + ymm_t ymm_lr = ymm_t(2); + ymm_t ymm_eps = ymm_t(3); + ymm_t ymm_old_lr = ymm_t(4); + ymm_t ymm_lr_ratio = ymm_t(5); + ymm_t ymm_coeff = ymm_t(6); + ymm_t ymm_one_sub_beta1 = ymm_t(7); + ymm_t ymm_one_sub_beta2 = ymm_t(8); + ymm_t ymm_one = ymm_t(9); + + reg64_t reg_mom2_out_ptr{r10}; + reg64_t reg_param_out_ptr{r11}; + reg64_t reg_numel_without_tail{r12}; + reg64_t reg_offset{rax}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index 4bdb650305..46da6fba2e 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -59,6 +59,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kMatMul); ONE_CASE(kHMax); ONE_CASE(kAdam); + ONE_CASE(kAdamW); ONE_CASE(kHSum); ONE_CASE(kStrideASum); ONE_CASE(kSoftmax); diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 40ea04d3c2..9a48d9c3c8 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -25,6 +25,7 @@ typedef enum { kNone = 0, // sort by alphabet kAdam = 1, + kAdamW, kCRFDecoding, kEmbSeqPool, kGRUH1, @@ -285,6 +286,15 @@ struct AdamTuple { const T*, T*, T*, T*); }; +template +struct AdamWTuple { + static constexpr KernelType kernel_type = kAdamW; + typedef T data_type; + typedef int attr_type; + typedef void (*func_type)(T, T, T, T, T, T, T, int64_t, const T*, const T*, + const T*, const T*, T*, T*, T*); +}; + typedef struct matmul_attr_s { int m, n, k; void* packed_weight{nullptr}; diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index e4e3263e01..a1ee4508f7 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -37,5 +37,6 @@ USE_JITKERNEL_REFER(kStrideASum) USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kAdam) +USE_JITKERNEL_REFER(kAdamW) USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kVBroadcast) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 8669bfe372..779d4c172b 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -56,6 +56,7 @@ REGISTER_REFER_KERNEL(StrideASum); REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(EmbSeqPool); REGISTER_REFER_KERNEL(Adam); +REGISTER_REFER_KERNEL(AdamW); REGISTER_REFER_KERNEL(Sgd); REGISTER_REFER_KERNEL(VBroadcast); diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 3545b35a70..79b2e174ef 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -565,6 +565,21 @@ void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr, } } +template +void AdamW(T beta1, T beta2, T lr, T eps, T old_lr, T lr_ratio, T coeff, + int64_t numel, const T* grad_ptr, const T* mom1_ptr, + const T* mom2_ptr, const T* param_ptr, T* mom1_out_ptr, + T* mom2_out_ptr, T* param_out_ptr) { + for (int i = 0; i < numel; ++i) { + auto param_tmp = param_ptr[i] - old_lr * lr_ratio * coeff * param_ptr[i]; + mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i]; + mom2_out_ptr[i] = + beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; + param_out_ptr[i] = + param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps)); + } +} + #define DECLARE_REFER_KERNEL(name) \ template \ class name##Kernel : public ReferKernel> { \ @@ -617,6 +632,7 @@ DECLARE_REFER_KERNEL(MatMul); DECLARE_REFER_KERNEL(Softmax); DECLARE_REFER_KERNEL(EmbSeqPool); DECLARE_REFER_KERNEL(Adam); +DECLARE_REFER_KERNEL(AdamW); DECLARE_REFER_KERNEL(Sgd); DECLARE_REFER_KERNEL(VBroadcast); diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 675db4a72b..74f2d62c64 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -907,6 +907,73 @@ void TestKernelAdam() { param, mom1_out, mom2_out, param_out); } +template +void TestKernelAdamW() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + const T old_lr = 0.1; + const T beta1 = 0.99; + const T beta2 = 0.95; + const T beta1_pow = beta1 * beta1; + const T beta2_pow = beta2 * beta2; + + const T epsilon = 0.000001; + const int64_t numel = 123; + const T lr_ratio = 0.2; + const T coeff = 0.3; + + T learning_rate = old_lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); + T eps = epsilon * sqrt(1 - beta2_pow); + + std::vector param(numel); + std::vector grad(numel); + std::vector mom1(numel); + std::vector mom2(numel); + + std::vector param_out(param.size()); + std::vector mom1_out(mom1.size()); + std::vector mom2_out(mom2.size()); + + RandomVec(numel, param.data(), 0.5f); + RandomVec(numel, grad.data(), 0.5f); + RandomVec(numel, mom1.data(), 0.5f); + RandomVec(numel, mom2.data(), 0.5f); + auto ref = jit::GetReferFunc(); + EXPECT_TRUE(ref != nullptr); + ref(beta1, beta2, -learning_rate, eps, old_lr, lr_ratio, coeff, numel, + grad.data(), mom1.data(), mom2.data(), param.data(), mom1_out.data(), + mom2_out.data(), param_out.data()); + + auto verifier = []( + const typename KernelTuple::func_type tgt, T beta1, T beta2, T lr, T eps, + T old_lr, T lr_ratio, T coeff, int64_t numel, const std::vector& grad, + const std::vector& mom1, const std::vector& mom2, + const std::vector& param, const std::vector& ref_mom1_out, + const std::vector& ref_mom2_out, const std::vector& ref_param_out) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(param.size(), static_cast(numel)); + EXPECT_EQ(grad.size(), static_cast(numel)); + EXPECT_EQ(mom1.size(), static_cast(numel)); + EXPECT_EQ(mom2.size(), static_cast(numel)); + + std::vector jit_mom1_out(ref_mom1_out.size()); + std::vector jit_mom2_out(ref_mom2_out.size()); + std::vector jit_param_out(ref_param_out.size()); + + tgt(beta1, beta2, -lr, eps, old_lr, lr_ratio, coeff, numel, grad.data(), + mom1.data(), mom2.data(), param.data(), jit_mom1_out.data(), + jit_mom2_out.data(), jit_param_out.data()); + + ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); + ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); + ExpectEQ(ref_param_out.data(), jit_param_out.data(), numel); + }; + + TestAllImpls( + 1, verifier, beta1, beta2, learning_rate, eps, old_lr, lr_ratio, coeff, + numel, grad, mom1, mom2, param, mom1_out, mom2_out, param_out); +} + template void TestKernelSgd() { using T = typename KernelTuple::data_type; @@ -1046,7 +1113,7 @@ TEST(JITKernel_pool, jitcreator) { #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) EXPECT_EQ(jitcreators.size(), 0UL); #else - EXPECT_EQ(jitcreators.size(), 26UL); + EXPECT_EQ(jitcreators.size(), 27UL); #endif } @@ -1080,7 +1147,7 @@ TEST(JITKernel_pool, more) { TEST(JITKernel_pool, refer) { const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); - EXPECT_EQ(kers.size(), 32UL); + EXPECT_EQ(kers.size(), 33UL); } // test helper @@ -1464,6 +1531,7 @@ TEST_CPU_KERNEL(EmbSeqPool); TEST_CPU_KERNEL(MatMul); TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Adam); +TEST_CPU_KERNEL(AdamW); TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(VBroadcast); diff --git a/paddle/phi/kernels/cpu/adamw_kernel.cc b/paddle/phi/kernels/cpu/adamw_kernel.cc index 3a7869a062..f2c98fded4 100644 --- a/paddle/phi/kernels/cpu/adamw_kernel.cc +++ b/paddle/phi/kernels/cpu/adamw_kernel.cc @@ -17,6 +17,7 @@ #include #include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/core/kernel_registry.h" @@ -92,41 +93,101 @@ void AdamwDenseKernel(const Context& dev_ctx, return; } - auto* param_ = - master_param.is_initialized() ? master_param.get_ptr() : ¶m; + T beta1_ = beta1.to(); + T beta2_ = beta2.to(); + T epsilon_ = epsilon.to(); T coeff_ = static_cast(coeff); T lr_ratio_ = static_cast(lr_ratio); - funcs::AdamWFunctor functor( - coeff_, - lr_ratio_, - learning_rate.data(), - const_cast(param_->data())); - functor(param_->numel()); - - AdamDenseKernel(dev_ctx, - param, - grad, - learning_rate, - moment1, - moment2, - beta1_pow, - beta2_pow, - master_param, - skip_update, - beta1, - beta2, - epsilon, - lazy_mode, - min_row_size_to_use_multithread, - multi_precision, - use_global_beta_pow, - param_out, - moment1_out, - moment2_out, - beta1_pow_out, - beta2_pow_out, - master_param_outs); + VLOG(3) << "beta1_pow.numel() : " << beta1_pow.numel(); + VLOG(3) << "beta2_pow.numel() : " << beta2_pow.numel(); + VLOG(3) << "param.numel(): " << param.numel(); + + PADDLE_ENFORCE_EQ( + beta1_pow_out->numel(), + 1, + errors::InvalidArgument("beta1 pow output size should be 1, but received " + "value is:%d.", + beta1_pow_out->numel())); + + PADDLE_ENFORCE_EQ( + beta2_pow_out->numel(), + 1, + errors::InvalidArgument("beta2 pow output size should be 1, but received " + "value is:%d.", + beta2_pow_out->numel())); + + T beta1_p = beta1_pow.data()[0]; + T beta2_p = beta2_pow.data()[0]; + + if (!use_global_beta_pow) { + dev_ctx.template Alloc(beta1_pow_out)[0] = beta1_ * beta1_p; + dev_ctx.template Alloc(beta2_pow_out)[0] = beta2_ * beta2_p; + } + + T* param_out_ptr = dev_ctx.template Alloc(param_out); + T* mom1_out_ptr = dev_ctx.template Alloc(moment1_out); + T* mom2_out_ptr = dev_ctx.template Alloc(moment2_out); + T old_lr = learning_rate.data()[0]; + T learning_rate_ = + learning_rate.data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); + T eps = epsilon_ * sqrt(1 - beta2_p); + + int64_t numel = param.numel(); + + const T* param_ptr = param.data(); + const T* mom1_ptr = moment1.data(); + const T* mom2_ptr = moment2.data(); + const T* grad_ptr = grad.data(); + + auto adamw = + paddle::operators::jit::KernelFuncs, + phi::CPUPlace>::Cache() + .At(1); + + static constexpr int64_t chunk_size = 512; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < numel / chunk_size; ++i) { + const int64_t offset = i * chunk_size; + adamw(beta1_, + beta2_, + -learning_rate_, + eps, + old_lr, + lr_ratio_, + coeff_, + chunk_size, + grad_ptr + offset, + mom1_ptr + offset, + mom2_ptr + offset, + param_ptr + offset, + mom1_out_ptr + offset, + mom2_out_ptr + offset, + param_out_ptr + offset); + } + + if (numel % chunk_size != 0) { + const int64_t offset = (numel / chunk_size) * chunk_size; + const int64_t tail_numel = numel % chunk_size; + adamw(beta1_, + beta2_, + -learning_rate_, + eps, + old_lr, + lr_ratio_, + coeff_, + tail_numel, + grad_ptr + offset, + mom1_ptr + offset, + mom2_ptr + offset, + param_ptr + offset, + mom1_out_ptr + offset, + mom2_out_ptr + offset, + param_out_ptr + offset); + } } } // namespace phi -- GitLab