提交 a3a3d3d8 编写于 作者: T tensor-tang

add embseqpool jitkernel mkl impl and use it

test=develop
上级 15da2f9a
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......@@ -31,35 +32,6 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void emb_seqpool(const framework::ExecutionContext &context, const T *table,
const int64_t *idx, T *out, int64_t table_height,
int64_t table_width, int64_t idx_height, int64_t idx_width,
int64_t out_width) { // pool type == sum
PADDLE_ENFORCE_EQ(table_width * idx_width, out_width);
auto check_idx_value_valid = [&](int i) {
PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int w = 0; w != idx_width; ++w) {
check_idx_value_valid(w);
blas.VCOPY(table_width, table + idx[w] * table_width,
out + w * table_width);
}
for (int h = 1; h < idx_height; ++h) {
for (int w = 0; w < idx_width; ++w) {
int i = h * idx_width + w;
check_idx_value_valid(i);
blas.AXPY(table_width, static_cast<T>(1), table + idx[i] * table_width,
out + w * table_width);
}
}
}
template <typename T>
struct EmbeddingVSumFunctor {
void operator()(const framework::ExecutionContext &context,
......@@ -75,10 +47,15 @@ struct EmbeddingVSumFunctor {
auto *output = output_t->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
out_width, jit::SeqPoolType::kSum);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
emb_seqpool(context, table, ids + ids_lod[i] * idx_width,
output + i * out_width, table_height, table_width,
ids_lod[i + 1] - ids_lod[i], idx_width, out_width);
attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr);
}
}
};
......
......@@ -13,3 +13,4 @@ USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
USE_JITKERNEL_MORE(kSoftmax, mkl)
USE_JITKERNEL_MORE(kEmbSeqPool, mkl)
......@@ -174,6 +174,16 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx);
......@@ -227,6 +237,7 @@ REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax);
#undef REGISTER_MKL_KERNEL
......@@ -18,6 +18,7 @@
#include <type_traits>
#include <vector>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
......@@ -91,6 +92,32 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
}
}
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
for (int64_t w = 0; w != attr->index_width; ++w) {
check_idx_value_valid(w);
VCopy<T>(table + idx[w] * attr->table_width, out + w * attr->table_width,
attr->table_width);
}
for (int64_t h = 1; h < attr->index_height; ++h) {
for (int64_t w = 0; w < attr->index_width; ++w) {
int64_t i = h * attr->index_width + w;
check_idx_value_valid(i);
VAXPY<T>(static_cast<T>(1), table + idx[i] * attr->table_width,
out + w * attr->table_width, attr->table_width);
}
}
}
template <typename T>
void ASum(const T* x, T* res, int n);
......@@ -142,6 +169,8 @@ DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
#undef DECLARE_MKL_KERNEL
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册