From 4f5e4540f88d3d8d933cef30547b587c17ed52f1 Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Mon, 23 Mar 2020 09:59:42 +0100 Subject: [PATCH] Improve SGD jit code to work with large data (#23120) --- paddle/fluid/operators/jit/gen/sgd.cc | 83 ++++++++++++++------------- paddle/fluid/operators/jit/gen/sgd.h | 1 + 2 files changed, 43 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/jit/gen/sgd.cc b/paddle/fluid/operators/jit/gen/sgd.cc index e65d3500b49..40f8298af39 100644 --- a/paddle/fluid/operators/jit/gen/sgd.cc +++ b/paddle/fluid/operators/jit/gen/sgd.cc @@ -24,22 +24,41 @@ namespace operators { namespace jit { namespace gen { +void SgdJitCode::mainCode(int num_regs) { + constexpr size_t block_size = sizeof(float) * YMM_FLOAT_BLOCK; + // load grad + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ymm_t(reg_i), ptr[reg_ptr_grad_i]); + add(reg_ptr_grad_i, block_size); + } + // load param + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ymm_t(reg_i + num_regs), ptr[reg_ptr_param_i]); + add(reg_ptr_param_i, block_size); + } + // compute out + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmulps(ymm_t(reg_i), ymm_t(reg_i), ymm_lr); + vsubps(ymm_t(reg_i + num_regs), ymm_t(reg_i + num_regs), ymm_t(reg_i)); + } + // save out + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ptr[reg_ptr_out_i], ymm_t(reg_i + num_regs)); + add(reg_ptr_out_i, block_size); + } +} + void SgdJitCode::genCode() { preCode(); constexpr int block = YMM_FLOAT_BLOCK; constexpr int max_num_regs = 7; const int num_block = w_ / block; const int num_groups = num_block / max_num_regs; - const size_t block_size = sizeof(float) * block; - const size_t width_size = w_ * sizeof(float); - std::vector groups(num_groups, max_num_regs); int rest_num_regs = num_block % max_num_regs; - if (rest_num_regs > 0) { - groups.push_back(rest_num_regs); - } + const size_t width_size = w_ * sizeof(float); vbroadcastss(ymm_lr, ptr[param_lr]); - // protect rdx + mov(reg_ptr_grad_i, param_grad); mov(reg_ptr_rows_i, param_rows); @@ -63,43 +82,27 @@ void SgdJitCode::genCode() { add(reg_ptr_param_i, reg_row); add(reg_ptr_out_i, reg_row); - size_t w_offset = 0; - for (int num_regs : groups) { - // load grad - size_t inner_offfset = w_offset; - for (int reg_i = 0; reg_i < num_regs; ++reg_i) { - vmovups(ymm_t(reg_i), ptr[reg_ptr_grad_i + inner_offfset]); - inner_offfset += block_size; - } - - // load param - inner_offfset = w_offset; - for (int reg_i = 0; reg_i < num_regs; ++reg_i) { - vmovups(ymm_t(reg_i + num_regs), ptr[reg_ptr_param_i + inner_offfset]); - inner_offfset += block_size; - } - - // compute out - for (int reg_i = 0; reg_i < num_regs; ++reg_i) { - vmulps(ymm_t(reg_i), ymm_t(reg_i), ymm_lr); - vsubps(ymm_t(reg_i + num_regs), ymm_t(reg_i + num_regs), ymm_t(reg_i)); - } - - // save out - inner_offfset = w_offset; - for (int reg_i = 0; reg_i < num_regs; ++reg_i) { - vmovups(ptr[reg_ptr_out_i + inner_offfset], ymm_t(reg_i + num_regs)); - inner_offfset += block_size; - } - w_offset += (block_size * num_regs); + Label inner_loop; + Label escape_loop; + mov(rax, 0); + L(inner_loop); + { + cmp(rax, num_groups); + jnb(escape_loop, T_NEAR); + + mainCode(max_num_regs); + + inc(rax); + jmp(inner_loop, T_NEAR); } + L(escape_loop); + mainCode(rest_num_regs); - add(reg_ptr_grad_i, width_size); add(reg_ptr_rows_i, sizeof(int64_t)); + cmp(reg_ptr_rows_i, reg_rows_size_in_byte); jl(l_next_row, T_NEAR); } - postCode(); } @@ -109,9 +112,7 @@ class SgdCreator : public JitCodeCreator { return platform::MayIUse(platform::avx) && attr.grad_width % YMM_FLOAT_BLOCK == 0; } - size_t CodeSize(const sgd_attr_t& attr) const override { - return 96 + (attr.grad_width / YMM_FLOAT_BLOCK) * 32 * 8; - } + size_t CodeSize(const sgd_attr_t& attr) const override { return 96 + 32 * 8; } std::unique_ptr CreateJitCode( const sgd_attr_t& attr) const override { PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width); diff --git a/paddle/fluid/operators/jit/gen/sgd.h b/paddle/fluid/operators/jit/gen/sgd.h index 317edcd2bcb..80b1809bbbf 100644 --- a/paddle/fluid/operators/jit/gen/sgd.h +++ b/paddle/fluid/operators/jit/gen/sgd.h @@ -34,6 +34,7 @@ class SgdJitCode : public JitCode { DECLARE_JIT_CODE(SgdJitCode); void genCode() override; + void mainCode(int num_regs); private: int w_; -- GitLab