diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 2b0c1f560f23eee7fbdf14444bf933535b704167..f13c02038606e52337b7ef85545e37054e54b631 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/operators/math/blas.h" namespace paddle { namespace operators { @@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor { auto *output = output_t->mutable_data(context.GetPlace()); PADDLE_ENFORCE_LE(table_width * idx_width, out_width); - PADDLE_ENFORCE_GT(ids_lod.size(), 1UL); + PADDLE_ENFORCE_GT(ids_lod.size(), 1UL, "The LoD[0] could NOT be empty"); jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width, out_width, jit::SeqPoolType::kSum); @@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel { FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); const auto &ids_lod = ids_t->lod(); // in run time, the LoD of ids must be 1 - PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1"); - PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); + PADDLE_ENFORCE(ids_lod.size(), 1UL, + "The LoD level of Input(Ids) must be 1"); int64_t batch_size = ids_lod[0].size() - 1; // in run time, the shape from Ids -> output - // should be [seq_length, 1] -> [batch_size, embedding_size] + // should be [seq_length, 1] -> [batch_size, last_dim] output_t->Resize({batch_size, last_dim}); if (combiner_type == "sum") { @@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { auto *ids_data = ids->data(); int64_t ids_num = ids->numel(); auto lod = ids->lod()[0]; - int64_t row_width = d_output->dims()[1]; + int64_t out_width = d_output->dims()[1]; framework::Vector *new_rows = d_table->mutable_rows(); new_rows->resize(ids_num); @@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel { T *d_table_data = d_table_value->mutable_data(context.GetPlace()); const T *d_output_data = d_output->data(); - auto blas = math::GetBlas(context); + auto vbroadcast = jit::Get, + platform::CPUPlace>(out_width); for (int i = 0; i < static_cast(lod.size()) - 1; ++i) { int64_t h = static_cast(lod[i + 1] - lod[i]); - int64_t in_offset = lod[i] * row_width; - const T *out_pos = d_output_data + i * row_width; - T *in_pos = d_table_data + in_offset; - for (int r = 0; r != h; ++r) { - blas.VCOPY(row_width, out_pos, in_pos + r * row_width); - } + const T *src = d_output_data + i * out_width; + T *dst = d_table_data + lod[i] * out_width; + vbroadcast(src, dst, h, out_width); } } else { LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; diff --git a/paddle/fluid/operators/jit/benchmark.cc b/paddle/fluid/operators/jit/benchmark.cc index 11dc615f5ff8ea78bbbf6eeb655ee88b3a52dc13..3088280bb90174e6195a349c07a3435e131e2b33 100644 --- a/paddle/fluid/operators/jit/benchmark.cc +++ b/paddle/fluid/operators/jit/benchmark.cc @@ -474,6 +474,23 @@ void BenchCRFDecodingKernel() { } } +template +void BenchVBroadcastKernel() { + for (int64_t w : {1, 16, 64, 100, 256}) { + Tensor x; + x.Resize({w}); + RandomVec(w, x.mutable_data(PlaceType())); + const T* x_data = x.data(); + for (int h : TestSizes()) { + Tensor y; + y.Resize({h * w}); + T* y_data = y.mutable_data(PlaceType()); + BenchAllImpls, PlaceType>( + w, x_data, y_data, static_cast(h), w); + } + } +} + using T = float; using CPUPlace = paddle::platform::CPUPlace; @@ -498,6 +515,7 @@ BENCH_FP32_CPU(kVSquare) { BenchXYNKernel(); } BENCH_FP32_CPU(kVExp) { BenchXYNKernel(); } BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel(); } BENCH_FP32_CPU(kVTanh) { BenchXYNKernel(); } +BENCH_FP32_CPU(kVCopy) { BenchXYNKernel(); } // lstm and peephole BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel(); } @@ -535,6 +553,11 @@ BENCH_FP32_CPU(kCRFDecoding) { BenchCRFDecodingKernel(); } +// vbroadcast function +BENCH_FP32_CPU(kVBroadcast) { + BenchVBroadcastKernel(); +} + // Benchmark all jit kernels including jitcode, mkl and refer. // To use this tool, run command: ./benchmark [options...] // Options: diff --git a/paddle/fluid/operators/jit/gen/CMakeLists.txt b/paddle/fluid/operators/jit/gen/CMakeLists.txt index eb0c03568ddddf1c456fec6fcc81f3b40d051844..99244ea9bd919a018732b75d1ab811e8bf338516 100644 --- a/paddle/fluid/operators/jit/gen/CMakeLists.txt +++ b/paddle/fluid/operators/jit/gen/CMakeLists.txt @@ -33,3 +33,4 @@ USE_JITKERNEL_GEN(kHMax) USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kEmbSeqPool) USE_JITKERNEL_GEN(kSgd) +USE_JITKERNEL_GEN(kVBroadcast) diff --git a/paddle/fluid/operators/jit/gen/vbroadcast.cc b/paddle/fluid/operators/jit/gen/vbroadcast.cc new file mode 100644 index 0000000000000000000000000000000000000000..3f9fbdbd821acae0940c5a7b8d9a5eb2432712ff --- /dev/null +++ b/paddle/fluid/operators/jit/gen/vbroadcast.cc @@ -0,0 +1,91 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#include "paddle/fluid/operators/jit/gen/vbroadcast.h" +#include +#include +#include "paddle/fluid/operators/jit/registry.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +void VBroadcastJitCode::genCode() { + preCode(); + constexpr int block = YMM_FLOAT_BLOCK; + constexpr int max_num_regs = 16; + const int num_block = w_ / block; + const int num_groups = num_block / max_num_regs; + const size_t block_size = sizeof(float) * block; + std::vector groups(num_groups, max_num_regs); + int rest_num_regs = num_block % max_num_regs; + if (rest_num_regs > 0) { + groups.push_back(rest_num_regs); + } + + // protect param_h + mov(reg_height, param_h); + Label l_next_h; + xor_(reg_h_i, reg_h_i); + mov(reg_ptr_dst_i, param_dst); + L(l_next_h); + { + mov(reg_ptr_src_i, param_src); + for (int num_regs : groups) { + size_t w_offset = 0; + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]); + w_offset += block_size; + } + add(reg_ptr_src_i, num_regs * block_size); + + w_offset = 0; + for (int reg_i = 0; reg_i < num_regs; ++reg_i) { + vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i)); + w_offset += block_size; + } + add(reg_ptr_dst_i, num_regs * block_size); + } // end of groups + inc(reg_h_i); + cmp(reg_h_i, reg_height); + jl(l_next_h, T_NEAR); + } // end of l_next_h + + postCode(); +} + +class VBroadcastCreator : public JitCodeCreator { + public: + bool UseMe(const int64_t& w) const override { + return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0; + } + size_t CodeSize(const int64_t& w) const override { + return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8; + } + std::unique_ptr CreateJitCode(const int64_t& w) const override { + PADDLE_ENFORCE_GT(w, 0); + return make_unique(w, CodeSize(w)); + } +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle + +namespace gen = paddle::operators::jit::gen; + +REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator); diff --git a/paddle/fluid/operators/jit/gen/vbroadcast.h b/paddle/fluid/operators/jit/gen/vbroadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..27c75f6f710e9514c7d91181e7f447d9dd997081 --- /dev/null +++ b/paddle/fluid/operators/jit/gen/vbroadcast.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + +#pragma once + +#include +#include "glog/logging.h" +#include "paddle/fluid/operators/jit/gen/jitcode.h" + +namespace paddle { +namespace operators { +namespace jit { +namespace gen { + +class VBroadcastJitCode : public JitCode { + public: + explicit VBroadcastJitCode(const int64_t& w, size_t code_size = 256 * 1024, + void* code_ptr = nullptr) + : JitCode(code_size, code_ptr), w_(w) { + this->genCode(); + } + + DECLARE_JIT_CODE(VBroadcastJitCode); + void genCode() override; + + private: + int w_; + reg64_t param_src{abi_param1}; + reg64_t param_dst{abi_param2}; + reg64_t param_h{abi_param3}; + reg64_t param_w{abi_param4}; + + reg64_t reg_height{r9}; + reg64_t reg_h_i{r10}; + reg64_t reg_ptr_src_i{r11}; + reg64_t reg_ptr_dst_i{r12}; +}; + +} // namespace gen +} // namespace jit +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/jit/helper.cc b/paddle/fluid/operators/jit/helper.cc index 1dc60442d5c5f6acf49b6319223b190f6c81e1a6..eb1c410b6f9a31c3f97a274c5e5ff55bf1c32ea0 100644 --- a/paddle/fluid/operators/jit/helper.cc +++ b/paddle/fluid/operators/jit/helper.cc @@ -36,6 +36,8 @@ const char* to_string(KernelType kt) { ONE_CASE(kVScal); ONE_CASE(kVAddBias); ONE_CASE(kVRelu); + ONE_CASE(kVBroadcast); + ONE_CASE(kVCopy); ONE_CASE(kVIdentity); ONE_CASE(kVExp); ONE_CASE(kVSquare); diff --git a/paddle/fluid/operators/jit/kernel_base.h b/paddle/fluid/operators/jit/kernel_base.h index 895e2d4d6f3809a66443ed6d6bfc1ee02d6c529a..96e162a21bff2a5624f35ada615c9a9a17ad3c75 100644 --- a/paddle/fluid/operators/jit/kernel_base.h +++ b/paddle/fluid/operators/jit/kernel_base.h @@ -41,6 +41,8 @@ typedef enum { kVAdd, kVAddBias, kVAddRelu, + kVBroadcast, + kVCopy, kVExp, kVIdentity, kVMul, @@ -133,6 +135,13 @@ struct GRUTuples { typedef void (*func_type)(gru_t*, const gru_attr_t*); }; +template +struct VBroadcastTuples { + typedef T data_type; + typedef int64_t attr_type; + typedef void (*func_type)(const T*, T*, int64_t, int64_t); +}; + typedef struct seq_pool_attr_s { int h, w; // h should always be the first one SeqPoolType type; diff --git a/paddle/fluid/operators/jit/kernel_key.cc b/paddle/fluid/operators/jit/kernel_key.cc index 740d0f850a072a5ad3238e52402141a83c0b7e33..1c2fddcae79d8b89e1169d5bcb364b3ff2e42dd3 100644 --- a/paddle/fluid/operators/jit/kernel_key.cc +++ b/paddle/fluid/operators/jit/kernel_key.cc @@ -24,6 +24,11 @@ size_t JitCodeKey(const int& d) { return d; } +template <> +size_t JitCodeKey(const int64_t& d) { + return d; +} + // TODO(TJ): refine and benchmark JitCodeKey generatation constexpr int act_type_shift = 3; // suppot 2^3 act types static inline int act_type_convert(KernelType type) { diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index 9a00ad56a6a909a677cb8f60bd80fe399e82952f..f69417c370b653d93cce04a2248ad809168670da 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -9,9 +9,11 @@ USE_JITKERNEL_MORE(kVAdd, mkl) USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVSquare, mkl) +USE_JITKERNEL_MORE(kVCopy, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl) USE_JITKERNEL_MORE(kEmbSeqPool, mkl) USE_JITKERNEL_MORE(kSgd, mkl) +USE_JITKERNEL_MORE(kVBroadcast, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 780fda02c1ff3da2e0b945f9b2fece30484e4519..4f51353bce834325e6c659399a374e4fbc40d4b7 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -154,6 +154,21 @@ bool VSquareKernel::UseMe(const int& d) const { return d > 7; } +template <> +bool VCopyKernel::UseMe(const int& d) const { + return d > 15; +} + +template <> +bool VBroadcastKernel::UseMe(const int64_t& d) const { + return d > 127; +} + +template <> +bool VBroadcastKernel::UseMe(const int64_t& attr) const { + return true; +} + template <> bool VSigmoidKernel::UseMe(const int& d) const { return d > 7; @@ -223,6 +238,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VExp); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VTanh); AWALYS_USE_ME_WITH_DOUBLE(VSquare); +AWALYS_USE_ME_WITH_DOUBLE(VCopy); AWALYS_USE_ME_WITH_DOUBLE(Softmax); #undef AWALYS_USE_ME_WITH_DOUBLE @@ -244,6 +260,8 @@ REGISTER_MKL_KERNEL(kVAdd, VAdd); REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVSquare, VSquare); +REGISTER_MKL_KERNEL(kVCopy, VCopy); +REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kSeqPool, SeqPool); diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index a7bc2de4a3e8e7d8e2a6b00990bfa459b3029c2a..db2d6faed4fdcfebedb9d9eb752831259af30186 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -50,6 +50,13 @@ void VCopy(const T* x, T* y, int n); template void VAXPY(T a, const T* x, T* y, int n); +template +void VBroadcast(const T* x, T* y, int64_t y_h, int64_t x_len) { + for (int64_t h = 0; h < y_h; ++h) { + VCopy(x, y + h * x_len, x_len); + } +} + template void VSigmoid(const T* x, T* y, int n) { const T min = SIGMOID_THRESHOLD_MIN; @@ -192,6 +199,7 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VSquare, XYNTuples); +DECLARE_MKL_KERNEL(VCopy, XYNTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); @@ -201,6 +209,8 @@ DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Sgd, SgdTuples); +DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples); + #undef DECLARE_MKL_KERNEL } // namespace mkl diff --git a/paddle/fluid/operators/jit/refer/CMakeLists.txt b/paddle/fluid/operators/jit/refer/CMakeLists.txt index cd19dd169d0bfdfe2cb8157ade29f48ad6428453..ffab9c1457b932b3211e6aa75954bb1435f8e34c 100644 --- a/paddle/fluid/operators/jit/refer/CMakeLists.txt +++ b/paddle/fluid/operators/jit/refer/CMakeLists.txt @@ -13,6 +13,7 @@ USE_JITKERNEL_REFER(kVAddRelu) USE_JITKERNEL_REFER(kVSub) USE_JITKERNEL_REFER(kVScal) USE_JITKERNEL_REFER(kVAddBias) +USE_JITKERNEL_REFER(kVCopy) USE_JITKERNEL_REFER(kVRelu) USE_JITKERNEL_REFER(kVIdentity) USE_JITKERNEL_REFER(kVExp) @@ -34,3 +35,4 @@ USE_JITKERNEL_REFER(kHMax) USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kSgd) +USE_JITKERNEL_REFER(kVBroadcast) diff --git a/paddle/fluid/operators/jit/refer/refer.cc b/paddle/fluid/operators/jit/refer/refer.cc index 0c434bd2b8cacdf4b8872da66bb8e763a6a45cee..c279d1b2ca4f53bb6bc5da0cab41e9086ed475bd 100644 --- a/paddle/fluid/operators/jit/refer/refer.cc +++ b/paddle/fluid/operators/jit/refer/refer.cc @@ -30,6 +30,7 @@ REGISTER_REFER_KERNEL(kVScal, VScal); REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(kVRelu, VRelu); +REGISTER_REFER_KERNEL(kVCopy, VCopy); REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(kVExp, VExp); @@ -61,4 +62,6 @@ REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_REFER_KERNEL(kSgd, Sgd); +REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast); + #undef REGISTER_REFER_KERNEL diff --git a/paddle/fluid/operators/jit/refer/refer.h b/paddle/fluid/operators/jit/refer/refer.h index 0f714edf85bbbf4838bfe09251bd1c2d5f3b3eb7..b3b2097828c5b6d647fd6bfe14a6e8bff04409e0 100644 --- a/paddle/fluid/operators/jit/refer/refer.h +++ b/paddle/fluid/operators/jit/refer/refer.h @@ -70,6 +70,20 @@ void VAddBias(const T* a, const T* x, T* y, int n) { } } +template +void VCopy(const T* x, T* y, int n) { + std::memcpy(y, x, n * sizeof(T)); +} + +// x shape: (x_len) +// y shape: (h, x_len) +template +void VBroadcast(const T* x, T* y, int64_t y_h, int64_t x_len) { + for (int64_t h = 0; h < y_h; ++h) { + VCopy(x, y + h * x_len, x_len); + } +} + template void VRelu(const T* x, T* y, int n) { for (int i = 0; i < n; ++i) { @@ -500,6 +514,7 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VSquare, XYNTuples); +DECLARE_REFER_KERNEL(VCopy, XYNTuples); // lstm_t*, const lstm_attr_t* DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); @@ -528,6 +543,8 @@ DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); DECLARE_REFER_KERNEL(Sgd, SgdTuples); +DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples); + #undef DECLARE_REFER_KERNEL } // namespace refer diff --git a/paddle/fluid/operators/jit/test.cc b/paddle/fluid/operators/jit/test.cc index b618cd6a84be752a052f9d49a4a4c772b1d7eeae..cdec14dc4383897f4ae24fc89b99fe00c713cf42 100644 --- a/paddle/fluid/operators/jit/test.cc +++ b/paddle/fluid/operators/jit/test.cc @@ -26,8 +26,8 @@ limitations under the License. */ DEFINE_double(acc, 1e-5, "Test accuracy threshold."); template -void RandomVec(const int n, T* a, const T lower = static_cast(-20.f), - const T upper = static_cast(20.f)) { +void RandomVec(const int n, T* a, const T lower = static_cast(-2.f), + const T upper = static_cast(2.f)) { static unsigned int seed = 100; std::mt19937 rng(seed++); std::uniform_real_distribution uniform_dist(0, 1); @@ -157,6 +157,26 @@ struct TestFuncWithRefer, std::vector, T> { } }; +template +struct TestFuncWithRefer, std::vector, + std::vector, int64_t, + typename jit::VBroadcastTuples::attr_type> { + void operator()(const typename jit::VBroadcastTuples::func_type tgt, + const std::vector& x, const std::vector& yref, + int64_t h, + const typename jit::VBroadcastTuples::attr_type& attr) { + EXPECT_TRUE(tgt != nullptr); + EXPECT_EQ(x.size(), static_cast(attr)); + EXPECT_EQ(yref.size(), x.size() * h); + std::vector y(yref.size()); + const T* x_data = x.data(); + const T* yref_data = yref.data(); + T* y_data = y.data(); + tgt(x_data, y_data, h, attr); + ExpectEQ(y_data, yref_data, yref.size()); + } +}; + template struct TestFuncWithRefer, std::vector, std::vector> { void operator()(const typename jit::XYNTuples::func_type tgt, @@ -514,7 +534,7 @@ void TestKernelXRNTuples() { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(d); - RandomVec(d, x.data(), -2.f, 2.f); + RandomVec(d, x.data()); T ref_res; ref(x.data(), &ref_res, d); TestAllImpls, PlaceType, std::vector, T>(d, x, @@ -532,7 +552,7 @@ void TestKernelXYNTuples() { std::vector x(d), yref(d); std::vector xinp(d); // inplace test - RandomVec(d, x.data(), -2.f, 2.f); + RandomVec(d, x.data()); std::copy(x.begin(), x.end(), xinp.begin()); const T* x_data = x.data(); @@ -566,7 +586,7 @@ void TestKernelLSTMTuples() { EXPECT_TRUE(ref != nullptr); std::vector xsrc(4 * d), wp(3 * d), ct_1(d); std::vector ct_ref(d), ht_ref(d), checked(2 * d); - RandomVec(4 * d, xsrc.data(), -2.f, 2.f); + RandomVec(4 * d, xsrc.data()); RandomVec(3 * d, wp.data(), -1.f, 1.f); RandomVec(d, ct_1.data(), -1.f, 1.f); // x could be changed after compute, so copy to save src @@ -614,8 +634,8 @@ void TestKernelGRUTuples() { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector xsrc(3 * d), ht_1(d), ht_ref(d); - RandomVec(3 * d, xsrc.data(), -2.f, 2.f); - RandomVec(d, ht_1.data(), -2.f, 2.f); + RandomVec(3 * d, xsrc.data()); + RandomVec(d, ht_1.data()); // x could be changed after compute, so copy to save src std::vector x(xsrc.size()); std::copy(xsrc.begin(), xsrc.end(), x.begin()); @@ -651,7 +671,7 @@ void TestKernelSeqPoolTuples() { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(h * w), yref(w); - RandomVec(h * w, x.data(), -2.f, 2.f); + RandomVec(h * w, x.data()); const T* x_data = x.data(); T* yref_data = yref.data(); ref(x_data, yref_data, &attr); @@ -676,8 +696,8 @@ void TestKernelMatMulTuples() { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector a(m * k), b(k * n), c(m * n); - RandomVec(m * k, a.data(), -2.f, 2.f); - RandomVec(k * n, b.data(), -2.f, 2.f); + RandomVec(m * k, a.data()); + RandomVec(k * n, b.data()); const T* a_data = a.data(); const T* b_data = b.data(); T* c_data = c.data(); @@ -699,7 +719,7 @@ void TestKernelSoftmaxTuples() { auto ref = jit::GetRefer>(); EXPECT_TRUE(ref != nullptr); std::vector x(bs * n), y(bs * n); - RandomVec(bs * n, x.data(), -2.f, 2.f); + RandomVec(bs * n, x.data()); const T* x_data = x.data(); T* y_data = y.data(); @@ -726,7 +746,7 @@ void TestKernelEmbSeqPoolTuples() { test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); for (int tbl_w : test_sizes) { std::vector table(tbl_h * tbl_w); - RandomVec(tbl_h * tbl_w, table.data(), -2.f, 2.f); + RandomVec(tbl_h * tbl_w, table.data()); const T* table_data = table.data(); for (auto type : pool_types) { for (int idx_w : {1, 2, 10, 16}) { @@ -772,14 +792,14 @@ void TestKernelSgdTuples() { for (int grad_w : TestSizes()) { std::vector param(param_h * grad_w); std::vector param_out(param_h * grad_w); - RandomVec(param_h * grad_w, param.data(), -2.f, 2.f); + RandomVec(param_h * grad_w, param.data()); const T* param_data = param.data(); T* out_data = param_out.data(); for (int rows_size = 1; rows_size <= param_h; ++rows_size) { std::vector grad(rows_size * grad_w); std::vector rows = UnDuplicatedRandomVec(rows_size, 0, rows_size - 1); - RandomVec(rows_size * grad_w, grad.data(), -2.f, 2.f); + RandomVec(rows_size * grad_w, grad.data()); const int64_t* rows_data = rows.data(); const T* grad_data = grad.data(); auto ref = jit::GetRefer>(); @@ -815,8 +835,8 @@ void TestKernelNCHW16CMulNCTuples() { int sz = n * c * h * w; std::vector x(sz), y(n * c), zref(sz); std::vector ztgt(sz), zjit(sz); - RandomVec(sz, x.data(), -2.f, 2.f); - RandomVec(n * c, y.data(), -2.f, 2.f); + RandomVec(sz, x.data()); + RandomVec(n * c, y.data()); const T* x_data = x.data(); const T* y_data = y.data(); @@ -873,11 +893,11 @@ void TestKernelLayerNormTuples() { int sz = left * right; std::vector x(sz), mean(left), var(left), scale(right), bias(right), outref(sz); - RandomVec(sz, x.data(), -2.f, 2.f); - RandomVec(left, mean.data(), -2.f, 2.f); - RandomVec(left, var.data(), -2.f, 2.f); - RandomVec(right, scale.data(), -2.f, 2.f); - RandomVec(right, bias.data(), -2.f, 2.f); + RandomVec(sz, x.data()); + RandomVec(left, mean.data()); + RandomVec(left, var.data()); + RandomVec(right, scale.data()); + RandomVec(right, bias.data()); const T* scale_data = scale.data(); const T* bias_data = bias.data(); @@ -903,7 +923,7 @@ void TestKernelCRFDecodingTuples() { VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); constexpr int state_trans_base_idx = 2; auto test_sizes = TestSizes(); - test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); + test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000)); for (int seq_len : {1, 11, 17, 50}) { for (int tag_num : test_sizes) { auto ref = jit::GetRefer>(); @@ -912,8 +932,8 @@ void TestKernelCRFDecodingTuples() { int w_sz = (tag_num + state_trans_base_idx) * tag_num; std::vector x(x_sz), w(w_sz), alpharef(x_sz); std::vector trackref(x_sz); - RandomVec(x_sz, x.data(), -2.f, 2.f); - RandomVec(w_sz, w.data(), -2.f, 2.f); + RandomVec(x_sz, x.data()); + RandomVec(w_sz, w.data()); ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(), trackref.data(), tag_num); @@ -926,6 +946,27 @@ void TestKernelCRFDecodingTuples() { } } +template +void TestKernelVBroadcastTuples() { + VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); + for (int w : TestSizes()) { + std::vector x(w); + RandomVec(w, x.data()); + const T* x_data = x.data(); + for (int64_t h : {1, 2, 6}) { + auto ref = jit::GetRefer>(); + EXPECT_TRUE(ref != nullptr); + std::vector y(w * h); + T* y_data = y.data(); + ref(x_data, y_data, h, w); + + TestAllImpls, PlaceType, std::vector, + std::vector, int64_t>(static_cast(w), x, y, h, + static_cast(w)); + } + } +} + #define TEST_CPU_KERNEL(test_tuple, kernel_type) \ TEST(JITKernel, kernel_type) { \ TestKernel##test_tuple(); \ @@ -949,6 +990,7 @@ TEST_CPU_KERNEL(XYNTuples, kVSquare); TEST_CPU_KERNEL(XYNTuples, kVExp); TEST_CPU_KERNEL(XYNTuples, kVSigmoid); TEST_CPU_KERNEL(XYNTuples, kVTanh); +TEST_CPU_KERNEL(XYNTuples, kVCopy); TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt); TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1); @@ -966,6 +1008,7 @@ TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool); TEST_CPU_KERNEL(SgdTuples, kSgd); TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm); TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding); +TEST_CPU_KERNEL(VBroadcastTuples, kVBroadcast); TEST(JITKernel_key, lstm) { jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);