diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 744e83541d3d5bc02ee5570460ac960d40408e7b..92345b3c0eda42f03bd48c5f64880c56ddbe351b 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/selected_rows.h" +#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -31,35 +32,6 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; -template -void emb_seqpool(const framework::ExecutionContext &context, const T *table, - const int64_t *idx, T *out, int64_t table_height, - int64_t table_width, int64_t idx_height, int64_t idx_width, - int64_t out_width) { // pool type == sum - PADDLE_ENFORCE_EQ(table_width * idx_width, out_width); - - auto check_idx_value_valid = [&](int i) { - PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i); - PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i); - }; - auto blas = math::GetBlas(context); - - for (int w = 0; w != idx_width; ++w) { - check_idx_value_valid(w); - blas.VCOPY(table_width, table + idx[w] * table_width, - out + w * table_width); - } - - for (int h = 1; h < idx_height; ++h) { - for (int w = 0; w < idx_width; ++w) { - int i = h * idx_width + w; - check_idx_value_valid(i); - blas.AXPY(table_width, static_cast(1), table + idx[i] * table_width, - out + w * table_width); - } - } -} - template struct EmbeddingVSumFunctor { void operator()(const framework::ExecutionContext &context, @@ -75,10 +47,15 @@ struct EmbeddingVSumFunctor { auto *output = output_t->mutable_data(context.GetPlace()); PADDLE_ENFORCE_LE(table_width * idx_width, out_width); + + jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width, + out_width, jit::SeqPoolType::kSum); for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { - emb_seqpool(context, table, ids + ids_lod[i] * idx_width, - output + i * out_width, table_height, table_width, - ids_lod[i + 1] - ids_lod[i], idx_width, out_width); + attr.index_height = ids_lod[i + 1] - ids_lod[i]; + auto emb_seqpool = jit::Get, + platform::CPUPlace>(attr); + emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width, + &attr); } } }; diff --git a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt index f9e5aea32e7cd48e9b39c4c3ee0e30f4a5c84f6f..d209f31007255b3a90fdeeb4d609311b80bdc7b5 100644 --- a/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt +++ b/paddle/fluid/operators/jit/more/mkl/CMakeLists.txt @@ -13,3 +13,4 @@ USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl) +USE_JITKERNEL_MORE(kEmbSeqPool, mkl) diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.cc b/paddle/fluid/operators/jit/more/mkl/mkl.cc index 4c999131ab116ebe3484355158993558b02cc4b2..29a451f832fa745f8e1f5a45fd934f09e1f41e76 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.cc +++ b/paddle/fluid/operators/jit/more/mkl/mkl.cc @@ -174,6 +174,16 @@ bool SeqPoolKernel::UseMe(const seq_pool_attr_t& attr) const { return true; } +template <> +bool EmbSeqPoolKernel::UseMe(const emb_seq_pool_attr_t& attr) const { + return true; +} + +template <> +bool EmbSeqPoolKernel::UseMe(const emb_seq_pool_attr_t& attr) const { + return true; +} + template <> bool MatMulKernel::UseMe(const matmul_attr_t& attr) const { return platform::MayIUse(platform::avx); @@ -227,6 +237,7 @@ REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kSeqPool, SeqPool); +REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool); REGISTER_MKL_KERNEL(kSoftmax, Softmax); #undef REGISTER_MKL_KERNEL diff --git a/paddle/fluid/operators/jit/more/mkl/mkl.h b/paddle/fluid/operators/jit/more/mkl/mkl.h index 8130b87326f1887f232022ab30fa7bf42b0723e7..9a72ba83022de2beeb760772ee8489477befdd7e 100644 --- a/paddle/fluid/operators/jit/more/mkl/mkl.h +++ b/paddle/fluid/operators/jit/more/mkl/mkl.h @@ -18,6 +18,7 @@ #include #include #include "paddle/fluid/operators/jit/kernel_base.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { @@ -91,6 +92,32 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { } } +template +void EmbSeqPool(const T* table, const int64_t* idx, T* out, + const emb_seq_pool_attr_t* attr) { + PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width); + auto check_idx_value_valid = [&](int64_t i) { + PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d", + idx[i], i); + PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i); + }; + + for (int64_t w = 0; w != attr->index_width; ++w) { + check_idx_value_valid(w); + VCopy(table + idx[w] * attr->table_width, out + w * attr->table_width, + attr->table_width); + } + + for (int64_t h = 1; h < attr->index_height; ++h) { + for (int64_t w = 0; w < attr->index_width; ++w) { + int64_t i = h * attr->index_width + w; + check_idx_value_valid(i); + VAXPY(static_cast(1), table + idx[i] * attr->table_width, + out + w * attr->table_width, attr->table_width); + } + } +} + template void ASum(const T* x, T* res, int n); @@ -142,6 +169,8 @@ DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); +DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples); + DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); #undef DECLARE_MKL_KERNEL