提交 a3a3d3d8 编写于 作者: T tensor-tang

add embseqpool jitkernel mkl impl and use it

test=develop
上级 15da2f9a
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
...@@ -31,35 +32,6 @@ using LoDTensor = framework::LoDTensor; ...@@ -31,35 +32,6 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename T>
void emb_seqpool(const framework::ExecutionContext &context, const T *table,
const int64_t *idx, T *out, int64_t table_height,
int64_t table_width, int64_t idx_height, int64_t idx_width,
int64_t out_width) { // pool type == sum
PADDLE_ENFORCE_EQ(table_width * idx_width, out_width);
auto check_idx_value_valid = [&](int i) {
PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int w = 0; w != idx_width; ++w) {
check_idx_value_valid(w);
blas.VCOPY(table_width, table + idx[w] * table_width,
out + w * table_width);
}
for (int h = 1; h < idx_height; ++h) {
for (int w = 0; w < idx_width; ++w) {
int i = h * idx_width + w;
check_idx_value_valid(i);
blas.AXPY(table_width, static_cast<T>(1), table + idx[i] * table_width,
out + w * table_width);
}
}
}
template <typename T> template <typename T>
struct EmbeddingVSumFunctor { struct EmbeddingVSumFunctor {
void operator()(const framework::ExecutionContext &context, void operator()(const framework::ExecutionContext &context,
...@@ -75,10 +47,15 @@ struct EmbeddingVSumFunctor { ...@@ -75,10 +47,15 @@ struct EmbeddingVSumFunctor {
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_LE(table_width * idx_width, out_width); PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
out_width, jit::SeqPoolType::kSum);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
emb_seqpool(context, table, ids + ids_lod[i] * idx_width, attr.index_height = ids_lod[i + 1] - ids_lod[i];
output + i * out_width, table_height, table_width, auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
ids_lod[i + 1] - ids_lod[i], idx_width, out_width); platform::CPUPlace>(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr);
} }
} }
}; };
......
...@@ -13,3 +13,4 @@ USE_JITKERNEL_MORE(kVSigmoid, mkl) ...@@ -13,3 +13,4 @@ USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl)
USE_JITKERNEL_MORE(kSoftmax, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl)
USE_JITKERNEL_MORE(kEmbSeqPool, mkl)
...@@ -174,6 +174,16 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const { ...@@ -174,6 +174,16 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true; return true;
} }
template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <> template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const { bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx); return platform::MayIUse(platform::avx);
...@@ -227,6 +237,7 @@ REGISTER_MKL_KERNEL(kVSquare, VSquare); ...@@ -227,6 +237,7 @@ REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax); REGISTER_MKL_KERNEL(kSoftmax, Softmax);
#undef REGISTER_MKL_KERNEL #undef REGISTER_MKL_KERNEL
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "paddle/fluid/operators/jit/kernel_base.h" #include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -91,6 +92,32 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) { ...@@ -91,6 +92,32 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
} }
} }
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
for (int64_t w = 0; w != attr->index_width; ++w) {
check_idx_value_valid(w);
VCopy<T>(table + idx[w] * attr->table_width, out + w * attr->table_width,
attr->table_width);
}
for (int64_t h = 1; h < attr->index_height; ++h) {
for (int64_t w = 0; w < attr->index_width; ++w) {
int64_t i = h * attr->index_width + w;
check_idx_value_valid(i);
VAXPY<T>(static_cast<T>(1), table + idx[i] * attr->table_width,
out + w * attr->table_width, attr->table_width);
}
}
}
template <typename T> template <typename T>
void ASum(const T* x, T* res, int n); void ASum(const T* x, T* res, int n);
...@@ -142,6 +169,8 @@ DECLARE_MKL_KERNEL(VSquare, XYNTuples); ...@@ -142,6 +169,8 @@ DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册