From 41a1270856c6472c0f7e86e9d506636d6fb01490 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Fri, 1 Mar 2019 06:16:56 +0000 Subject: [PATCH] add vbroadcast jitkernel refer code and use it test=develop --- .../fused/fused_embedding_seq_pool_op.h | 23 +++++----- paddle/fluid/operators/jit/benchmark.cc | 23 ++++++++++ paddle/fluid/operators/jit/helper.cc | 1 + paddle/fluid/operators/jit/kernel_base.h | 8 ++++ paddle/fluid/operators/jit/kernel_key.cc | 5 +++ .../fluid/operators/jit/refer/CMakeLists.txt | 1 + paddle/fluid/operators/jit/refer/refer.cc | 2 + paddle/fluid/operators/jit/refer/refer.h | 11 +++++ paddle/fluid/operators/jit/test.cc | 42 +++++++++++++++++++ 9 files changed, 103 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 2b0c1f560f2..f13c0203860 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor { auto *output = output_t->mutable_data(context.GetPlace()); PADDLE_ENFORCE_LE(table_width * idx_width, out_width); - PADDLE_ENFORCE_GT(ids_lod.size(), 1UL); + PADDLE_ENFORCE_GT(ids_lod.size(), 1UL, "The LoD[0] could NOT be empty"); jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width, out_width, jit::SeqPoolType::kSum); @@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); const auto &ids_lod = ids_t->lod(); // in run time, the LoD of ids must be 1 - PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1"); - PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); + PADDLE_ENFORCE(ids_lod.size(), 1UL, + "The LoD level of Input(Ids) must be 1"); int64_t batch_size = ids_lod[0].size() - 1; // in run time, the shape from Ids -> output - // should be [seq_length, 1] -> [batch_size, embedding_size] + // should be [seq_length, 1] -> [batch_size, last_dim] output_t->Resize({batch_size, last_dim}); if (combiner_type == "sum") { @@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); auto lod = ids->lod()[0]; - int64_t row_width = d_output->dims()[1]; + int64_t out_width = d_output->dims()[1]; framework::Vector *new_rows = d_table->mutable_rows(); new_rows->resize(ids_num); @@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { T *d_table_data = d_table_value->mutable_data(context.GetPlace()); const T *d_output_data = d_output->data(); - auto blas = math::GetBlas(context); + auto vbroadcast = jit::Get, + platform::CPUPlace>(out_width); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); - int64_t in_offset = lod[i] * row_width; - const T *out_pos = d_output_data + i * row_width; - T *in_pos = d_table_data + in_offset; - for (int r = 0; r != h; ++r) { - blas.VCOPY(row_width, out_pos, in_pos + r * row_width); - } + const T *src = d_output_data + i * out_width; + T *dst = d_table_data + lod[i] * out_width; + vbroadcast(src, dst, h, out_width); } } else { LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index dcee2215291..93ebb1faa75 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -474,6 +474,24 @@ void BenchCRFDecodingKernel() { } } +template +void BenchVBroadcastKernel() { + for (int w : TestSizes()) { + Tensor x; + x.Resize({w}); + RandomVec(w, x.mutable_data(PlaceType())); + const T* x_data = x.data(); + for (int64_t h : {1, 3, 6}) { + Tensor y; + y.Resize({h * w}); + T* y_data = y.mutable_data(PlaceType()); + + BenchAllImpls, PlaceType>( + static_cast(w), x_data, y_data, h, static_cast(w)); + } + } +} + using T = float; using CPUPlace = paddle::platform::CPUPlace; @@ -536,6 +554,11 @@ BENCH_FP32_CPU(kCRFDecoding) { BenchCRFDecodingKernel(); } +// vbroadcast function +BENCH_FP32_CPU(kVBroadcast) { + BenchVBroadcastKernel(); +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index b15d956b9f1..eb1c410b6f9 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -36,6 +36,7 @@ const char* to_string(KernelType kt) { ONE_CASE(kVScal); ONE_CASE(kVAddBias); ONE_CASE(kVRelu); + ONE_CASE(kVBroadcast); ONE_CASE(kVCopy); ONE_CASE(kVIdentity); ONE_CASE(kVExp); diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index df24b1bea6e..96e162a21bf 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -41,6 +41,7 @@ typedef enum { kVAdd, kVAddBias, kVAddRelu, + kVBroadcast, kVCopy, kVExp, kVIdentity, @@ -134,6 +135,13 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +template +struct VBroadcastTuples { + typedef T data_type; + typedef int64_t attr_type; + typedef void (*func_type)(const T*, T*, int64_t, int64_t); +}; + typedef struct seq_pool_attr_s { int h, w; // h should always be the first one SeqPoolType type; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 740d0f850a0..1c2fddcae79 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -24,6 +24,11 @@ size_t JitCodeKey(const int& d) { return d; } +template <> +size_t JitCodeKey(const int64_t& d) { + return d; +} + // TODO(TJ): refine and benchmark JitCodeKey generatation constexpr int act_type_shift = 3; // suppot 2^3 act types static inline int act_type_convert(KernelType type) { diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index 44ea944cf57..ffab9c1457b 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -35,3 +35,4 @@ USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kSgd) +USE_JITKERNEL_REFER(kVBroadcast) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 01a521942bb..c279d1b2ca4 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -62,4 +62,6 @@ REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_REFER_KERNEL(kSgd, Sgd); +REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index bef4ca9cbb9..b3b2097828c 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -75,6 +75,15 @@ void VCopy(const T* x, T* y, int n) { std::memcpy(y, x, n * sizeof(T)); } +// x shape: (x_len) +// y shape: (h, x_len) +template +void VBroadcast(const T* x, T* y, int64_t y_h, int64_t x_len) { + for (int64_t h = 0; h < y_h; ++h) { + VCopy(x, y + h * x_len, x_len); + } +} + template void VRelu(const T* x, T* y, int n) { for (int i = 0; i < n; ++i) { @@ -534,6 +543,8 @@ DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_REFER_KERNEL(Sgd, SgdTuples); +DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index c9e0f170219..cdec14dc438 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -157,6 +157,26 @@ struct TestFuncWithRefer, std::vector, T> { } }; +template +struct TestFuncWithRefer, std::vector, + std::vector, int64_t, + typename jit::VBroadcastTuples::attr_type> { + void operator()(const typename jit::VBroadcastTuples::func_type tgt, + const std::vector& x, const std::vector& yref, + int64_t h, + const typename jit::VBroadcastTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size(), static_cast(attr)); + EXPECT_EQ(yref.size(), x.size() * h); + std::vector y(yref.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + T* y_data = y.data(); + tgt(x_data, y_data, h, attr); + ExpectEQ(y_data, yref_data, yref.size()); + } +}; + template struct TestFuncWithRefer, std::vector, std::vector> { void operator()(const typename jit::XYNTuples::func_type tgt, @@ -926,6 +946,27 @@ void TestKernelCRFDecodingTuples() { } } +template +void TestKernelVBroadcastTuples() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int w : TestSizes()) { + std::vector x(w); + RandomVec(w, x.data()); + const T* x_data = x.data(); + for (int64_t h : {1, 2, 6}) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector y(w * h); + T* y_data = y.data(); + ref(x_data, y_data, h, w); + + TestAllImpls, PlaceType, std::vector, + std::vector, int64_t>(static_cast(w), x, y, h, + static_cast(w)); + } + } +} + #define TEST_CPU_KERNEL(test_tuple, kernel_type) \ TEST(JITKernel, kernel_type) { \ TestKernel##test_tuple(); \ @@ -967,6 +1008,7 @@ TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool); TEST_CPU_KERNEL(SgdTuples, kSgd); TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm); TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding); +TEST_CPU_KERNEL(VBroadcastTuples, kVBroadcast); TEST(JITKernel_key, lstm) { jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); -- GitLab