未验证 提交 ebd14743 编写于 作者: J jakpiase 提交者: GitHub

Added Adam FP32 JIT assembly kernel (#39158)

* Added adam kernel

* CI rerun
上级 e15e4ed0
...@@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool) ...@@ -32,5 +32,6 @@ USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax) USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool) USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kAdam)
USE_JITKERNEL_GEN(kSgd) USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast) USE_JITKERNEL_GEN(kVBroadcast)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/adam.h"
#include <stddef.h> // offsetof
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void AdamJitCode::loadArgs() {
static constexpr int32_t one_as_float = 0x3f800000;
static constexpr int32_t mask_all_ones = 0xFFFFFFFF;
static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8;
static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8;
mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]);
mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]);
mov(eax, one_as_float);
movd(xmm_one, eax);
vbroadcastss(ymm_one, xmm_one); // 1
vbroadcastss(ymm_beta1, xmm_beta1); // beta1
vbroadcastss(ymm_beta2, xmm_beta2); // beta2
vbroadcastss(ymm_lr, xmm_lr); // -lr
vbroadcastss(ymm_eps, xmm_eps); // eps
vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1
vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2
mov(reg_numel_without_tail, reg_numel);
and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible
shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset
shl(reg_numel, 2);
mov(eax, mask_all_ones);
kmovw(k1, eax);
xor_(reg_offset, reg_offset);
}
void AdamJitCode::setTailOpmask() {
mov(r13, rcx);
mov(rcx, reg_numel);
sub(rcx, reg_offset); // get tail numel as float size
shr(rcx, 2); // as elements
mov(r14, 1);
shl(r14, cl); // 2 ^ elements
dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1
kmovw(k1, r14d);
mov(rcx, r13);
}
void AdamJitCode::mainCode() {
// load grad
vmovups(ymm7 | k1, ptr[reg_grad_ptr + reg_offset]);
// beta1 * mom1 + (1 - beta1) * g
vmulps(ymm8 | k1, ymm_one_sub_beta1, ymm7);
vfmadd231ps(ymm8 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]);
// beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm7 | k1, ymm7, ymm7);
vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7);
vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]);
// store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8);
vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7);
// sqrt(mom2) + eps
vsqrtps(ymm7 | k1, ymm7);
vaddps(ymm7 | k1, ymm7, ymm3);
// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm7 | k1, ymm8, ymm7);
vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]);
// store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7);
}
void AdamJitCode::genCode() {
static constexpr int64_t main_loop_elems_size =
8 * sizeof(float); // 8 floats in YMM
static constexpr int64_t offset_increment = main_loop_elems_size;
preCode();
loadArgs();
cmp(reg_numel, main_loop_elems_size);
jl("process_tail");
L("main_loop");
{
mainCode();
add(reg_offset, offset_increment);
cmp(reg_numel_without_tail, reg_offset);
jg("main_loop");
}
cmp(reg_numel, reg_offset);
je("end");
L("process_tail");
{
setTailOpmask();
mainCode();
}
L("end");
postCode();
}
class AdamCreator : public JitCodeCreator<adam_attr_t> {
public:
bool CanBeUsed(const adam_attr_t& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const adam_attr_t& attr) const override {
return 96 + 32 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(
const adam_attr_t& attr) const override {
return make_unique<AdamJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kAdam, gen::AdamCreator);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class AdamJitCode : public JitCode {
public:
explicit AdamJitCode(const adam_attr_t& attr, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}
DECLARE_JIT_CODE(AdamJitCode);
void genCode() override;
void loadArgs();
void setTailOpmask();
void mainCode();
private:
reg64_t reg_numel{abi_param1};
reg64_t reg_grad_ptr{abi_param2};
reg64_t reg_mom1_ptr{abi_param3};
reg64_t reg_mom2_ptr{abi_param4};
reg64_t reg_param_ptr{abi_param5};
reg64_t reg_mom1_out_ptr{abi_param6};
xmm_t xmm_beta1 = xmm_t(0);
xmm_t xmm_beta2 = xmm_t(1);
xmm_t xmm_lr = xmm_t(2);
xmm_t xmm_eps = xmm_t(3);
xmm_t xmm_one_sub_beta1 = xmm_t(4);
xmm_t xmm_one_sub_beta2 = xmm_t(5);
xmm_t xmm_one = xmm_t(6);
ymm_t ymm_beta1 = ymm_t(0);
ymm_t ymm_beta2 = ymm_t(1);
ymm_t ymm_lr = ymm_t(2);
ymm_t ymm_eps = ymm_t(3);
ymm_t ymm_one_sub_beta1 = ymm_t(4);
ymm_t ymm_one_sub_beta2 = ymm_t(5);
ymm_t ymm_one = ymm_t(6);
reg64_t reg_mom2_out_ptr{r10};
reg64_t reg_param_out_ptr{r11};
reg64_t reg_numel_without_tail{r12};
reg64_t reg_offset{rax};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
...@@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32; ...@@ -45,6 +45,7 @@ using reg32_t = const Xbyak::Reg32;
using xmm_t = const Xbyak::Xmm; using xmm_t = const Xbyak::Xmm;
using ymm_t = const Xbyak::Ymm; using ymm_t = const Xbyak::Ymm;
using zmm_t = const Xbyak::Zmm; using zmm_t = const Xbyak::Zmm;
using opmask_t = const Xbyak::Opmask;
using Label = Xbyak::Label; using Label = Xbyak::Label;
typedef enum { typedef enum {
......
...@@ -58,6 +58,7 @@ const char* to_string(KernelType kt) { ...@@ -58,6 +58,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kSeqPool); ONE_CASE(kSeqPool);
ONE_CASE(kMatMul); ONE_CASE(kMatMul);
ONE_CASE(kHMax); ONE_CASE(kHMax);
ONE_CASE(kAdam);
ONE_CASE(kHSum); ONE_CASE(kHSum);
ONE_CASE(kStrideASum); ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax); ONE_CASE(kSoftmax);
......
...@@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os, ...@@ -275,6 +275,11 @@ inline std::ostream& operator<<(std::ostream& os,
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const adam_attr_t& attr) {
os << "beta1[" << attr.beta1 << "],beta2[" << attr.beta2 << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) { inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) {
os << "param_height[" << attr.param_height << "],param_width[" os << "param_height[" << attr.param_height << "],param_width["
<< attr.param_width << "],grad_height[" << attr.grad_height << attr.param_width << "],grad_height[" << attr.grad_height
......
...@@ -24,8 +24,9 @@ namespace jit { ...@@ -24,8 +24,9 @@ namespace jit {
typedef enum { typedef enum {
kNone = 0, kNone = 0,
// sort by alphabet // sort by alphabet
kCRFDecoding = 1, kAdam = 1,
kEmbSeqPool = 2, kCRFDecoding,
kEmbSeqPool,
kGRUH1, kGRUH1,
kGRUHtPart1, kGRUHtPart1,
kGRUHtPart2, kGRUHtPart2,
...@@ -269,6 +270,21 @@ struct SgdTuple { ...@@ -269,6 +270,21 @@ struct SgdTuple {
const sgd_attr_t*); const sgd_attr_t*);
}; };
typedef struct adam_attr_s {
float beta1, beta2;
adam_attr_s() = default;
explicit adam_attr_s(float beta1, float beta2) : beta1(beta1), beta2(beta2) {}
} adam_attr_t;
template <typename T>
struct AdamTuple {
static constexpr KernelType kernel_type = kAdam;
typedef T data_type;
typedef adam_attr_t attr_type;
typedef void (*func_type)(T, T, T, T, int64_t, const T*, const T*, const T*,
const T*, T*, T*, T*);
};
typedef struct matmul_attr_s { typedef struct matmul_attr_s {
int m, n, k; int m, n, k;
void* packed_weight{nullptr}; void* packed_weight{nullptr};
......
...@@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) { ...@@ -63,6 +63,11 @@ int64_t JitCodeKey<sgd_attr_t>(const sgd_attr_t& attr) {
return attr.grad_width; return attr.grad_width;
} }
template <>
int64_t JitCodeKey<adam_attr_t>(const adam_attr_t& attr) {
return static_cast<int64_t>(attr.beta1 + attr.beta2);
}
} // namespace jit } // namespace jit
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax) ...@@ -36,5 +36,6 @@ USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kStrideASum) USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kAdam)
USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast) USE_JITKERNEL_REFER(kVBroadcast)
...@@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum); ...@@ -55,6 +55,7 @@ REGISTER_REFER_KERNEL(HSum);
REGISTER_REFER_KERNEL(StrideASum); REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax); REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool); REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Adam);
REGISTER_REFER_KERNEL(Sgd); REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast); REGISTER_REFER_KERNEL(VBroadcast);
......
...@@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, ...@@ -552,6 +552,19 @@ void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows,
} }
} }
template <typename T>
void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr,
const T* mom1_ptr, const T* mom2_ptr, const T* param_ptr,
T* mom1_out_ptr, T* mom2_out_ptr, T* param_out_ptr) {
for (int i = 0; i < numel; ++i) {
mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i];
mom2_out_ptr[i] =
beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i];
param_out_ptr[i] =
param_ptr[i] + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps));
}
}
#define DECLARE_REFER_KERNEL(name) \ #define DECLARE_REFER_KERNEL(name) \
template <typename T> \ template <typename T> \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \ class name##Kernel : public ReferKernel<name##Tuple<T>> { \
...@@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool); ...@@ -603,6 +616,7 @@ DECLARE_REFER_KERNEL(SeqPool);
DECLARE_REFER_KERNEL(MatMul); DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax); DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool); DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Adam);
DECLARE_REFER_KERNEL(Sgd); DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast); DECLARE_REFER_KERNEL(VBroadcast);
......
...@@ -841,6 +841,72 @@ void TestKernelStrideScal() { ...@@ -841,6 +841,72 @@ void TestKernelStrideScal() {
} }
} }
template <typename KernelTuple, typename PlaceType>
void TestKernelAdam() {
using T = typename KernelTuple::data_type;
VLOG(10) << "Test JITKernel: " << jit::to_string(KernelTuple::kernel_type);
const T lr = 0.1;
const T beta1 = 0.99;
const T beta2 = 0.95;
const T beta1_pow = beta1 * beta1;
const T beta2_pow = beta2 * beta2;
const T epsilon = 0.000001;
const int64_t numel = 123;
T learning_rate = lr * (sqrt(1 - beta2_pow) / (1 - beta1_pow));
T eps = epsilon * sqrt(1 - beta2_pow);
std::vector<T> param(numel);
std::vector<T> grad(numel);
std::vector<T> mom1(numel);
std::vector<T> mom2(numel);
std::vector<T> param_out(param.size());
std::vector<T> mom1_out(mom1.size());
std::vector<T> mom2_out(mom2.size());
RandomVec<T>(numel, param.data(), 0.5f);
RandomVec<T>(numel, grad.data(), 0.5f);
RandomVec<T>(numel, mom1.data(), 0.5f);
RandomVec<T>(numel, mom2.data(), 0.5f);
auto ref = jit::GetReferFunc<KernelTuple>();
EXPECT_TRUE(ref != nullptr);
jit::adam_attr_t attr(beta1, beta2);
ref(beta1, beta2, -learning_rate, eps, numel, grad.data(), mom1.data(),
mom2.data(), param.data(), mom1_out.data(), mom2_out.data(),
param_out.data());
auto verifier = [](
const typename KernelTuple::func_type tgt, T beta1, T beta2, T lr, T eps,
int64_t numel, const std::vector<T>& grad, const std::vector<T>& mom1,
const std::vector<T>& mom2, const std::vector<T>& param,
const std::vector<T>& ref_mom1_out, const std::vector<T>& ref_mom2_out,
const std::vector<T>& ref_param_out) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(param.size(), static_cast<size_t>(numel));
EXPECT_EQ(grad.size(), static_cast<size_t>(numel));
EXPECT_EQ(mom1.size(), static_cast<size_t>(numel));
EXPECT_EQ(mom2.size(), static_cast<size_t>(numel));
std::vector<T> jit_mom1_out(ref_mom1_out.size());
std::vector<T> jit_mom2_out(ref_mom2_out.size());
std::vector<T> jit_param_out(ref_param_out.size());
tgt(beta1, beta2, -lr, eps, numel, grad.data(), mom1.data(), mom2.data(),
param.data(), jit_mom1_out.data(), jit_mom2_out.data(),
jit_param_out.data());
ExpectEQ<T>(ref_mom1_out.data(), jit_mom1_out.data(), numel);
ExpectEQ<T>(ref_mom2_out.data(), jit_mom2_out.data(), numel);
ExpectEQ<T>(ref_param_out.data(), jit_param_out.data(), numel);
};
TestAllImpls<KernelTuple, PlaceType>(
attr, verifier, beta1, beta2, learning_rate, eps, numel, grad, mom1, mom2,
param, mom1_out, mom2_out, param_out);
}
template <typename KernelTuple, typename PlaceType> template <typename KernelTuple, typename PlaceType>
void TestKernelSgd() { void TestKernelSgd() {
using T = typename KernelTuple::data_type; using T = typename KernelTuple::data_type;
...@@ -980,7 +1046,7 @@ TEST(JITKernel_pool, jitcreator) { ...@@ -980,7 +1046,7 @@ TEST(JITKernel_pool, jitcreator) {
#if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__) #if defined(_WIN32) || defined(__APPLE__) || defined(__OSX__)
EXPECT_EQ(jitcreators.size(), 0UL); EXPECT_EQ(jitcreators.size(), 0UL);
#else #else
EXPECT_EQ(jitcreators.size(), 25UL); EXPECT_EQ(jitcreators.size(), 26UL);
#endif #endif
} }
...@@ -1014,7 +1080,7 @@ TEST(JITKernel_pool, more) { ...@@ -1014,7 +1080,7 @@ TEST(JITKernel_pool, more) {
TEST(JITKernel_pool, refer) { TEST(JITKernel_pool, refer) {
const auto& kers = jit::ReferKernelPool::Instance().AllKernels(); const auto& kers = jit::ReferKernelPool::Instance().AllKernels();
EXPECT_EQ(kers.size(), 31UL); EXPECT_EQ(kers.size(), 32UL);
} }
// test helper // test helper
...@@ -1147,9 +1213,10 @@ TEST(JITKernel_helper, attr) { ...@@ -1147,9 +1213,10 @@ TEST(JITKernel_helper, attr) {
<< jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity) << jit::to_string(jit::kVExp) << jit::to_string(jit::kVIdentity)
<< jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu) << jit::to_string(jit::kVMul) << jit::to_string(jit::kVRelu)
<< jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd) << jit::to_string(jit::kVScal) << jit::to_string(jit::kSgd)
<< jit::to_string(jit::kVSigmoid) << jit::to_string(jit::kVSquare) << jit::to_string(jit::kAdam) << jit::to_string(jit::kVSigmoid)
<< jit::to_string(jit::kVSub) << jit::to_string(jit::kVTanh); << jit::to_string(jit::kVSquare) << jit::to_string(jit::kVSub)
EXPECT_EQ(out.str().size(), 234UL); << jit::to_string(jit::kVTanh);
EXPECT_EQ(out.str().size(), 239UL);
// SeqPoolTypes // SeqPoolTypes
out.str(""); out.str("");
...@@ -1296,6 +1363,19 @@ TEST(JITKernel_key, emb_seq_pool) { ...@@ -1296,6 +1363,19 @@ TEST(JITKernel_key, emb_seq_pool) {
EXPECT_TRUE(key4 != key5); EXPECT_TRUE(key4 != key5);
} }
TEST(JITKernel_key, adam) {
jit::adam_attr_t attr1(0.4f, 0.9f);
jit::adam_attr_t attr2(0.4f, 0.9f);
jit::adam_attr_t attr3(0.1f, 0.3f);
auto key1 = jit::JitCodeKey<jit::adam_attr_t>(attr1);
auto key2 = jit::JitCodeKey<jit::adam_attr_t>(attr2);
auto key3 = jit::JitCodeKey<jit::adam_attr_t>(attr3);
EXPECT_TRUE(key1 == key2);
EXPECT_TRUE(key2 != key3);
}
TEST(JITKernel_key, sgd) { TEST(JITKernel_key, sgd) {
jit::sgd_attr_t attr1(1, 2, 3, 4, 5); jit::sgd_attr_t attr1(1, 2, 3, 4, 5);
jit::sgd_attr_t attr2(1, 2, 3, 4, 5); jit::sgd_attr_t attr2(1, 2, 3, 4, 5);
...@@ -1316,7 +1396,7 @@ TEST(JITKernel_key, sgd) { ...@@ -1316,7 +1396,7 @@ TEST(JITKernel_key, sgd) {
EXPECT_TRUE(key4 != key5); EXPECT_TRUE(key4 != key5);
} }
// test kernerls // test kernels
#define TestKernelVMul TestKernelXYZN #define TestKernelVMul TestKernelXYZN
#define TestKernelVAdd TestKernelXYZN #define TestKernelVAdd TestKernelXYZN
#define TestKernelVAddRelu TestKernelXYZN #define TestKernelVAddRelu TestKernelXYZN
...@@ -1383,6 +1463,7 @@ TEST_CPU_KERNEL(SeqPool); ...@@ -1383,6 +1463,7 @@ TEST_CPU_KERNEL(SeqPool);
TEST_CPU_KERNEL(EmbSeqPool); TEST_CPU_KERNEL(EmbSeqPool);
TEST_CPU_KERNEL(MatMul); TEST_CPU_KERNEL(MatMul);
TEST_CPU_KERNEL(Softmax); TEST_CPU_KERNEL(Softmax);
TEST_CPU_KERNEL(Adam);
TEST_CPU_KERNEL(Sgd); TEST_CPU_KERNEL(Sgd);
TEST_CPU_KERNEL(VBroadcast); TEST_CPU_KERNEL(VBroadcast);
......
...@@ -20,9 +20,11 @@ limitations under the License. */ ...@@ -20,9 +20,11 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/threadpool.h" #include "paddle/fluid/framework/threadpool.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/algorithm.h" #include "paddle/fluid/operators/math/algorithm.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -506,21 +508,58 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -506,21 +508,58 @@ class AdamOpKernel : public framework::OpKernel<T> {
beta2_pow_out->numel())); beta2_pow_out->numel()));
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto* grad = ctx.Input<LoDTensor>("Grad"); T beta1_p = beta1_pow->data<T>()[0];
T beta2_p = beta2_pow->data<T>()[0];
AdamFunctor<T, CPUAdam> functor(
beta1, beta2, epsilon, beta1_pow->data<T>(), beta2_pow->data<T>(),
mom1->data<T>(), mom1_out->mutable_data<T>(ctx.GetPlace()),
mom2->data<T>(), mom2_out->mutable_data<T>(ctx.GetPlace()),
lr->data<T>(), grad->data<T>(), param->data<T>(),
param_out->mutable_data<T>(ctx.GetPlace()));
functor(param->numel());
if (!use_global_beta_pow) { if (!use_global_beta_pow) {
beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] = beta1_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta1 * beta1_pow->data<T>()[0]; beta1 * beta1_pow->data<T>()[0];
beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] = beta2_pow_out->mutable_data<T>(ctx.GetPlace())[0] =
beta2 * beta2_pow->data<T>()[0]; beta2 * beta2_pow->data<T>()[0];
} }
auto* grad = ctx.Input<LoDTensor>("Grad");
T* param_out_ptr = param_out->mutable_data<T>(ctx.GetPlace());
T* mom1_out_ptr = mom1_out->mutable_data<T>(ctx.GetPlace());
T* mom2_out_ptr = mom2_out->mutable_data<T>(ctx.GetPlace());
T learning_rate = lr->data<T>()[0] * (sqrt(1 - beta2_p) / (1 - beta1_p));
T eps = epsilon * sqrt(1 - beta2_p);
jit::adam_attr_t attr(beta1, beta2);
int64_t numel = param->numel();
const T* param_ptr = param->data<T>();
const T* mom1_ptr = mom1->data<T>();
const T* mom2_ptr = mom2->data<T>();
const T* grad_ptr = grad->data<T>();
auto adam =
jit::KernelFuncs<jit::AdamTuple<T>, platform::CPUPlace>::Cache().At(
attr);
static constexpr int64_t chunk_size = 512;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < numel / chunk_size; ++i) {
const int64_t offset = i * chunk_size;
adam(beta1, beta2, -learning_rate, eps, chunk_size, grad_ptr + offset,
mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset,
mom1_out_ptr + offset, mom2_out_ptr + offset,
param_out_ptr + offset);
}
if (numel % chunk_size != 0) {
const int64_t offset = (numel / chunk_size) * chunk_size;
const int64_t tail_numel = numel % chunk_size;
adam(beta1, beta2, -learning_rate, eps, tail_numel, grad_ptr + offset,
mom1_ptr + offset, mom2_ptr + offset, param_ptr + offset,
mom1_out_ptr + offset, mom2_out_ptr + offset,
param_out_ptr + offset);
}
} else if (grad_var->IsType<pten::SelectedRows>()) { } else if (grad_var->IsType<pten::SelectedRows>()) {
auto* grad = ctx.Input<pten::SelectedRows>("Grad"); auto* grad = ctx.Input<pten::SelectedRows>("Grad");
if (grad->rows().size() == 0) { if (grad->rows().size() == 0) {
......
...@@ -69,15 +69,19 @@ class TestAdamOp1(OpTest): ...@@ -69,15 +69,19 @@ class TestAdamOp1(OpTest):
class TestAdamOp2(OpTest): class TestAdamOp2(OpTest):
def set_shape(self):
self.shape = (102, 105)
def setUp(self): def setUp(self):
'''Test Adam Op with supplied attributes '''Test Adam Op with supplied attributes
''' '''
self.op_type = "adam" self.op_type = "adam"
param = np.random.uniform(-1, 1, (102, 105)).astype("float32") self.set_shape()
grad = np.random.uniform(-1, 1, (102, 105)).astype("float32") param = np.random.uniform(-1, 1, self.shape).astype("float32")
moment1 = np.random.uniform(-1, 1, (102, 105)).astype("float32") grad = np.random.uniform(-1, 1, self.shape).astype("float32")
moment1 = np.random.uniform(-1, 1, self.shape).astype("float32")
# The second moment is positive # The second moment is positive
moment2 = np.random.random((102, 105)).astype("float32") moment2 = np.random.random(self.shape).astype("float32")
learning_rate = 0.001 learning_rate = 0.001
beta1 = 0.9 beta1 = 0.9
...@@ -113,6 +117,11 @@ class TestAdamOp2(OpTest): ...@@ -113,6 +117,11 @@ class TestAdamOp2(OpTest):
self.check_output() self.check_output()
class TestAdamOnlyTailOp(TestAdamOp2):
def set_shape(self):
self.shape = (3)
class TestAdamOpMultipleSteps(OpTest): class TestAdamOpMultipleSteps(OpTest):
def setUp(self): def setUp(self):
'''Test Adam Operator with supplied attributes '''Test Adam Operator with supplied attributes
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册