diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 3348778ee782ef0cdd1df4c3c4b24060436d7d79..11dc615f5ff8ea78bbbf6eeb655ee88b3a52dc13 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -332,6 +332,45 @@ void BenchEmbSeqPoolKernel() { } } +template +void BenchSgdKernel() { + const T lr = 0.1; + auto UnDuplicatedRandomVec = [](int n, const int64_t lower, + const int64_t upper) -> std::vector { + PADDLE_ENFORCE_LE(static_cast(upper - lower), n - 1); + PADDLE_ENFORCE_GT(n, 0); + std::vector all, out; + for (int i = 0; i < n; ++i) { + all.push_back(i); + } + std::random_shuffle(all.begin(), all.end()); + out.insert(out.begin(), all.begin(), all.begin() + n); + return out; + }; + for (int param_h : {1, 1000}) { + for (int grad_w : {1, 2, 8, 16, 30, 256}) { + // only benchmark inplace + Tensor param; + param.Resize({param_h, grad_w}); + T* param_data = param.mutable_data(PlaceType()); + RandomVec(param_h * grad_w, param_data, -2.f, 2.f); + for (int rows_size = 1; rows_size <= std::min(param_h, 10); ++rows_size) { + Tensor grad; + grad.Resize({rows_size, grad_w}); + std::vector rows = + UnDuplicatedRandomVec(rows_size, 0, rows_size - 1); + RandomVec(rows_size * grad_w, grad.mutable_data(PlaceType()), + -2.f, 2.f); + const T* grad_data = grad.data(); + const int64_t* rows_data = rows.data(); + jit::sgd_attr_t attr(param_h, grad_w, rows_size, grad_w, rows_size); + BenchAllImpls, PlaceType>( + attr, &lr, param_data, grad_data, rows_data, param_data, &attr); + } + } + } +} + template void BenchMatMulKernel() { for (int m : {1, 2, 3, 4}) { @@ -477,6 +516,9 @@ BENCH_FP32_CPU(kEmbSeqPool) { BenchEmbSeqPoolKernel(); } +// sgd function +BENCH_FP32_CPU(kSgd) { BenchSgdKernel(); } + // matmul BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel(); } diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index d209f31007255b3a90fdeeb4d609311b80bdc7b5..9a00ad56a6a909a677cb8f60bd80fe399e82952f 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -14,3 +14,4 @@ USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl) USE_JITKERNEL_MORE(kEmbSeqPool, mkl) +USE_JITKERNEL_MORE(kSgd, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 29a451f832fa745f8e1f5a45fd934f09e1f41e76..780fda02c1ff3da2e0b945f9b2fece30484e4519 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -184,6 +184,16 @@ bool EmbSeqPoolKernel::UseMe(const emb_seq_pool_attr_t& attr) const { return true; } +template <> +bool SgdKernel::UseMe(const sgd_attr_t& attr) const { + return true; +} + +template <> +bool SgdKernel::UseMe(const sgd_attr_t& attr) const { + return true; +} + template <> bool MatMulKernel::UseMe(const matmul_attr_t& attr) const { return platform::MayIUse(platform::avx); @@ -239,5 +249,6 @@ REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(kSoftmax, Softmax); +REGISTER_MKL_KERNEL(kSgd, Sgd); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 9a72ba83022de2beeb760772ee8489477befdd7e..a7bc2de4a3e8e7d8e2a6b00990bfa459b3029c2a 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -142,6 +142,32 @@ void Softmax(const T* x, T* y, int n, int bs) { } } +template +void Sgd(const T* lr, const T* param, const T* grad, const int64_t* rows, + T* out, const sgd_attr_t* attr) { + PADDLE_ENFORCE_EQ(attr->param_width, attr->grad_width); + PADDLE_ENFORCE_LE(attr->selected_rows_size, attr->grad_height); + T scalar = -lr[0]; + int width = attr->grad_width; + if (out == param) { + for (int64_t i = 0; i < attr->selected_rows_size; ++i) { + auto h_idx = rows[i]; + PADDLE_ENFORCE_LT(h_idx, attr->param_height); + PADDLE_ENFORCE_GE(h_idx, 0); + VAXPY(scalar, grad + i * width, out + h_idx * width, width); + } + } else { + for (int64_t i = 0; i < attr->selected_rows_size; ++i) { + auto h_idx = rows[i]; + PADDLE_ENFORCE_LT(h_idx, attr->param_height); + PADDLE_ENFORCE_GE(h_idx, 0); + VScal(&scalar, grad + i * width, out + h_idx * width, width); + VAdd(param + h_idx * width, out + h_idx * width, out + h_idx * width, + width); + } + } +} + #define DECLARE_MKL_KERNEL(name, tuples) \ template \ class name##Kernel : public KernelMore> { \ @@ -173,6 +199,8 @@ DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); +DECLARE_MKL_KERNEL(Sgd, SgdTuples); + #undef DECLARE_MKL_KERNEL } // namespace mkl