diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 99244ea9bd919a018732b75d1ab811e8bf338516..79fcb780feb9310a35c60e2c9589c538fa0870ec 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool) USE_JITKERNEL_GEN(kHMax) USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kEmbSeqPool) +USE_JITKERNEL_GEN(kAdam) USE_JITKERNEL_GEN(kSgd) USE_JITKERNEL_GEN(kVBroadcast) diff --git a/paddle/fluid/operators/jit/gen/adam.cc b/paddle/fluid/operators/jit/gen/adam.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e8cb7f59eed61e69c14719eba9ac01c7b13868a --- /dev/null +++ b/paddle/fluid/operators/jit/gen/adam.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/adam.h" + +#include // offsetof + +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void AdamJitCode::loadArgs() { + static constexpr int32_t one_as_float = 0x3f800000; + static constexpr int32_t mask_all_ones = 0xFFFFFFFF; + static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8; + static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8; + + mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]); + mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]); + mov(eax, one_as_float); + movd(xmm_one, eax); + + vbroadcastss(ymm_one, xmm_one); // 1 + vbroadcastss(ymm_beta1, xmm_beta1); // beta1 + vbroadcastss(ymm_beta2, xmm_beta2); // beta2 + vbroadcastss(ymm_lr, xmm_lr); // -lr + vbroadcastss(ymm_eps, xmm_eps); // eps + vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1 + vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2 + + mov(reg_numel_without_tail, reg_numel); + and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible + + shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset + shl(reg_numel, 2); + + mov(eax, mask_all_ones); + kmovw(k1, eax); + + xor_(reg_offset, reg_offset); +} + +void AdamJitCode::setTailOpmask() { + mov(r13, rcx); + + mov(rcx, reg_numel); + sub(rcx, reg_offset); // get tail numel as float size + shr(rcx, 2); // as elements + mov(r14, 1); + shl(r14, cl); // 2 ^ elements + dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1 + kmovw(k1, r14d); + + mov(rcx, r13); +} + +void AdamJitCode::mainCode() { + // load grad + vmovups(ymm7 | k1, ptr[reg_grad_ptr + reg_offset]); + + // beta1 * mom1 + (1 - beta1) * g + vmulps(ymm8 | k1, ymm_one_sub_beta1, ymm7); + vfmadd231ps(ymm8 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]); + + // beta2 * mom2 + (1 - beta2) * g * g + vmulps(ymm7 | k1, ymm7, ymm7); + vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7); + vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]); + + // store mom1 and mom2 + vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8); + vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7); + + // sqrt(mom2) + eps + vsqrtps(ymm7 | k1, ymm7); + vaddps(ymm7 | k1, ymm7, ymm3); + + // p + (-lr) * (mom1 / sqrt(mom2) + eps) + vdivps(ymm7 | k1, ymm8, ymm7); + vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]); + + // store p + vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7); +} + +void AdamJitCode::genCode() { + static constexpr int64_t main_loop_elems_size = + 8 * sizeof(float); // 8 floats in YMM + static constexpr int64_t offset_increment = main_loop_elems_size; + preCode(); + loadArgs(); + + cmp(reg_numel, main_loop_elems_size); + jl("process_tail"); + + L("main_loop"); + { + mainCode(); + add(reg_offset, offset_increment); + cmp(reg_numel_without_tail, reg_offset); + jg("main_loop"); + } + + cmp(reg_numel, reg_offset); + je("end"); + + L("process_tail"); + { + setTailOpmask(); + mainCode(); + } + + L("end"); + postCode(); +} + +class AdamCreator : public JitCodeCreator { + public: + bool CanBeUsed(const adam_attr_t& attr) const override { + return platform::MayIUse(platform::avx512f); + } + size_t CodeSize(const adam_attr_t& attr) const override { + return 96 + 32 * 8; + } + std::unique_ptr CreateJitCode( + const adam_attr_t& attr) const override { + return make_unique(attr, CodeSize(attr)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator); diff --git a/paddle/fluid/operators/jit/gen/adam.h b/paddle/fluid/operators/jit/gen/adam.h new file mode 100644 index 0000000000000000000000000000000000000000..86a38e97ece021245d70ebd59b5dbb3a99f43514 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/adam.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include + +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class AdamJitCode : public JitCode { + public: + explicit AdamJitCode(const adam_attr_t& attr, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr) { + this->genCode(); + } + + DECLARE_JIT_CODE(AdamJitCode); + void genCode() override; + void loadArgs(); + void setTailOpmask(); + void mainCode(); + + private: + reg64_t reg_numel{abi_param1}; + reg64_t reg_grad_ptr{abi_param2}; + reg64_t reg_mom1_ptr{abi_param3}; + reg64_t reg_mom2_ptr{abi_param4}; + reg64_t reg_param_ptr{abi_param5}; + reg64_t reg_mom1_out_ptr{abi_param6}; + + xmm_t xmm_beta1 = xmm_t(0); + xmm_t xmm_beta2 = xmm_t(1); + xmm_t xmm_lr = xmm_t(2); + xmm_t xmm_eps = xmm_t(3); + xmm_t xmm_one_sub_beta1 = xmm_t(4); + xmm_t xmm_one_sub_beta2 = xmm_t(5); + xmm_t xmm_one = xmm_t(6); + + ymm_t ymm_beta1 = ymm_t(0); + ymm_t ymm_beta2 = ymm_t(1); + ymm_t ymm_lr = ymm_t(2); + ymm_t ymm_eps = ymm_t(3); + ymm_t ymm_one_sub_beta1 = ymm_t(4); + ymm_t ymm_one_sub_beta2 = ymm_t(5); + ymm_t ymm_one = ymm_t(6); + + reg64_t reg_mom2_out_ptr{r10}; + reg64_t reg_param_out_ptr{r11}; + reg64_t reg_numel_without_tail{r12}; + reg64_t reg_offset{rax}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 23650c8efc73b007c465b48505684c0e5d182f7e..bd84368a573881e0eaba02d6d19d239985a42940 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32; using xmm_t = const Xbyak::Xmm; using ymm_t = const Xbyak::Ymm; using zmm_t = const Xbyak::Zmm; +using opmask_t = const Xbyak::Opmask; using Label = Xbyak::Label; typedef enum { diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index 2085aa41e3b90d79d1a87b9c44d096e818a1f20d..4bdb65030590fdf84360478c73cdb463010a49f5 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -58,6 +58,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kSeqPool); ONE_CASE(kMatMul); ONE_CASE(kHMax); + ONE_CASE(kAdam); ONE_CASE(kHSum); ONE_CASE(kStrideASum); ONE_CASE(kSoftmax); diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 0791bb5810526cb930fe1869a60913d4239f72a3..f217cf6e77854780eeb66d4083ee67858c0dace0 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +inline std::ostream& operator<<(std::ostream& os, const adam_attr_t& attr) { + os << "beta1[" << attr.beta1 << "],beta2[" << attr.beta2 << "]"; + return os; +} + inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) { os << "param_height[" << attr.param_height << "],param_width[" << attr.param_width << "],grad_height[" << attr.grad_height diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 6e0393b820f3780940d37659a067a630a6a0ae2b..40ea04d3c2791d0bde6cd76a4b97f5c12dbfbd99 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -24,8 +24,9 @@ namespace jit { typedef enum { kNone = 0, // sort by alphabet - kCRFDecoding = 1, - kEmbSeqPool = 2, + kAdam = 1, + kCRFDecoding, + kEmbSeqPool, kGRUH1, kGRUHtPart1, kGRUHtPart2, @@ -269,6 +270,21 @@ struct SgdTuple { const sgd_attr_t*); }; +typedef struct adam_attr_s { + float beta1, beta2; + adam_attr_s() = default; + explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {} +} adam_attr_t; + +template +struct AdamTuple { + static constexpr KernelType kernel_type = kAdam; + typedef T data_type; + typedef adam_attr_t attr_type; + typedef void (*func_type)(T, T, T, T, int64_t, const T*, const T*, const T*, + const T*, T*, T*, T*); +}; + typedef struct matmul_attr_s { int m, n, k; void* packed_weight{nullptr}; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index a7b1addeb5ded7790f1ec5f12952deb1e522501d..4f652002bc7455180b8eef6d4a5e111b3aa72dfb 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -63,6 +63,11 @@ int64_t JitCodeKey(const sgd_attr_t& attr) { return attr.grad_width; } +template <> +int64_t JitCodeKey(const adam_attr_t& attr) { + return static_cast(attr.beta1 + attr.beta2); +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 7133f596620410d37ffe52a2ee92b7a9974bf1cc..e4e3263e01ebae709f04869659d5523fc785e19d 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kStrideASum) USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kEmbSeqPool) +USE_JITKERNEL_REFER(kAdam) USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kVBroadcast) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 460cb6c58076d7f6c49b60fed45584bd9b506c63..8669bfe37232bfd0e5ca274e573a02db7e727071 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum); REGISTER_REFER_KERNEL(StrideASum); REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(EmbSeqPool); +REGISTER_REFER_KERNEL(Adam); REGISTER_REFER_KERNEL(Sgd); REGISTER_REFER_KERNEL(VBroadcast); diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 42fb7b4f279c225fb38a49d23e9d76ac1854d12d..3545b35a703f8c87199439b471e8133deabbd7f6 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, } } +template +void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr, + const T* mom1_ptr, const T* mom2_ptr, const T* param_ptr, + T* mom1_out_ptr, T* mom2_out_ptr, T* param_out_ptr) { + for (int i = 0; i < numel; ++i) { + mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i]; + mom2_out_ptr[i] = + beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i]; + param_out_ptr[i] = + param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps)); + } +} + #define DECLARE_REFER_KERNEL(name) \ template \ class name##Kernel : public ReferKernel> { \ @@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool); DECLARE_REFER_KERNEL(MatMul); DECLARE_REFER_KERNEL(Softmax); DECLARE_REFER_KERNEL(EmbSeqPool); +DECLARE_REFER_KERNEL(Adam); DECLARE_REFER_KERNEL(Sgd); DECLARE_REFER_KERNEL(VBroadcast); diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index ff68565637c5a98f6f8bf5021ac685846edc605d..675db4a72bda33bf8d0fbf3adb698eae33f66ac3 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -841,6 +841,72 @@ void TestKernelStrideScal() { } } +template +void TestKernelAdam() { + using T = typename KernelTuple::data_type; + VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type); + const T lr = 0.1; + const T beta1 = 0.99; + const T beta2 = 0.95; + const T beta1_pow = beta1 * beta1; + const T beta2_pow = beta2 * beta2; + + const T epsilon = 0.000001; + const int64_t numel = 123; + + T learning_rate = lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow)); + T eps = epsilon * sqrt(1 - beta2_pow); + + std::vector param(numel); + std::vector grad(numel); + std::vector mom1(numel); + std::vector mom2(numel); + + std::vector param_out(param.size()); + std::vector mom1_out(mom1.size()); + std::vector mom2_out(mom2.size()); + + RandomVec(numel, param.data(), 0.5f); + RandomVec(numel, grad.data(), 0.5f); + RandomVec(numel, mom1.data(), 0.5f); + RandomVec(numel, mom2.data(), 0.5f); + + auto ref = jit::GetReferFunc(); + EXPECT_TRUE(ref != nullptr); + jit::adam_attr_t attr(beta1, beta2); + ref(beta1, beta2, -learning_rate, eps, numel, grad.data(), mom1.data(), + mom2.data(), param.data(), mom1_out.data(), mom2_out.data(), + param_out.data()); + + auto verifier = []( + const typename KernelTuple::func_type tgt, T beta1, T beta2, T lr, T eps, + int64_t numel, const std::vector& grad, const std::vector& mom1, + const std::vector& mom2, const std::vector& param, + const std::vector& ref_mom1_out, const std::vector& ref_mom2_out, + const std::vector& ref_param_out) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(param.size(), static_cast(numel)); + EXPECT_EQ(grad.size(), static_cast(numel)); + EXPECT_EQ(mom1.size(), static_cast(numel)); + EXPECT_EQ(mom2.size(), static_cast(numel)); + + std::vector jit_mom1_out(ref_mom1_out.size()); + std::vector jit_mom2_out(ref_mom2_out.size()); + std::vector jit_param_out(ref_param_out.size()); + + tgt(beta1, beta2, -lr, eps, numel, grad.data(), mom1.data(), mom2.data(), + param.data(), jit_mom1_out.data(), jit_mom2_out.data(), + jit_param_out.data()); + + ExpectEQ(ref_mom1_out.data(), jit_mom1_out.data(), numel); + ExpectEQ(ref_mom2_out.data(), jit_mom2_out.data(), numel); + ExpectEQ(ref_param_out.data(), jit_param_out.data(), numel); + }; + TestAllImpls( + attr, verifier, beta1, beta2, learning_rate, eps, numel, grad, mom1, mom2, + param, mom1_out, mom2_out, param_out); +} + template void TestKernelSgd() { using T = typename KernelTuple::data_type; @@ -980,7 +1046,7 @@ TEST(JITKernel_pool, jitcreator) { #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) EXPECT_EQ(jitcreators.size(), 0UL); #else - EXPECT_EQ(jitcreators.size(), 25UL); + EXPECT_EQ(jitcreators.size(), 26UL); #endif } @@ -1014,7 +1080,7 @@ TEST(JITKernel_pool, more) { TEST(JITKernel_pool, refer) { const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); - EXPECT_EQ(kers.size(), 31UL); + EXPECT_EQ(kers.size(), 32UL); } // test helper @@ -1147,9 +1213,10 @@ TEST(JITKernel_helper, attr) { << jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity) << jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu) << jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd) - << jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare) - << jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh); - EXPECT_EQ(out.str().size(), 234UL); + << jit::to_string(jit::kAdam) << jit::to_string(jit::kVSigmoid) + << jit::to_string(jit::kVSquare) << jit::to_string(jit::kVSub) + << jit::to_string(jit::kVTanh); + EXPECT_EQ(out.str().size(), 239UL); // SeqPoolTypes out.str(""); @@ -1296,6 +1363,19 @@ TEST(JITKernel_key, emb_seq_pool) { EXPECT_TRUE(key4 != key5); } +TEST(JITKernel_key, adam) { + jit::adam_attr_t attr1(0.4f, 0.9f); + jit::adam_attr_t attr2(0.4f, 0.9f); + jit::adam_attr_t attr3(0.1f, 0.3f); + + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + + EXPECT_TRUE(key1 == key2); + EXPECT_TRUE(key2 != key3); +} + TEST(JITKernel_key, sgd) { jit::sgd_attr_t attr1(1, 2, 3, 4, 5); jit::sgd_attr_t attr2(1, 2, 3, 4, 5); @@ -1316,7 +1396,7 @@ TEST(JITKernel_key, sgd) { EXPECT_TRUE(key4 != key5); } -// test kernerls +// test kernels #define TestKernelVMul TestKernelXYZN #define TestKernelVAdd TestKernelXYZN #define TestKernelVAddRelu TestKernelXYZN @@ -1383,6 +1463,7 @@ TEST_CPU_KERNEL(SeqPool); TEST_CPU_KERNEL(EmbSeqPool); TEST_CPU_KERNEL(MatMul); TEST_CPU_KERNEL(Softmax); +TEST_CPU_KERNEL(Adam); TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(VBroadcast); diff --git a/paddle/fluid/operators/optimizers/adam_op.h b/paddle/fluid/operators/optimizers/adam_op.h index bcc314cd57c017b577d8370a6e593366364dbdd9..bdeaa106282d2550e4fd928627da098168ea469a 100644 --- a/paddle/fluid/operators/optimizers/adam_op.h +++ b/paddle/fluid/operators/optimizers/adam_op.h @@ -20,9 +20,11 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/for_range.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace operators { @@ -506,21 +508,58 @@ class AdamOpKernel : public framework::OpKernel { beta2_pow_out->numel())); if (grad_var->IsType()) { - auto* grad = ctx.Input("Grad"); + T beta1_p = beta1_pow->data()[0]; + T beta2_p = beta2_pow->data()[0]; - AdamFunctor functor( - beta1, beta2, epsilon, beta1_pow->data(), beta2_pow->data(), - mom1->data(), mom1_out->mutable_data(ctx.GetPlace()), - mom2->data(), mom2_out->mutable_data(ctx.GetPlace()), - lr->data(), grad->data(), param->data(), - param_out->mutable_data(ctx.GetPlace())); - functor(param->numel()); if (!use_global_beta_pow) { beta1_pow_out->mutable_data(ctx.GetPlace())[0] = beta1 * beta1_pow->data()[0]; beta2_pow_out->mutable_data(ctx.GetPlace())[0] = beta2 * beta2_pow->data()[0]; } + + auto* grad = ctx.Input("Grad"); + + T* param_out_ptr = param_out->mutable_data(ctx.GetPlace()); + T* mom1_out_ptr = mom1_out->mutable_data(ctx.GetPlace()); + T* mom2_out_ptr = mom2_out->mutable_data(ctx.GetPlace()); + + T learning_rate = lr->data()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p)); + T eps = epsilon * sqrt(1 - beta2_p); + + jit::adam_attr_t attr(beta1, beta2); + int64_t numel = param->numel(); + + const T* param_ptr = param->data(); + const T* mom1_ptr = mom1->data(); + const T* mom2_ptr = mom2->data(); + const T* grad_ptr = grad->data(); + + auto adam = + jit::KernelFuncs, platform::CPUPlace>::Cache().At( + attr); + + static constexpr int64_t chunk_size = 512; + +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < numel / chunk_size; ++i) { + const int64_t offset = i * chunk_size; + adam(beta1, beta2, -learning_rate, eps, chunk_size, grad_ptr + offset, + mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset, + mom1_out_ptr + offset, mom2_out_ptr + offset, + param_out_ptr + offset); + } + + if (numel % chunk_size != 0) { + const int64_t offset = (numel / chunk_size) * chunk_size; + const int64_t tail_numel = numel % chunk_size; + adam(beta1, beta2, -learning_rate, eps, tail_numel, grad_ptr + offset, + mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset, + mom1_out_ptr + offset, mom2_out_ptr + offset, + param_out_ptr + offset); + } } else if (grad_var->IsType()) { auto* grad = ctx.Input("Grad"); if (grad->rows().size() == 0) { diff --git a/python/paddle/fluid/tests/unittests/test_adam_op.py b/python/paddle/fluid/tests/unittests/test_adam_op.py index a06f0d390e517d6434b5232c3eb3c5d9b0115150..ecac22553cbcda7cc2dae179603f407eddc8652a 100644 --- a/python/paddle/fluid/tests/unittests/test_adam_op.py +++ b/python/paddle/fluid/tests/unittests/test_adam_op.py @@ -69,15 +69,19 @@ class TestAdamOp1(OpTest): class TestAdamOp2(OpTest): + def set_shape(self): + self.shape = (102, 105) + def setUp(self): '''Test Adam Op with supplied attributes ''' self.op_type = "adam" - param = np.random.uniform(-1, 1, (102, 105)).astype("float32") - grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") - moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") + self.set_shape() + param = np.random.uniform(-1, 1, self.shape).astype("float32") + grad = np.random.uniform(-1, 1, self.shape).astype("float32") + moment1 = np.random.uniform(-1, 1, self.shape).astype("float32") # The second moment is positive - moment2 = np.random.random((102, 105)).astype("float32") + moment2 = np.random.random(self.shape).astype("float32") learning_rate = 0.001 beta1 = 0.9 @@ -113,6 +117,11 @@ class TestAdamOp2(OpTest): self.check_output() +class TestAdamOnlyTailOp(TestAdamOp2): + def set_shape(self): + self.shape = (3) + + class TestAdamOpMultipleSteps(OpTest): def setUp(self): '''Test Adam Operator with supplied attributes