未验证 提交 766c50ac 编写于 作者: J joanna.wozna.intel 提交者: GitHub

[Need approval] Add AdamW-CPU FP32 JIT assembly kernel (#42522)

* Add AdamW jit kernel

* Second implementation

* Add missing header

* Correct number of jit kernels in the test
上级 bf481550
...@@ -33,5 +33,6 @@ USE_JITKERNEL_GEN(kHMax) ...@@ -33,5 +33,6 @@ USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool) USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kAdam) USE_JITKERNEL_GEN(kAdam)
USE_JITKERNEL_GEN(kAdamW)
USE_JITKERNEL_GEN(kSgd) USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast) USE_JITKERNEL_GEN(kVBroadcast)
...@@ -80,7 +80,7 @@ void AdamJitCode::mainCode() { ...@@ -80,7 +80,7 @@ void AdamJitCode::mainCode() {
// beta2 * mom2 + (1 - beta2) * g * g // beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm7 | k1, ymm7, ymm7); vmulps(ymm7 | k1, ymm7, ymm7);
vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7); vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7);
vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]); vfmadd231ps(ymm7 | k1, ymm_beta2, ptr[reg_mom2_ptr + reg_offset]);
// store mom1 and mom2 // store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8); vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8);
...@@ -88,11 +88,11 @@ void AdamJitCode::mainCode() { ...@@ -88,11 +88,11 @@ void AdamJitCode::mainCode() {
// sqrt(mom2) + eps // sqrt(mom2) + eps
vsqrtps(ymm7 | k1, ymm7); vsqrtps(ymm7 | k1, ymm7);
vaddps(ymm7 | k1, ymm7, ymm3); vaddps(ymm7 | k1, ymm7, ymm_eps);
// p + (-lr) * (mom1 / sqrt(mom2) + eps) // p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm7 | k1, ymm8, ymm7); vdivps(ymm7 | k1, ymm8, ymm7);
vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]); vfmadd213ps(ymm7 | k1, ymm_lr, ptr[reg_param_ptr + reg_offset]);
// store p // store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7); vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7);
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/adamw.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void AdamWJitCode::loadArgs() {
static constexpr int32_t one_as_float = 0x3f800000;
static constexpr int32_t mask_all_ones = 0xFFFFFFFF;
static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8;
static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8;
mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]);
mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]);
mov(eax, one_as_float);
movd(xmm_one, eax);
vbroadcastss(ymm_one, xmm_one); // 1
vbroadcastss(ymm_beta1, xmm_beta1); // beta1
vbroadcastss(ymm_beta2, xmm_beta2); // beta2
vbroadcastss(ymm_lr, xmm_lr); // -lr
vbroadcastss(ymm_eps, xmm_eps); // eps
vbroadcastss(ymm_old_lr, xmm_old_lr); // old lr
vbroadcastss(ymm_lr_ratio, xmm_lr_ratio); // lr_ratio
vbroadcastss(ymm_coeff, xmm_coeff); // coeff
vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1
vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2
mov(reg_numel_without_tail, reg_numel);
and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible
shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset
shl(reg_numel, 2);
mov(eax, mask_all_ones);
kmovw(k1, eax);
xor_(reg_offset, reg_offset);
}
void AdamWJitCode::setTailOpmask() {
mov(r13, rcx);
mov(rcx, reg_numel);
sub(rcx, reg_offset); // get tail numel as float size
shr(rcx, 2); // as elements
mov(r14, 1);
shl(r14, cl); // 2 ^ elements
dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1
kmovw(k1, r14d);
mov(rcx, r13);
}
void AdamWJitCode::mainCode() {
// load p
vmovups(ymm10 | k1, ptr[reg_param_ptr + reg_offset]);
// ((lr * lr_ratio) * coeff)
vmulps(ymm11 | k1, ymm_old_lr, ymm_lr_ratio);
vmulps(ymm11 | k1, ymm11, ymm_coeff);
// - (lr * lr_ratio) * coeff) * p + p
// p is stored in ymm11
vfnmadd132ps(ymm11 | k1, ymm10, ymm10);
// load grad
vmovups(ymm10 | k1, ptr[reg_grad_ptr + reg_offset]);
// beta1 * mom1 + (1 - beta1) * g
vmulps(ymm12 | k1, ymm_one_sub_beta1, ymm10);
vfmadd231ps(ymm12 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]);
// beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm10 | k1, ymm10, ymm10);
vmulps(ymm10 | k1, ymm_one_sub_beta2, ymm10);
vfmadd231ps(ymm10 | k1, ymm_beta2, ptr[reg_mom2_ptr + reg_offset]);
// store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm12);
vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm10);
// sqrt(mom2) + eps
vsqrtps(ymm10 | k1, ymm10);
vaddps(ymm10 | k1, ymm10, ymm_eps);
// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm10 | k1, ymm12, ymm10);
vfmadd213ps(ymm10 | k1, ymm_lr, ymm11);
// store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm10);
}
void AdamWJitCode::genCode() {
static constexpr int64_t main_loop_elems_size =
8 * sizeof(float); // 8 floats in YMM
static constexpr int64_t offset_increment = main_loop_elems_size;
preCode();
loadArgs();
cmp(reg_numel, main_loop_elems_size);
jl("process_tail");
L("main_loop");
{
mainCode();
add(reg_offset, offset_increment);
cmp(reg_numel_without_tail, reg_offset);
jg("main_loop");
}
cmp(reg_numel, reg_offset);
je("end", T_NEAR); // size between jmp and label is larger than 127 byte,
// T_NEAR allow long jump
L("process_tail");
{
setTailOpmask();
mainCode();
}
L("end");
postCode();
}
class AdamWCreator : public JitCodeCreator<int> {
public:
bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& attr) const override { return 96 + 32 * 8; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<AdamWJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kAdamW, gen::AdamWCreator);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class AdamWJitCode : public JitCode {
public:
explicit AdamWJitCode(const int& attr, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
DECLARE_JIT_CODE(AdamJitCode);
void genCode() override;
void loadArgs();
void setTailOpmask();
void mainCode();
private:
reg64_t reg_numel{abi_param1};
reg64_t reg_grad_ptr{abi_param2};
reg64_t reg_mom1_ptr{abi_param3};
reg64_t reg_mom2_ptr{abi_param4};
reg64_t reg_param_ptr{abi_param5};
reg64_t reg_mom1_out_ptr{abi_param6};
xmm_t xmm_beta1 = xmm_t(0);
xmm_t xmm_beta2 = xmm_t(1);
xmm_t xmm_lr = xmm_t(2);
xmm_t xmm_eps = xmm_t(3);
xmm_t xmm_old_lr = xmm_t(4);
xmm_t xmm_lr_ratio = xmm_t(5);
xmm_t xmm_coeff = xmm_t(6);
xmm_t xmm_one_sub_beta1 = xmm_t(7);
xmm_t xmm_one_sub_beta2 = xmm_t(8);
xmm_t xmm_one = xmm_t(9);
ymm_t ymm_beta1 = ymm_t(0);
ymm_t ymm_beta2 = ymm_t(1);
ymm_t ymm_lr = ymm_t(2);
ymm_t ymm_eps = ymm_t(3);
ymm_t ymm_old_lr = ymm_t(4);
ymm_t ymm_lr_ratio = ymm_t(5);
ymm_t ymm_coeff = ymm_t(6);
ymm_t ymm_one_sub_beta1 = ymm_t(7);
ymm_t ymm_one_sub_beta2 = ymm_t(8);
ymm_t ymm_one = ymm_t(9);
reg64_t reg_mom2_out_ptr{r10};
reg64_t reg_param_out_ptr{r11};
reg64_t reg_numel_without_tail{r12};
reg64_t reg_offset{rax};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
...@@ -59,6 +59,7 @@ const char* to_string(KernelType kt) { ...@@ -59,6 +59,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul); ONE_CASE(kMatMul);
ONE_CASE(kHMax); ONE_CASE(kHMax);
ONE_CASE(kAdam); ONE_CASE(kAdam);
ONE_CASE(kAdamW);
ONE_CASE(kHSum); ONE_CASE(kHSum);
ONE_CASE(kStrideASum); ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax); ONE_CASE(kSoftmax);
......
...@@ -25,6 +25,7 @@ typedef enum { ...@@ -25,6 +25,7 @@ typedef enum {
kNone = 0, kNone = 0,
// sort by alphabet // sort by alphabet
kAdam = 1, kAdam = 1,
kAdamW,
kCRFDecoding, kCRFDecoding,
kEmbSeqPool, kEmbSeqPool,
kGRUH1, kGRUH1,
...@@ -285,6 +286,15 @@ struct AdamTuple { ...@@ -285,6 +286,15 @@ struct AdamTuple {
const T*, T*, T*, T*); const T*, T*, T*, T*);
}; };
template <typename T>
struct AdamWTuple {
static constexpr KernelType kernel_type = kAdamW;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T, T, T, T, T, T, T, int64_t, const T*, const T*,
const T*, const T*, T*, T*, T*);
};
typedef struct matmul_attr_s { typedef struct matmul_attr_s {
int m, n, k; int m, n, k;
void* packed_weight{nullptr}; void* packed_weight{nullptr};
......
...@@ -37,5 +37,6 @@ USE_JITKERNEL_REFER(kStrideASum) ...@@ -37,5 +37,6 @@ USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kAdam) USE_JITKERNEL_REFER(kAdam)
USE_JITKERNEL_REFER(kAdamW)
USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast) USE_JITKERNEL_REFER(kVBroadcast)
...@@ -56,6 +56,7 @@ REGISTER_REFER_KERNEL(StrideASum); ...@@ -56,6 +56,7 @@ REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool); REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Adam); REGISTER_REFER_KERNEL(Adam);
REGISTER_REFER_KERNEL(AdamW);
REGISTER_REFER_KERNEL(Sgd); REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast); REGISTER_REFER_KERNEL(VBroadcast);
......
...@@ -565,6 +565,21 @@ void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr, ...@@ -565,6 +565,21 @@ void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr,
} }
} }
template <typename T>
void AdamW(T beta1, T beta2, T lr, T eps, T old_lr, T lr_ratio, T coeff,
int64_t numel, const T* grad_ptr, const T* mom1_ptr,
const T* mom2_ptr, const T* param_ptr, T* mom1_out_ptr,
T* mom2_out_ptr, T* param_out_ptr) {
for (int i = 0; i < numel; ++i) {
auto param_tmp = param_ptr[i] - old_lr * lr_ratio * coeff * param_ptr[i];
mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i];
mom2_out_ptr[i] =
beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i];
param_out_ptr[i] =
param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps));
}
}
#define DECLARE_REFER_KERNEL(name) \ #define DECLARE_REFER_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \ class name##Kernel : public ReferKernel<name##Tuple<T>> { \
...@@ -617,6 +632,7 @@ DECLARE_REFER_KERNEL(MatMul); ...@@ -617,6 +632,7 @@ DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax); DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool); DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Adam); DECLARE_REFER_KERNEL(Adam);
DECLARE_REFER_KERNEL(AdamW);
DECLARE_REFER_KERNEL(Sgd); DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast); DECLARE_REFER_KERNEL(VBroadcast);
......
...@@ -907,6 +907,73 @@ void TestKernelAdam() { ...@@ -907,6 +907,73 @@ void TestKernelAdam() {
param, mom1_out, mom2_out, param_out); param, mom1_out, mom2_out, param_out);
} }
template <typename KernelTuple, typename PlaceType>
void TestKernelAdamW() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
const T old_lr = 0.1;
const T beta1 = 0.99;
const T beta2 = 0.95;
const T beta1_pow = beta1 * beta1;
const T beta2_pow = beta2 * beta2;
const T epsilon = 0.000001;
const int64_t numel = 123;
const T lr_ratio = 0.2;
const T coeff = 0.3;
T learning_rate = old_lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow));
T eps = epsilon * sqrt(1 - beta2_pow);
std::vector<T> param(numel);
std::vector<T> grad(numel);
std::vector<T> mom1(numel);
std::vector<T> mom2(numel);
std::vector<T> param_out(param.size());
std::vector<T> mom1_out(mom1.size());
std::vector<T> mom2_out(mom2.size());
RandomVec<T>(numel, param.data(), 0.5f);
RandomVec<T>(numel, grad.data(), 0.5f);
RandomVec<T>(numel, mom1.data(), 0.5f);
RandomVec<T>(numel, mom2.data(), 0.5f);
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
ref(beta1, beta2, -learning_rate, eps, old_lr, lr_ratio, coeff, numel,
grad.data(), mom1.data(), mom2.data(), param.data(), mom1_out.data(),
mom2_out.data(), param_out.data());
auto verifier = [](
const typename KernelTuple::func_type tgt, T beta1, T beta2, T lr, T eps,
T old_lr, T lr_ratio, T coeff, int64_t numel, const std::vector<T>& grad,
const std::vector<T>& mom1, const std::vector<T>& mom2,
const std::vector<T>& param, const std::vector<T>& ref_mom1_out,
const std::vector<T>& ref_mom2_out, const std::vector<T>& ref_param_out) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(param.size(), static_cast<size_t>(numel));
EXPECT_EQ(grad.size(), static_cast<size_t>(numel));
EXPECT_EQ(mom1.size(), static_cast<size_t>(numel));
EXPECT_EQ(mom2.size(), static_cast<size_t>(numel));
std::vector<T> jit_mom1_out(ref_mom1_out.size());
std::vector<T> jit_mom2_out(ref_mom2_out.size());
std::vector<T> jit_param_out(ref_param_out.size());
tgt(beta1, beta2, -lr, eps, old_lr, lr_ratio, coeff, numel, grad.data(),
mom1.data(), mom2.data(), param.data(), jit_mom1_out.data(),
jit_mom2_out.data(), jit_param_out.data());
ExpectEQ<T>(ref_mom1_out.data(), jit_mom1_out.data(), numel);
ExpectEQ<T>(ref_mom2_out.data(), jit_mom2_out.data(), numel);
ExpectEQ<T>(ref_param_out.data(), jit_param_out.data(), numel);
};
TestAllImpls<KernelTuple, PlaceType>(
1, verifier, beta1, beta2, learning_rate, eps, old_lr, lr_ratio, coeff,
numel, grad, mom1, mom2, param, mom1_out, mom2_out, param_out);
}
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void TestKernelSgd() { void TestKernelSgd() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
...@@ -1046,7 +1113,7 @@ TEST(JITKernel_pool, jitcreator) { ...@@ -1046,7 +1113,7 @@ TEST(JITKernel_pool, jitcreator) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(jitcreators.size(), 0UL); EXPECT_EQ(jitcreators.size(), 0UL);
#else #else
EXPECT_EQ(jitcreators.size(), 26UL); EXPECT_EQ(jitcreators.size(), 27UL);
#endif #endif
} }
...@@ -1080,7 +1147,7 @@ TEST(JITKernel_pool, more) { ...@@ -1080,7 +1147,7 @@ TEST(JITKernel_pool, more) {
TEST(JITKernel_pool, refer) { TEST(JITKernel_pool, refer) {
const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 32UL); EXPECT_EQ(kers.size(), 33UL);
} }
// test helper // test helper
...@@ -1464,6 +1531,7 @@ TEST_CPU_KERNEL(EmbSeqPool); ...@@ -1464,6 +1531,7 @@ TEST_CPU_KERNEL(EmbSeqPool);
TEST_CPU_KERNEL(MatMul); TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Adam); TEST_CPU_KERNEL(Adam);
TEST_CPU_KERNEL(AdamW);
TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast); TEST_CPU_KERNEL(VBroadcast);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -92,41 +93,101 @@ void AdamwDenseKernel(const Context& dev_ctx, ...@@ -92,41 +93,101 @@ void AdamwDenseKernel(const Context& dev_ctx,
return; return;
} }
auto* param_ = T beta1_ = beta1.to<T>();
master_param.is_initialized() ? master_param.get_ptr() : &param; T beta2_ = beta2.to<T>();
T epsilon_ = epsilon.to<T>();
T coeff_ = static_cast<T>(coeff); T coeff_ = static_cast<T>(coeff);
T lr_ratio_ = static_cast<T>(lr_ratio); T lr_ratio_ = static_cast<T>(lr_ratio);
funcs::AdamWFunctor<T, funcs::CPUAdamW> functor( VLOG(3) << "beta1_pow.numel() : " << beta1_pow.numel();
coeff_, VLOG(3) << "beta2_pow.numel() : " << beta2_pow.numel();
VLOG(3) << "param.numel(): " << param.numel();
PADDLE_ENFORCE_EQ(
beta1_pow_out->numel(),
1,
errors::InvalidArgument("beta1 pow output size should be 1, but received "
"value is:%d.",
beta1_pow_out->numel()));
PADDLE_ENFORCE_EQ(
beta2_pow_out->numel(),
1,
errors::InvalidArgument("beta2 pow output size should be 1, but received "
"value is:%d.",
beta2_pow_out->numel()));
T beta1_p = beta1_pow.data<T>()[0];
T beta2_p = beta2_pow.data<T>()[0];
if (!use_global_beta_pow) {
dev_ctx.template Alloc<T>(beta1_pow_out)[0] = beta1_ * beta1_p;
dev_ctx.template Alloc<T>(beta2_pow_out)[0] = beta2_ * beta2_p;
}
T* param_out_ptr = dev_ctx.template Alloc<T>(param_out);
T* mom1_out_ptr = dev_ctx.template Alloc<T>(moment1_out);
T* mom2_out_ptr = dev_ctx.template Alloc<T>(moment2_out);
T old_lr = learning_rate.data<T>()[0];
T learning_rate_ =
learning_rate.data<T>()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p));
T eps = epsilon_ * sqrt(1 - beta2_p);
int64_t numel = param.numel();
const T* param_ptr = param.data<T>();
const T* mom1_ptr = moment1.data<T>();
const T* mom2_ptr = moment2.data<T>();
const T* grad_ptr = grad.data<T>();
auto adamw =
paddle::operators::jit::KernelFuncs<paddle::operators::jit::AdamWTuple<T>,
phi::CPUPlace>::Cache()
.At(1);
static constexpr int64_t chunk_size = 512;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < numel / chunk_size; ++i) {
const int64_t offset = i * chunk_size;
adamw(beta1_,
beta2_,
-learning_rate_,
eps,
old_lr,
lr_ratio_, lr_ratio_,
learning_rate.data<T>(), coeff_,
const_cast<T*>(param_->data<T>())); chunk_size,
functor(param_->numel()); grad_ptr + offset,
mom1_ptr + offset,
mom2_ptr + offset,
param_ptr + offset,
mom1_out_ptr + offset,
mom2_out_ptr + offset,
param_out_ptr + offset);
}
AdamDenseKernel<T, Context>(dev_ctx, if (numel % chunk_size != 0) {
param, const int64_t offset = (numel / chunk_size) * chunk_size;
grad, const int64_t tail_numel = numel % chunk_size;
learning_rate, adamw(beta1_,
moment1, beta2_,
moment2, -learning_rate_,
beta1_pow, eps,
beta2_pow, old_lr,
master_param, lr_ratio_,
skip_update, coeff_,
beta1, tail_numel,
beta2, grad_ptr + offset,
epsilon, mom1_ptr + offset,
lazy_mode, mom2_ptr + offset,
min_row_size_to_use_multithread, param_ptr + offset,
multi_precision, mom1_out_ptr + offset,
use_global_beta_pow, mom2_out_ptr + offset,
param_out, param_out_ptr + offset);
moment1_out, }
moment2_out,
beta1_pow_out,
beta2_pow_out,
master_param_outs);
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册