diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 3348778ee782ef0cdd1df4c3c4b24060436d7d79..11dc615f5ff8ea78bbbf6eeb655ee88b3a52dc13 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -332,6 +332,45 @@ void BenchEmbSeqPoolKernel() { } } +template +void BenchSgdKernel() { + const T lr = 0.1; + auto UnDuplicatedRandomVec = [](int n, const int64_t lower, + const int64_t upper) -> std::vector { + PADDLE_ENFORCE_LE(static_cast(upper - lower), n - 1); + PADDLE_ENFORCE_GT(n, 0); + std::vector all, out; + for (int i = 0; i < n; ++i) { + all.push_back(i); + } + std::random_shuffle(all.begin(), all.end()); + out.insert(out.begin(), all.begin(), all.begin() + n); + return out; + }; + for (int param_h : {1, 1000}) { + for (int grad_w : {1, 2, 8, 16, 30, 256}) { + // only benchmark inplace + Tensor param; + param.Resize({param_h, grad_w}); + T* param_data = param.mutable_data(PlaceType()); + RandomVec(param_h * grad_w, param_data, -2.f, 2.f); + for (int rows_size = 1; rows_size <= std::min(param_h, 10); ++rows_size) { + Tensor grad; + grad.Resize({rows_size, grad_w}); + std::vector rows = + UnDuplicatedRandomVec(rows_size, 0, rows_size - 1); + RandomVec(rows_size * grad_w, grad.mutable_data(PlaceType()), + -2.f, 2.f); + const T* grad_data = grad.data(); + const int64_t* rows_data = rows.data(); + jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); + BenchAllImpls, PlaceType>( + attr, &lr, param_data, grad_data, rows_data, param_data, &attr); + } + } + } +} + template void BenchMatMulKernel() { for (int m : {1, 2, 3, 4}) { @@ -477,6 +516,9 @@ BENCH_FP32_CPU(kEmbSeqPool) { BenchEmbSeqPoolKernel(); } +// sgd function +BENCH_FP32_CPU(kSgd) { BenchSgdKernel(); } + // matmul BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel(); } diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index 294f73d9646c93132e464a032e93562094663a73..eb0c03568ddddf1c456fec6fcc81f3b40d051844 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -32,3 +32,4 @@ USE_JITKERNEL_GEN(kSeqPool) USE_JITKERNEL_GEN(kHMax) USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kEmbSeqPool) +USE_JITKERNEL_GEN(kSgd) diff --git a/paddle/fluid/operators/jit/gen/jitcode.h b/paddle/fluid/operators/jit/gen/jitcode.h index 689df8b1cbb7a928c9f9175d28a8231b56e2e82e..39847d1b65f771976c4dde5a3e34cc40e33851e6 100644 --- a/paddle/fluid/operators/jit/gen/jitcode.h +++ b/paddle/fluid/operators/jit/gen/jitcode.h @@ -31,7 +31,8 @@ namespace gen { // Application Binary Interface constexpr Xbyak::Operand::Code abi_param1(Xbyak::Operand::RDI), abi_param2(Xbyak::Operand::RSI), abi_param3(Xbyak::Operand::RDX), - abi_param4(Xbyak::Operand::RCX); + abi_param4(Xbyak::Operand::RCX), abi_param5(Xbyak::Operand::R8), + abi_param6(Xbyak::Operand::R9); constexpr Xbyak::Operand::Code g_abi_regs[] = { Xbyak::Operand::RBX, Xbyak::Operand::RBP, Xbyak::Operand::R12, diff --git a/paddle/fluid/operators/jit/gen/sgd.cc b/paddle/fluid/operators/jit/gen/sgd.cc new file mode 100644 index 0000000000000000000000000000000000000000..a745a27f9543a75f6915c9316aad62fa41305bb1 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/sgd.cc @@ -0,0 +1,130 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/sgd.h" +#include // offsetof +#include +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void SgdJitCode::genCode() { + preCode(); + constexpr int block = YMM_FLOAT_BLOCK; + constexpr int max_num_regs = 7; + const int num_block = w_ / block; + const int num_groups = num_block / max_num_regs; + const size_t block_size = sizeof(float) * block; + const size_t width_size = w_ * sizeof(float); + std::vector groups(num_groups, max_num_regs); + int rest_num_regs = num_block % max_num_regs; + if (rest_num_regs > 0) { + groups.push_back(rest_num_regs); + } + + vbroadcastss(ymm_lr, ptr[param_lr]); + // protect rdx + mov(reg_ptr_grad_i, param_grad); + mov(reg_ptr_rows_i, param_rows); + + mov(reg_rows_size_in_byte, + qword[param_attr + offsetof(sgd_attr_t, selected_rows_size)]); + mov(rax, sizeof(int64_t)); + mul(reg_rows_size_in_byte); + mov(reg_rows_size_in_byte, rax); + add(reg_rows_size_in_byte, reg_ptr_rows_i); + + Label l_next_row; + L(l_next_row); + { + mov(reg_row, qword[reg_ptr_rows_i]); + mov(rax, width_size); + mul(reg_row); + mov(reg_row, rax); + + mov(reg_ptr_param_i, param_param); + mov(reg_ptr_out_i, param_out); + add(reg_ptr_param_i, reg_row); + add(reg_ptr_out_i, reg_row); + + size_t w_offset = 0; + for (int num_regs : groups) { + // load grad + size_t inner_offfset = w_offset; + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ymm_t(reg_i), ptr[reg_ptr_grad_i + inner_offfset]); + inner_offfset += block_size; + } + + // load param + inner_offfset = w_offset; + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ymm_t(reg_i + num_regs), ptr[reg_ptr_param_i + inner_offfset]); + inner_offfset += block_size; + } + + // compute out + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmulps(ymm_t(reg_i), ymm_t(reg_i), ymm_lr); + vsubps(ymm_t(reg_i + num_regs), ymm_t(reg_i + num_regs), ymm_t(reg_i)); + } + + // save out + inner_offfset = w_offset; + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ptr[reg_ptr_out_i + inner_offfset], ymm_t(reg_i + num_regs)); + inner_offfset += block_size; + } + w_offset += (block_size * num_regs); + } + + add(reg_ptr_grad_i, width_size); + add(reg_ptr_rows_i, sizeof(int64_t)); + cmp(reg_ptr_rows_i, reg_rows_size_in_byte); + jl(l_next_row, T_NEAR); + } + + postCode(); +} + +class SgdCreator : public JitCodeCreator { + public: + bool UseMe(const sgd_attr_t& attr) const override { + return platform::MayIUse(platform::avx) && + attr.grad_width % YMM_FLOAT_BLOCK == 0; + } + size_t CodeSize(const sgd_attr_t& attr) const override { + return 96 + (attr.grad_width / YMM_FLOAT_BLOCK) * 32 * 8; + } + std::unique_ptr CreateJitCode( + const sgd_attr_t& attr) const override { + PADDLE_ENFORCE_EQ(attr.param_width, attr.grad_width); + PADDLE_ENFORCE_LE(attr.selected_rows_size, attr.grad_height); + PADDLE_ENFORCE_GE(attr.selected_rows_size, 0); + return make_unique(attr, CodeSize(attr)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kSgd, gen::SgdCreator); diff --git a/paddle/fluid/operators/jit/gen/sgd.h b/paddle/fluid/operators/jit/gen/sgd.h new file mode 100644 index 0000000000000000000000000000000000000000..317edcd2bcb5fea1f14f32260fd16c9c706eaf00 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/sgd.h @@ -0,0 +1,60 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class SgdJitCode : public JitCode { + public: + explicit SgdJitCode(const sgd_attr_t& attr, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), w_(attr.grad_width) { + this->genCode(); + } + + DECLARE_JIT_CODE(SgdJitCode); + void genCode() override; + + private: + int w_; + reg64_t param_lr{abi_param1}; + reg64_t param_param{abi_param2}; + reg64_t param_grad{abi_param3}; + reg64_t param_rows{abi_param4}; + reg64_t param_out{abi_param5}; + reg64_t param_attr{abi_param6}; + + ymm_t ymm_lr = ymm_t(15); + + reg64_t reg_ptr_grad_i{r10}; + reg64_t reg_ptr_rows_i{r11}; + reg64_t reg_rows_size_in_byte{r12}; + reg64_t reg_row{r13}; + reg64_t reg_ptr_param_i{r14}; + reg64_t reg_ptr_out_i{r15}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index a76653613289892c4bb41596f998c5f4cc131fd7..1dc60442d5c5f6acf49b6319223b190f6c81e1a6 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -55,6 +55,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kHSum); ONE_CASE(kSoftmax); ONE_CASE(kEmbSeqPool); + ONE_CASE(kSgd); default: PADDLE_THROW("Not support type: %d, or forget to add it.", kt); return "NOT JITKernel"; diff --git a/paddle/fluid/operators/jit/helper.h b/paddle/fluid/operators/jit/helper.h index 07998588a5a560f9c2ad7cc765b66e76e87da6f6..d85c719c1c58c88ec244f1f6ad8343d66391241d 100644 --- a/paddle/fluid/operators/jit/helper.h +++ b/paddle/fluid/operators/jit/helper.h @@ -181,6 +181,14 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +inline std::ostream& operator<<(std::ostream& os, const sgd_attr_t& attr) { + os << "param_height[" << attr.param_height << "],param_width[" + << attr.param_width << "],grad_height[" << attr.grad_height + << "],grad_width[" << attr.grad_width << "],selected_rows_size[" + << attr.selected_rows_size << "]"; + return os; +} + inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) { os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]"; return os; diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 20b6a32bef9860c52ab4423395a8e00f719b0210..895e2d4d6f3809a66443ed6d6bfc1ee02d6c529a 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -46,6 +46,7 @@ typedef enum { kVMul, kVRelu, kVScal, + kSgd, kVSigmoid, kVSquare, kVSub, @@ -173,6 +174,28 @@ struct EmbSeqPoolTuples { const emb_seq_pool_attr_t*); }; +typedef struct sgd_attr_s { + int64_t param_height, param_width; + int64_t grad_height, grad_width; + int64_t selected_rows_size; + sgd_attr_s() = default; + explicit sgd_attr_s(int64_t param_h, int64_t param_w, int64_t grad_h, + int64_t grad_w, int64_t selected_rows_sz) + : param_height(param_h), + param_width(param_w), + grad_height(grad_h), + grad_width(grad_w), + selected_rows_size(selected_rows_sz) {} +} sgd_attr_t; + +template +struct SgdTuples { + typedef T data_type; + typedef sgd_attr_t attr_type; + typedef void (*func_type)(const T*, const T*, const T*, const int64_t*, T*, + const sgd_attr_t*); +}; + typedef struct matmul_attr_s { int m, n, k; void* packed_weight{nullptr}; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index e659c6d254391f09ac8692e0b7602c65e1afd47d..740d0f850a072a5ad3238e52402141a83c0b7e33 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/jit/kernel_key.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -23,14 +24,30 @@ size_t JitCodeKey(const int& d) { return d; } +// TODO(TJ): refine and benchmark JitCodeKey generatation constexpr int act_type_shift = 3; // suppot 2^3 act types +static inline int act_type_convert(KernelType type) { + if (type == kVIdentity) { + return 0; + } else if (type == kVExp) { + return 1; + } else if (type == kVRelu) { + return 2; + } else if (type == kVSigmoid) { + return 3; + } else if (type == kVTanh) { + return 4; + } + PADDLE_THROW("Unsupported act type %d", type); + return 0; +} template <> size_t JitCodeKey(const lstm_attr_t& attr) { size_t key = attr.d; - int gate_key = static_cast(attr.act_gate) << 1; - int cand_key = static_cast(attr.act_cand) << (1 + act_type_shift); - int cell_key = static_cast(attr.act_cell) << (1 + act_type_shift * 2); + int gate_key = act_type_convert(attr.act_gate) << 1; + int cand_key = act_type_convert(attr.act_cand) << (1 + act_type_shift); + int cell_key = act_type_convert(attr.act_cell) << (1 + act_type_shift * 2); return (key << (1 + act_type_shift * 3)) + gate_key + cand_key + cell_key + attr.use_peephole; } @@ -38,8 +55,8 @@ size_t JitCodeKey(const lstm_attr_t& attr) { template <> size_t JitCodeKey(const gru_attr_t& attr) { size_t key = attr.d; - return (key << (act_type_shift * 2)) + static_cast(attr.act_gate) + - (static_cast(attr.act_cand) << act_type_shift); + return (key << (act_type_shift * 2)) + act_type_convert(attr.act_gate) + + (act_type_convert(attr.act_cand) << act_type_shift); } template <> @@ -61,6 +78,11 @@ size_t JitCodeKey(const emb_seq_pool_attr_t& attr) { return attr.table_width; } +template <> +size_t JitCodeKey(const sgd_attr_t& attr) { + return attr.grad_width; +} + } // namespace jit } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index d209f31007255b3a90fdeeb4d609311b80bdc7b5..9a00ad56a6a909a677cb8f60bd80fe399e82952f 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -14,3 +14,4 @@ USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl) USE_JITKERNEL_MORE(kEmbSeqPool, mkl) +USE_JITKERNEL_MORE(kSgd, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 29a451f832fa745f8e1f5a45fd934f09e1f41e76..780fda02c1ff3da2e0b945f9b2fece30484e4519 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -184,6 +184,16 @@ bool EmbSeqPoolKernel::UseMe(const emb_seq_pool_attr_t& attr) const { return true; } +template <> +bool SgdKernel::UseMe(const sgd_attr_t& attr) const { + return true; +} + +template <> +bool SgdKernel::UseMe(const sgd_attr_t& attr) const { + return true; +} + template <> bool MatMulKernel::UseMe(const matmul_attr_t& attr) const { return platform::MayIUse(platform::avx); @@ -239,5 +249,6 @@ REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(kSoftmax, Softmax); +REGISTER_MKL_KERNEL(kSgd, Sgd); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 9a72ba83022de2beeb760772ee8489477befdd7e..a7bc2de4a3e8e7d8e2a6b00990bfa459b3029c2a 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -142,6 +142,32 @@ void Softmax(const T* x, T* y, int n, int bs) { } } +template +void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, + T* out, const sgd_attr_t* attr) { + PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width); + PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height); + T scalar = -lr[0]; + int width = attr->grad_width; + if (out == param) { + for (int64_t i = 0; i < attr->selected_rows_size; ++i) { + auto h_idx = rows[i]; + PADDLE_ENFORCE_LT(h_idx, attr->param_height); + PADDLE_ENFORCE_GE(h_idx, 0); + VAXPY(scalar, grad + i * width, out + h_idx * width, width); + } + } else { + for (int64_t i = 0; i < attr->selected_rows_size; ++i) { + auto h_idx = rows[i]; + PADDLE_ENFORCE_LT(h_idx, attr->param_height); + PADDLE_ENFORCE_GE(h_idx, 0); + VScal(&scalar, grad + i * width, out + h_idx * width, width); + VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width, + width); + } + } +} + #define DECLARE_MKL_KERNEL(name, tuples) \ template \ class name##Kernel : public KernelMore> { \ @@ -173,6 +199,8 @@ DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); +DECLARE_MKL_KERNEL(Sgd, SgdTuples); + #undef DECLARE_MKL_KERNEL } // namespace mkl diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 218d801c084be455538628d1c1028d8e52142894..cd19dd169d0bfdfe2cb8157ade29f48ad6428453 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -33,3 +33,4 @@ USE_JITKERNEL_REFER(kHSum) USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kEmbSeqPool) +USE_JITKERNEL_REFER(kSgd) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 7e7dd6960b66e4e2f77eca6e96604f2a86553120..0c434bd2b8cacdf4b8872da66bb8e763a6a45cee 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -59,4 +59,6 @@ REGISTER_REFER_KERNEL(kSoftmax, Softmax); REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); +REGISTER_REFER_KERNEL(kSgd, Sgd); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index fd1193aa41e50e3ede7f61588dc72389279bb95d..0f714edf85bbbf4838bfe09251bd1c2d5f3b3eb7 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -446,6 +446,36 @@ void EmbSeqPool(const T* table, const int64_t* idx, T* out, } } +// SGD algorithm: +// lr is pointor of learning rate scalar +// param is an input matrix with (param_h, param_w) +// grad is an input matrix with (grad_h, grad_w), here grad_w == param_w +// selected_rows is a vectot with size selected_rows_size( <= grad_h ) +// out is an output matrix with (param_h, param_w) +// +// support both regular and sparse grad +// regular SGD: out[:] = param[:] - lr[0] * grad[:]; +// sparse SGD: out[rows[i]][:] = param[rows[i]][:] - lr[0] * grad[i][:] +// +// Note: when use sparse SGD, and if out != param, +// the out rows which are not selected have not beed changed, which maybe empty +template +void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, + T* out, const sgd_attr_t* attr) { + PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width); + PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height); + for (int64_t i = 0; i < attr->selected_rows_size; ++i) { + auto h_idx = rows[i]; + PADDLE_ENFORCE_LT(h_idx, attr->param_height); + PADDLE_ENFORCE_GE(h_idx, 0); + for (int64_t j = 0; j < attr->grad_width; ++j) { + out[h_idx * attr->grad_width + j] = + param[h_idx * attr->grad_width + j] - + lr[0] * grad[i * attr->grad_width + j]; + } + } +} + #define DECLARE_REFER_KERNEL(name, tuples) \ template \ class name##Kernel : public ReferKernel> { \ @@ -496,6 +526,8 @@ DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples); DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); +DECLARE_REFER_KERNEL(Sgd, SgdTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index 356eba6f86ad180c7d23bf7fa91eb5d455ff5f08..b618cd6a84be752a052f9d49a4a4c772b1d7eeae 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include #include #include @@ -36,14 +37,14 @@ void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), } template -void ExpectEQ(const T* target, const T* refer, int n) { +void ExpectEQ(const T* target, const T* refer, size_t n) { if (std::is_floating_point::value) { - for (int i = 0; i < n; ++i) { - EXPECT_NEAR(target[i], refer[i], FLAGS_acc); + for (size_t i = 0; i < n; ++i) { + EXPECT_NEAR(target[i], refer[i], FLAGS_acc) << " at index : " << i; } } else { - for (int i = 0; i < n; ++i) { - EXPECT_EQ(target[i], refer[i]); + for (size_t i = 0; i < n; ++i) { + EXPECT_EQ(target[i], refer[i]) << " at index : " << i; } } } @@ -296,6 +297,45 @@ struct TestFuncWithRefer, std::vector, } }; +template +struct TestFuncWithRefer, T, std::vector, std::vector, + std::vector, std::vector, + typename jit::SgdTuples::attr_type> { + void operator()(const typename jit::SgdTuples::func_type tgt, const T lr, + const std::vector& param, const std::vector& grad, + const std::vector& rows, const std::vector& oref, + const typename jit::SgdTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(param.size(), + static_cast(attr.param_height * attr.param_width)); + EXPECT_EQ(grad.size(), + static_cast(attr.grad_height * attr.grad_width)); + EXPECT_EQ(rows.size(), static_cast(attr.selected_rows_size)); + EXPECT_EQ(param.size(), oref.size()); + const T* param_data = param.data(); + const T* grad_data = grad.data(); + const int64_t* rows_data = rows.data(); + const T* oref_data = oref.data(); + + std::vector out(oref.size()); + T* o_data = out.data(); + tgt(&lr, param_data, grad_data, rows_data, o_data, &attr); + // only the selected rows should be equal + for (size_t i = 0; i < rows.size(); ++i) { + ExpectEQ(o_data + rows[i] * attr.grad_width, + oref_data + rows[i] * attr.grad_width, attr.grad_width); + } + + // inplace + std::copy(param.begin(), param.end(), out.begin()); + tgt(&lr, o_data, grad_data, rows_data, o_data, &attr); + for (size_t i = 0; i < rows.size(); ++i) { + ExpectEQ(o_data + rows[i] * attr.grad_width, + oref_data + rows[i] * attr.grad_width, attr.grad_width); + } + } +}; + template struct TestFuncWithRefer, std::vector, std::vector, std::vector, @@ -407,7 +447,7 @@ void TestAllImpls(const typename KernelTuples::attr_type& attr, Args... args) { } template -void TestXYZNKernel() { +void TestKernelXYZNTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int d : TestSizes()) { auto ref = jit::GetRefer>(); @@ -440,7 +480,7 @@ void TestXYZNKernel() { } template -void TestAXYNKernel() { +void TestKernelAXYNTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int d : TestSizes()) { auto ref = jit::GetRefer>(); @@ -466,7 +506,7 @@ void TestAXYNKernel() { } template -void TestXRNKernel() { +void TestKernelXRNTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); auto last_acc = FLAGS_acc; FLAGS_acc = 1e-4; @@ -484,7 +524,7 @@ void TestXRNKernel() { } template -void TestXYNKernel() { +void TestKernelXYNTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int d : TestSizes()) { auto ref = jit::GetRefer>(); @@ -509,10 +549,12 @@ void TestXYNKernel() { } template -void TestLSTMKernel() { +void TestKernelLSTMTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; - for (int d : TestSizes()) { + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); + for (int d : test_sizes) { for (bool use_peephole : {true, false}) { for (auto& act_gate : all_acts) { for (auto& act_cand : all_acts) { @@ -559,10 +601,12 @@ void TestLSTMKernel() { } template -void TestGRUKernel() { +void TestKernelGRUTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); std::vector all_acts = {"sigmoid", "tanh", "relu", "identity"}; - for (int d : TestSizes()) { + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); + for (int d : test_sizes) { for (auto& act_gate : all_acts) { for (auto& act_cand : all_acts) { const jit::gru_attr_t attr(d, jit::to_kerneltype(act_gate), @@ -593,14 +637,16 @@ void TestGRUKernel() { } template -void TestSeqPoolKernel() { +void TestKernelSeqPoolTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); std::vector pool_types = { jit::SeqPoolType::kSum, jit::SeqPoolType::kAvg, jit::SeqPoolType::kSqrt}; + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); for (auto type : pool_types) { - for (int w : TestSizes()) { + for (int w : test_sizes) { jit::seq_pool_attr_t attr(w, type); - for (int h : TestSizes()) { + for (int h : test_sizes) { attr.h = h; auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); @@ -618,11 +664,11 @@ void TestSeqPoolKernel() { } template -void TestMatMulKernel() { +void TestKernelMatMulTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); auto last_acc = FLAGS_acc; - // TODO(intel): fix MKL acc issue - // https://github.com/PaddlePaddle/Paddle/issues/15447 + // export MKL_CBWR=AVX would make MKL force to use AVX + // export KMP_DETERMINISTIC_REDUCTION=yes would make the result deterministic FLAGS_acc = 1e-3; for (int m : {1, 2, 3, 4}) { for (int n : {1, 2, 3, 4}) { @@ -646,7 +692,7 @@ void TestMatMulKernel() { } template -void TestSoftmaxKernel() { +void TestKernelSoftmaxTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); for (int bs : {1, 2, 10}) { for (int n : TestSizes()) { @@ -671,12 +717,14 @@ void TestSoftmaxKernel() { } template -void TestEmbSeqPoolKernel() { +void TestKernelEmbSeqPoolTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); int64_t tbl_h = 1e4; std::vector pool_types = { jit::SeqPoolType::kSum}; // only support sum yet - for (int tbl_w : TestSizes()) { + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); + for (int tbl_w : test_sizes) { std::vector table(tbl_h * tbl_w); RandomVec(tbl_h * tbl_w, table.data(), -2.f, 2.f); const T* table_data = table.data(); @@ -705,7 +753,61 @@ void TestEmbSeqPoolKernel() { } template -void TestNCHW16CMulNCKernel() { +void TestKernelSgdTuples() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + const T lr = 0.1; + auto UnDuplicatedRandomVec = [](int n, const int64_t lower, + const int64_t upper) -> std::vector { + PADDLE_ENFORCE_LE(static_cast(upper - lower), n - 1); + PADDLE_ENFORCE_GT(n, 0); + std::vector all, out; + for (int i = 0; i < n; ++i) { + all.push_back(i); + } + std::random_shuffle(all.begin(), all.end()); + out.insert(out.begin(), all.begin(), all.begin() + n); + return out; + }; + for (int param_h : {1, 10}) { + for (int grad_w : TestSizes()) { + std::vector param(param_h * grad_w); + std::vector param_out(param_h * grad_w); + RandomVec(param_h * grad_w, param.data(), -2.f, 2.f); + const T* param_data = param.data(); + T* out_data = param_out.data(); + for (int rows_size = 1; rows_size <= param_h; ++rows_size) { + std::vector grad(rows_size * grad_w); + std::vector rows = + UnDuplicatedRandomVec(rows_size, 0, rows_size - 1); + RandomVec(rows_size * grad_w, grad.data(), -2.f, 2.f); + const int64_t* rows_data = rows.data(); + const T* grad_data = grad.data(); + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); + ref(&lr, param_data, grad_data, rows_data, out_data, &attr); + + // inplace test + std::vector inp(param.size()); + std::copy(param.begin(), param.end(), inp.begin()); + T* inp_data = inp.data(); + ref(&lr, inp_data, grad_data, rows_data, inp_data, &attr); + // only the selected rows should be equal + for (int i = 0; i < rows_size; ++i) { + ExpectEQ(inp_data + rows[i] * grad_w, out_data + rows[i] * grad_w, + grad_w); + } + + TestAllImpls, PlaceType, T, std::vector, + std::vector, std::vector, std::vector>( + attr, lr, param, grad, rows, param_out, attr); + } + } + } +} + +template +void TestKernelNCHW16CMulNCTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); const int n = 3, c = 16 * 4, h = 10, w = 10; auto ref = jit::GetRefer>(); @@ -758,7 +860,7 @@ void TestNCHW16CMulNCKernel() { } template -void TestLayerNormKernel() { +void TestKernelLayerNormTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); const T epsilon = 9.99999975e-06; for (int n : {1, 2, 10}) { @@ -797,11 +899,13 @@ void TestLayerNormKernel() { } template -void TestCRFDecodingKernel() { +void TestKernelCRFDecodingTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); constexpr int state_trans_base_idx = 2; + auto test_sizes = TestSizes(); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); for (int seq_len : {1, 11, 17, 50}) { - for (int tag_num : TestSizes()) { + for (int tag_num : test_sizes) { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); int x_sz = seq_len * tag_num; @@ -822,143 +926,76 @@ void TestCRFDecodingKernel() { } } -// XYZNTuple -TEST(JITKernel, kVMul) { - TestXYZNKernel(); - TestXYZNKernel(); -} - -TEST(JITKernel, kVAdd) { - TestXYZNKernel(); - TestXYZNKernel(); -} - -TEST(JITKernel, kVAddRelu) { - TestXYZNKernel(); - TestXYZNKernel(); -} - -TEST(JITKernel, kVSub) { - TestXYZNKernel(); - TestXYZNKernel(); -} - -// AXYNTuples -TEST(JITKernel, kVScal) { - TestAXYNKernel(); - TestAXYNKernel(); -} - -TEST(JITKernel, kVAddBias) { - TestAXYNKernel(); - TestAXYNKernel(); -} - -// XRNTuples -TEST(JITKernel, kHMax) { - TestXRNKernel(); - TestXRNKernel(); -} - -TEST(JITKernel, kHSum) { - TestXRNKernel(); - TestXRNKernel(); -} - -// XYNTuples -TEST(JITKernel, kVRelu) { - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, kVIdentity) { - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, kVSquare) { - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, kVExp) { - TestXYNKernel(); - TestXYNKernel(); -} - -TEST(JITKernel, kVSigmoid) { - TestXYNKernel(); - TestXYNKernel(); -} +#define TEST_CPU_KERNEL(test_tuple, kernel_type) \ + TEST(JITKernel, kernel_type) { \ + TestKernel##test_tuple(); \ + TestKernel##test_tuple(); \ + } -TEST(JITKernel, kVTanh) { - TestXYNKernel(); - TestXYNKernel(); -} +TEST_CPU_KERNEL(XYZNTuples, kVMul); +TEST_CPU_KERNEL(XYZNTuples, kVAdd); +TEST_CPU_KERNEL(XYZNTuples, kVAddRelu); +TEST_CPU_KERNEL(XYZNTuples, kVSub); -// LSTM -TEST(JITKernel, kLSTMCtHt) { - TestLSTMKernel(); - TestLSTMKernel(); -} +TEST_CPU_KERNEL(AXYNTuples, kVScal); +TEST_CPU_KERNEL(AXYNTuples, kVAddBias); -TEST(JITKernel, kLSTMC1H1) { - TestLSTMKernel(); - TestLSTMKernel(); -} +TEST_CPU_KERNEL(XRNTuples, kHMax); +TEST_CPU_KERNEL(XRNTuples, kHSum); -// GRU -TEST(JITKernel, kGRUH1) { - TestGRUKernel(); - TestGRUKernel(); -} +TEST_CPU_KERNEL(XYNTuples, kVRelu); +TEST_CPU_KERNEL(XYNTuples, kVIdentity); +TEST_CPU_KERNEL(XYNTuples, kVSquare); +TEST_CPU_KERNEL(XYNTuples, kVExp); +TEST_CPU_KERNEL(XYNTuples, kVSigmoid); +TEST_CPU_KERNEL(XYNTuples, kVTanh); -TEST(JITKernel, kGRUHtPart1) { - TestGRUKernel(); - TestGRUKernel(); -} +TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt); +TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1); -TEST(JITKernel, kGRUHtPart2) { - TestGRUKernel(); - TestGRUKernel(); -} +TEST_CPU_KERNEL(GRUTuples, kGRUH1); +TEST_CPU_KERNEL(GRUTuples, kGRUHtPart1); +TEST_CPU_KERNEL(GRUTuples, kGRUHtPart2); -TEST(JITKernel, kSeqPool) { - TestSeqPoolKernel(); - TestSeqPoolKernel(); -} +TEST_CPU_KERNEL(NCHW16CMulNCTuples, kNCHW16CMulNC); -TEST(JITKernel, kMatMul) { - TestMatMulKernel(); - TestMatMulKernel(); -} +TEST_CPU_KERNEL(SeqPoolTuples, kSeqPool); +TEST_CPU_KERNEL(MatMulTuples, kMatMul); +TEST_CPU_KERNEL(SoftmaxTuples, kSoftmax); +TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool); +TEST_CPU_KERNEL(SgdTuples, kSgd); +TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm); +TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding); -TEST(JITKernel, kSoftmax) { - TestSoftmaxKernel(); - TestSoftmaxKernel(); -} +TEST(JITKernel_key, lstm) { + jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr2(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr3(9, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); + jit::lstm_attr_t attr4(9, jit::kVRelu, jit::kVSigmoid, jit::kVTanh); -TEST(JITKernel, kEmbSeqPool) { - TestEmbSeqPoolKernel(); - TestEmbSeqPoolKernel(); -} + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); -TEST(JITKernel, kNCHW16CMulNC) { - TestNCHW16CMulNCKernel(); - TestNCHW16CMulNCKernel(); + EXPECT_TRUE(key1 != key2); + EXPECT_TRUE(key2 == key3); + EXPECT_TRUE(key3 != key4); } -TEST(JITKernel, kLayerNorm) { - TestLayerNormKernel(); - TestLayerNormKernel(); -} +TEST(JITKernel_key, gru) { + jit::gru_attr_t attr1(8, jit::kVSigmoid, jit::kVTanh); + jit::gru_attr_t attr2(9, jit::kVSigmoid, jit::kVTanh); + jit::gru_attr_t attr3(9, jit::kVSigmoid, jit::kVTanh); + jit::gru_attr_t attr4(9, jit::kVSigmoid, jit::kVIdentity); -TEST(JITKernel, kCRFDecoding) { - TestCRFDecodingKernel(); - TestCRFDecodingKernel(); -} + auto key1 = jit::JitCodeKey(attr1); + auto key2 = jit::JitCodeKey(attr2); + auto key3 = jit::JitCodeKey(attr3); + auto key4 = jit::JitCodeKey(attr4); -TEST(JITKernel, pool) { - // TODO(TJ): add some test + EXPECT_TRUE(key1 != key2); + EXPECT_TRUE(key2 == key3); + EXPECT_TRUE(key3 != key4); } +// TODO(TJ): add more test about key and pool diff --git a/paddle/fluid/operators/optimizers/sgd_op.h b/paddle/fluid/operators/optimizers/sgd_op.h index 98bae5e1d329005f9463fd7bb0751c44952dea88..c9c9f530fe846c1713ad176e05a377996d04470b 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.h +++ b/paddle/fluid/operators/optimizers/sgd_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/jit/kernels.h" namespace paddle { namespace operators { @@ -32,53 +33,57 @@ class SGDOpKernel : public framework::OpKernel { if (param_var->IsType()) { const auto *param = ctx.Input("Param"); auto *param_out = ctx.Output("ParamOut"); - // Actually, all tensors are LoDTensor except SelectedRows. if (grad_var->IsType()) { - param_out->mutable_data(ctx.GetPlace()); const auto *grad = ctx.Input("Grad"); - - auto p = framework::EigenVector::Flatten(*param); - auto g = framework::EigenVector::Flatten(*grad); - auto o = framework::EigenVector::Flatten(*param_out); - auto *lr = learning_rate->data(); - - o = p - lr[0] * g; + auto sz = param_out->numel(); + PADDLE_ENFORCE_EQ(param->numel(), sz); + PADDLE_ENFORCE_EQ(grad->numel(), sz); + + jit::sgd_attr_t attr(1, sz, 1, sz, 1); + const T *lr = learning_rate->data(); + const T *param_data = param->data(); + const T *grad_data = grad->data(); + int64_t rows_idx = 0; + T *out_data = param_out->mutable_data(ctx.GetPlace()); + + auto sgd = + jit::Get, platform::CPUPlace>(attr); + sgd(lr, param_data, grad_data, &rows_idx, out_data, &attr); } else if (grad_var->IsType()) { // TODO(qijun): In Sparse SGD operator, in-place update is enforced. // This manual optimization brings difficulty to track data dependency. // It's better to find a more elegant solution. PADDLE_ENFORCE_EQ(param, param_out); const auto *grad = ctx.Input("Grad"); + auto &grad_rows = grad->rows(); // for distributed training, a sparse var may be empty, // just skip updating. - if (grad->rows().size() == 0) { + if (grad_rows.size() == 0) { return; } - auto grad_height = grad->height(); auto out_dims = param_out->dims(); - PADDLE_ENFORCE_EQ(grad_height, out_dims[0]); - + PADDLE_ENFORCE_EQ(grad->height(), out_dims[0]); auto &grad_value = grad->value(); - auto &grad_rows = grad->rows(); - - size_t grad_row_numel = grad_value.numel() / grad_rows.size(); - PADDLE_ENFORCE_EQ(static_cast(grad_row_numel), - param_out->numel() / grad_height); - - auto *grad_data = grad_value.data(); - auto *out_data = param_out->data(); - auto *lr = learning_rate->data(); - for (size_t i = 0; i < grad_rows.size(); i++) { - PADDLE_ENFORCE(grad_rows[i] < grad_height, - "Input rows index should less than height"); - for (size_t j = 0; j < grad_row_numel; j++) { - out_data[grad_rows[i] * grad_row_numel + j] -= - lr[0] * grad_data[i * grad_row_numel + j]; - } - } + const T *param_data = param->data(); + const T *grad_data = grad_value.data(); + const T *lr = learning_rate->data(); + const int64_t *rows_data = grad_rows.data(); + T *out_data = param_out->mutable_data(ctx.GetPlace()); + + jit::sgd_attr_t attr; + attr.param_height = out_dims[0]; + attr.param_width = param_out->numel() / attr.param_height; + attr.grad_height = grad_rows.size(); // note: it is not grad->height() + attr.grad_width = grad_value.numel() / attr.grad_height; + attr.selected_rows_size = grad_rows.size(); + PADDLE_ENFORCE_EQ(attr.grad_width, attr.param_width); + + auto sgd = + jit::Get, platform::CPUPlace>(attr); + sgd(lr, param_data, grad_data, rows_data, out_data, &attr); } else { PADDLE_THROW("Unsupported Variable Type of Grad"); } diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index b46e4bfb86bd5dc9c74375693328f2506281be3e..162e6d1938c8174d342d8e4af1e4b6c424afc521 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -24,17 +24,28 @@ from op_test import OpTest class TestSGDOp(OpTest): def setUp(self): self.op_type = "sgd" - w = np.random.random((102, 105)).astype("float32") - g = np.random.random((102, 105)).astype("float32") + self.conf() + w = np.random.random((self.h, self.w)).astype("float32") + g = np.random.random((self.h, self.w)).astype("float32") lr = np.array([0.1]).astype("float32") self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} self.outputs = {'ParamOut': w - lr * g} + def conf(self): + self.h = 102 + self.w = 105 + def test_check_output(self): self.check_output() +class TestSGDOpCase8X(TestSGDOp): + def conf(self): + self.h = 10 + self.w = 64 + + class TestSparseSGDOp(unittest.TestCase): def check_with_place(self, place): scope = core.Scope() @@ -42,12 +53,12 @@ class TestSparseSGDOp(unittest.TestCase): # create and initialize Grad Variable height = 10 rows = [0, 4, 7] - row_numel = 12 + self.conf() grad_selected_rows = scope.var('Grad').get_selected_rows() grad_selected_rows.set_height(height) grad_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), row_numel)).astype("float32") + np_array = np.ones((len(rows), self.row_numel)).astype("float32") np_array[0, 0] = 2.0 np_array[2, 8] = 4.0 @@ -56,7 +67,7 @@ class TestSparseSGDOp(unittest.TestCase): # create and initialize Param Variable param = scope.var('Param').get_tensor() - param_array = np.full((height, row_numel), 5.0).astype("float32") + param_array = np.full((height, self.row_numel), 5.0).astype("float32") param.set(param_array, place) # create and initialize LeraningRate Variable @@ -98,6 +109,14 @@ class TestSparseSGDOp(unittest.TestCase): for place in places: self.check_with_place(place) + def conf(self): + self.row_numel = 12 + + +class TestSparseSGDOpCase8X(TestSparseSGDOp): + def conf(self): + self.row_numel = 16 + class TestSGDOpOptimizeSelectedRows(unittest.TestCase): def check_with_place(self, place):