未验证 提交 5aea2cd2 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15652 from tensor-tang/refine/pyramiddnn

refine fused emb seq pool
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......@@ -37,32 +38,24 @@ struct EmbeddingVSumFunctor {
const LoDTensor *table_t, const LoDTensor *ids_t,
LoDTensor *output_t) {
auto *table = table_t->data<T>();
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
int64_t last_dim = output_t->dims()[1];
int64_t table_height = table_t->dims()[0];
int64_t table_width = table_t->dims()[1];
int64_t out_width = output_t->dims()[1];
const int64_t *ids = ids_t->data<int64_t>();
auto ids_lod = ids_t->lod()[0];
int64_t ids_count = ids_t->numel() / ids_lod.back();
int64_t idx_width = ids_t->numel() / ids_lod.back();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
size_t begin = ids_lod[i] * ids_count;
for (int64_t j = 0; j != ids_count; ++j) {
PADDLE_ENFORCE_LT(ids[begin], row_number);
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
blas.VCOPY(row_width, table + ids[begin + j] * row_width,
output + i * last_dim + j * row_width);
}
PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
for (int64_t r = (ids_lod[i] + 1) * ids_count;
r < ids_lod[i + 1] * ids_count; ++r) {
PADDLE_ENFORCE_LT(ids[r], row_number);
PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i);
blas.AXPY(row_width, 1., table + ids[r] * row_width,
output + i * last_dim + (r % ids_count) * row_width);
}
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
out_width, jit::SeqPoolType::kSum);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
attr.index_height = ids_lod[i + 1] - ids_lod[i];
auto emb_seqpool = jit::Get<jit::kEmbSeqPool, jit::EmbSeqPoolTuples<T>,
platform::CPUPlace>(attr);
emb_seqpool(table, ids + ids_lod[i] * idx_width, output + i * out_width,
&attr);
}
}
};
......
......@@ -301,6 +301,37 @@ void BenchSeqPoolKernel() {
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchEmbSeqPoolKernel() {
std::vector<jit::SeqPoolType> pool_types = {jit::SeqPoolType::kSum};
int64_t tbl_h = 1e4;
for (int tbl_w : {10, 16, 256}) {
Tensor table;
table.Resize({tbl_h, tbl_w});
RandomVec<T>(tbl_h * tbl_w, table.mutable_data<T>(PlaceType()), -2.f, 2.f);
const T* table_data = table.data<T>();
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 9, 13, 16}) {
int64_t out_w = tbl_w * idx_w;
jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
type);
Tensor idx, out;
idx.Resize({idx_h, idx_w});
out.Resize({out_w});
RandomVec<int64_t>(idx_h * idx_w,
idx.mutable_data<int64_t>(PlaceType()), 0,
tbl_h - 1);
const int64_t* idx_data = idx.data<int64_t>();
T* o_data = out.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType>(
attr, table_data, idx_data, o_data, &attr);
}
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchMatMulKernel() {
for (int m : {1, 2, 3, 4}) {
......@@ -441,6 +472,11 @@ BENCH_FP32_CPU(kGRUHtPart2) { BenchGRUKernel<jit::kGRUHtPart2, T, CPUPlace>(); }
// seq pool function
BENCH_FP32_CPU(kSeqPool) { BenchSeqPoolKernel<jit::kSeqPool, T, CPUPlace>(); }
// embedding seq pool function
BENCH_FP32_CPU(kEmbSeqPool) {
BenchEmbSeqPoolKernel<jit::kEmbSeqPool, T, CPUPlace>();
}
// matmul
BENCH_FP32_CPU(kMatMul) { BenchMatMulKernel<jit::kMatMul, T, CPUPlace>(); }
......
......@@ -31,3 +31,4 @@ USE_JITKERNEL_GEN(kNCHW16CMulNC)
USE_JITKERNEL_GEN(kSeqPool)
USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/embseqpool.h"
#include <stddef.h> // offsetof
#include <vector>
#include "paddle/fluid/operators/jit/gen/act.h" // for exp_float_consts ones
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void EmbSeqPoolJitCode::genCode() {
preCode();
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 8;
const int num_block = tbl_w_ / block;
const int num_groups = num_block / max_num_regs;
const size_t block_size = sizeof(float) * block;
std::vector<int> groups(num_groups, max_num_regs);
int rest_num_regs = num_block % max_num_regs;
if (rest_num_regs > 0) {
groups.push_back(rest_num_regs);
}
// protect param_dst
mov(reg_ptr_param_dst, param_dst);
mov(reg_idx_width_in_byte,
qword[param_attr + offsetof(emb_seq_pool_attr_t, index_width)]);
mov(reg_idx_height,
qword[param_attr + offsetof(emb_seq_pool_attr_t, index_height)]);
mov(rax, sizeof(int64_t));
mul(reg_idx_width_in_byte);
mov(reg_idx_width_in_byte, rax);
const size_t tbl_width_in_byte = sizeof(float) * tbl_w_;
int acc_num_regs = 0;
for (int num_regs : groups) {
Label l_next_idx_w, l_next_idx_h, l_save_now;
xor_(reg_idx_w_i_in_byte, reg_idx_w_i_in_byte);
mov(reg_ptr_dst_i, reg_ptr_param_dst);
add(reg_ptr_dst_i, acc_num_regs * block_size);
L(l_next_idx_w);
{
// h == 0
mov(reg_ptr_idx_i, param_idx);
add(reg_ptr_idx_i, reg_idx_w_i_in_byte);
mov(reg_idx, qword[reg_ptr_idx_i]);
mov(rax, tbl_width_in_byte);
mul(reg_idx);
mov(reg_ptr_tbl_i, rax); // reg is offset now
add(reg_ptr_tbl_i, param_tbl); // reg is ptr_i now
size_t w_offset = 0;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ymm_t(reg_i + num_regs), ptr[reg_ptr_tbl_i + w_offset]);
w_offset += block_size;
}
add(reg_ptr_idx_i, reg_idx_width_in_byte);
// end condition of idx h
mov(reg_idx_h_end, reg_idx_height);
mov(rax, reg_idx_width_in_byte);
mul(reg_idx_h_end);
mov(reg_idx_h_end, rax);
add(reg_idx_h_end, reg_idx_w_i_in_byte);
add(reg_idx_h_end, param_idx);
cmp(reg_ptr_idx_i, reg_idx_h_end);
jge(l_save_now, T_NEAR);
L(l_next_idx_h);
{
mov(reg_idx, qword[reg_ptr_idx_i]);
mov(reg_ptr_tbl_i, reg_idx);
mov(rax, tbl_width_in_byte);
mul(reg_idx);
mov(reg_ptr_tbl_i, rax);
add(reg_ptr_tbl_i, param_tbl);
size_t w_offset = 0;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ymm_t(reg_i), ptr[reg_ptr_tbl_i + w_offset]);
vaddps(ymm_t(reg_i + num_regs), ymm_t(reg_i + num_regs),
ymm_t(reg_i));
w_offset += block_size;
}
add(reg_ptr_idx_i, reg_idx_width_in_byte);
cmp(reg_ptr_idx_i, reg_idx_h_end);
jl(l_next_idx_h, T_NEAR);
} // end of idx h
L(l_save_now);
// avg or sqrt here, if needed
w_offset = 0;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i + num_regs));
w_offset += block_size;
}
add(reg_ptr_dst_i, tbl_width_in_byte);
add(reg_idx_w_i_in_byte, sizeof(int64_t));
cmp(reg_idx_w_i_in_byte, reg_idx_width_in_byte);
jl(l_next_idx_w, T_NEAR);
} // end of idx w
acc_num_regs += num_regs;
add(param_tbl, num_regs * block_size); // do not use acc_num_regs
} // end of groups
postCode();
}
class EmbSeqPoolCreator : public JitCodeCreator<emb_seq_pool_attr_t> {
public:
bool UseMe(const emb_seq_pool_attr_t& attr) const override {
return platform::MayIUse(platform::avx) &&
attr.table_width % YMM_FLOAT_BLOCK == 0;
}
size_t CodeSize(const emb_seq_pool_attr_t& attr) const override {
return 96 + (attr.table_width / YMM_FLOAT_BLOCK) * 96 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(
const emb_seq_pool_attr_t& attr) const override {
PADDLE_ENFORCE_GT(attr.table_height, 0);
PADDLE_ENFORCE_GT(attr.table_width, 0);
PADDLE_ENFORCE_GT(attr.index_height, 0);
PADDLE_ENFORCE_GT(attr.index_width, 0);
PADDLE_ENFORCE_GT(attr.out_width, 0);
return make_unique<EmbSeqPoolJitCode>(attr, CodeSize(attr));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kEmbSeqPool, gen::EmbSeqPoolCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class EmbSeqPoolJitCode : public JitCode {
public:
explicit EmbSeqPoolJitCode(const emb_seq_pool_attr_t& attr,
size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr),
tbl_w_(attr.table_width),
type_(attr.pool_type) {
if (type_ != SeqPoolType::kSum) {
LOG(FATAL) << "Only support sum pool yet ";
}
this->genCode();
}
std::string name() const override {
std::string base = "EmbSeqPoolJitCode";
if (type_ == SeqPoolType::kSum) {
base += "_Sum";
} else if (type_ == SeqPoolType::kAvg) {
base += "_Avg";
} else if (type_ == SeqPoolType::kSqrt) {
base += "_Sqrt";
}
base += ("_W" + std::to_string(tbl_w_));
return base;
}
void genCode() override;
private:
int tbl_w_;
SeqPoolType type_;
reg64_t param_tbl{abi_param1};
reg64_t param_idx{abi_param2};
reg64_t param_dst{abi_param3};
reg64_t param_attr{abi_param4};
reg64_t reg_tmp{rax};
reg64_t reg_idx_width_in_byte{r8};
reg64_t reg_idx_height{r9};
reg64_t reg_ptr_tbl_i{r10};
reg64_t reg_idx{r10}; // could use same of reg_ptr_tbl_i
reg64_t reg_ptr_idx_i{r11};
reg64_t reg_ptr_dst_i{r12};
reg64_t reg_ptr_param_dst{r13}; // rdx is used in mul so protect param_dst
reg64_t reg_idx_w_i_in_byte{r14};
reg64_t reg_idx_h_end{r15};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -32,7 +32,7 @@ class SeqPoolJitCode : public JitCode {
: JitCode(code_size, code_ptr), w_(attr.w), type_(attr.type) {
if (!(type_ == SeqPoolType::kSum || type_ == SeqPoolType::kAvg ||
type_ == SeqPoolType::kSqrt)) {
LOG(FATAL) << "Only support sum pool yet ";
LOG(FATAL) << "Only supported pool type: sum, avg and sqrt.";
}
fp_h_[0] = 1.f;
this->genCode();
......
......@@ -54,6 +54,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kHMax);
ONE_CASE(kHSum);
ONE_CASE(kSoftmax);
ONE_CASE(kEmbSeqPool);
default:
PADDLE_THROW("Not support type: %d, or forget to add it.", kt);
return "NOT JITKernel";
......
......@@ -172,6 +172,15 @@ inline std::ostream& operator<<(std::ostream& os, const seq_pool_attr_t& attr) {
return os;
}
inline std::ostream& operator<<(std::ostream& os,
const emb_seq_pool_attr_t& attr) {
os << "table_height[" << attr.table_height << "],table_width["
<< attr.table_width << "],index_height[" << attr.index_height
<< "],index_width[" << attr.index_width << "],output_width["
<< attr.out_width << "],pool_type[" << to_string(attr.pool_type) << "]";
return os;
}
inline std::ostream& operator<<(std::ostream& os, const matmul_attr_t& attr) {
os << "M[" << attr.m << "],N[" << attr.n << "],K[" << attr.k << "]";
return os;
......
......@@ -13,6 +13,7 @@
* limitations under the License. */
#pragma once
#include <cstdint>
#include "paddle/fluid/operators/jit/macro.h"
#include "paddle/fluid/platform/macros.h"
......@@ -20,34 +21,35 @@ namespace paddle {
namespace operators {
namespace jit {
// TODO(TJ): reorder by alphabet
typedef enum {
kNone = 0,
kVMul = 1,
kVAdd = 2,
kVAddRelu,
kVSub,
kVScal,
kVAddBias,
kVRelu,
kVIdentity,
kVSquare,
kVExp,
kVSigmoid,
kVTanh,
kLSTMCtHt,
kLSTMC1H1,
// sort by alphabet
kCRFDecoding = 1,
kEmbSeqPool = 2,
kGRUH1,
kGRUHtPart1,
kGRUHtPart2,
kCRFDecoding,
kHSum, // horizontal max
kHMax, // horizontal sum
kLSTMCtHt,
kLSTMC1H1,
kLayerNorm,
kMatMul,
kNCHW16CMulNC,
kSeqPool,
kMatMul,
kHSum, // horizontal max
kHMax, // horizontal sum
kSoftmax,
kVAdd,
kVAddBias,
kVAddRelu,
kVExp,
kVIdentity,
kVMul,
kVRelu,
kVScal,
kVSigmoid,
kVSquare,
kVSub,
kVTanh,
} KernelType;
typedef enum {
......@@ -145,6 +147,32 @@ struct SeqPoolTuples {
typedef void (*func_type)(const T*, T*, const seq_pool_attr_t*);
};
typedef struct emb_seq_pool_attr_s {
int64_t table_height, table_width;
int64_t index_height, index_width;
int64_t out_width;
SeqPoolType pool_type;
emb_seq_pool_attr_s() = default;
explicit emb_seq_pool_attr_s(int64_t tbl_height, int64_t tbl_width,
int64_t idx_height, int64_t idx_width,
int64_t output_width,
SeqPoolType seqpool_type = SeqPoolType::kSum)
: table_height(tbl_height),
table_width(tbl_width),
index_height(idx_height),
index_width(idx_width),
out_width(output_width),
pool_type(seqpool_type) {}
} emb_seq_pool_attr_t;
template <typename T>
struct EmbSeqPoolTuples {
typedef T data_type;
typedef emb_seq_pool_attr_t attr_type;
typedef void (*func_type)(const T*, const int64_t*, T*,
const emb_seq_pool_attr_t*);
};
typedef struct matmul_attr_s {
int m, n, k;
void* packed_weight{nullptr};
......
......@@ -56,6 +56,11 @@ size_t JitCodeKey<matmul_attr_t>(const matmul_attr_t& attr) {
return (key << shift * 2) + ((static_cast<size_t>(attr.n)) << shift) + attr.k;
}
template <>
size_t JitCodeKey<emb_seq_pool_attr_t>(const emb_seq_pool_attr_t& attr) {
return attr.table_width;
}
} // namespace jit
} // namespace operators
} // namespace paddle
......@@ -13,3 +13,4 @@ USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl)
USE_JITKERNEL_MORE(kSoftmax, mkl)
USE_JITKERNEL_MORE(kEmbSeqPool, mkl)
......@@ -174,6 +174,16 @@ bool SeqPoolKernel<double>::UseMe(const seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<float>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool EmbSeqPoolKernel<double>::UseMe(const emb_seq_pool_attr_t& attr) const {
return true;
}
template <>
bool MatMulKernel<float>::UseMe(const matmul_attr_t& attr) const {
return platform::MayIUse(platform::avx);
......@@ -227,6 +237,7 @@ REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
REGISTER_MKL_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_MKL_KERNEL(kSoftmax, Softmax);
#undef REGISTER_MKL_KERNEL
......@@ -18,6 +18,7 @@
#include <type_traits>
#include <vector>
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
......@@ -91,6 +92,32 @@ void SeqPool(const T* x, T* y, const seq_pool_attr_t* attr) {
}
}
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
for (int64_t w = 0; w != attr->index_width; ++w) {
check_idx_value_valid(w);
VCopy<T>(table + idx[w] * attr->table_width, out + w * attr->table_width,
attr->table_width);
}
for (int64_t h = 1; h < attr->index_height; ++h) {
for (int64_t w = 0; w < attr->index_width; ++w) {
int64_t i = h * attr->index_width + w;
check_idx_value_valid(i);
VAXPY<T>(static_cast<T>(1), table + idx[i] * attr->table_width,
out + w * attr->table_width, attr->table_width);
}
}
}
template <typename T>
void ASum(const T* x, T* res, int n);
......@@ -142,6 +169,8 @@ DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
DECLARE_MKL_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
#undef DECLARE_MKL_KERNEL
......
......@@ -32,3 +32,4 @@ USE_JITKERNEL_REFER(kVSquare)
USE_JITKERNEL_REFER(kHSum)
USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
......@@ -57,4 +57,6 @@ REGISTER_REFER_KERNEL(kHSum, HSum);
REGISTER_REFER_KERNEL(kSoftmax, Softmax);
REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
#undef REGISTER_REFER_KERNEL
......@@ -16,6 +16,7 @@
#include <cmath>
#include <limits>
#include <string>
#include "paddle/fluid/operators/jit/helper.h"
#include "paddle/fluid/operators/jit/kernel_base.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -414,6 +415,37 @@ void Softmax(const T* x, T* y, int n, int bs = 1) {
}
}
// embedding seq pool
// table is a matrix with (tbl_h, tbl_w)
// idx is a matrix with (idx_h, idx_w)
// output is a vector with length tbl_w * idx_w
template <typename T>
void EmbSeqPool(const T* table, const int64_t* idx, T* out,
const emb_seq_pool_attr_t* attr) {
PADDLE_ENFORCE_EQ(attr->table_width * attr->index_width, attr->out_width);
auto check_idx_value_valid = [&](int64_t i) {
PADDLE_ENFORCE_LT(idx[i], attr->table_height, "idx value: %d, i: %d",
idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
for (int64_t w = 0; w != attr->index_width; ++w) {
check_idx_value_valid(w);
std::memcpy(out + w * attr->table_width, table + idx[w] * attr->table_width,
attr->table_width * sizeof(T));
}
for (int64_t h = 1; h < attr->index_height; ++h) {
for (int64_t w = 0; w < attr->index_width; ++w) {
int64_t i = h * attr->index_width + w;
check_idx_value_valid(i);
VAdd(table + idx[i] * attr->table_width, out + w * attr->table_width,
out + w * attr->table_width, attr->table_width);
}
}
}
#define DECLARE_REFER_KERNEL(name, tuples) \
template <typename T> \
class name##Kernel : public ReferKernel<tuples<T>> { \
......@@ -462,6 +494,8 @@ DECLARE_REFER_KERNEL(HSum, XRNTuples);
DECLARE_REFER_KERNEL(Softmax, SoftmaxTuples);
DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
#undef DECLARE_REFER_KERNEL
} // namespace refer
......
......@@ -270,6 +270,32 @@ struct TestFuncWithRefer<jit::SeqPoolTuples<T>, std::vector<T>, std::vector<T>,
}
};
template <typename T>
struct TestFuncWithRefer<jit::EmbSeqPoolTuples<T>, std::vector<T>,
std::vector<int64_t>, std::vector<T>,
typename jit::EmbSeqPoolTuples<T>::attr_type> {
void operator()(const typename jit::EmbSeqPoolTuples<T>::func_type tgt,
const std::vector<T>& table, const std::vector<int64_t>& idx,
const std::vector<T>& oref,
const typename jit::EmbSeqPoolTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(table.size(),
static_cast<size_t>(attr.table_height * attr.table_width));
EXPECT_EQ(idx.size(),
static_cast<size_t>(attr.index_height * attr.index_width));
EXPECT_EQ(oref.size(),
static_cast<size_t>(attr.table_width * attr.index_width));
const T* table_data = table.data();
const int64_t* idx_data = idx.data();
const T* oref_data = oref.data();
int o_w = oref.size();
std::vector<T> out(o_w);
T* o_data = out.data();
tgt(table_data, idx_data, o_data, &attr);
ExpectEQ<T>(o_data, oref_data, o_w);
}
};
template <typename T>
struct TestFuncWithRefer<jit::MatMulTuples<T>, std::vector<T>, std::vector<T>,
std::vector<T>,
......@@ -644,6 +670,40 @@ void TestSoftmaxKernel() {
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestEmbSeqPoolKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
int64_t tbl_h = 1e4;
std::vector<jit::SeqPoolType> pool_types = {
jit::SeqPoolType::kSum}; // only support sum yet
for (int tbl_w : TestSizes()) {
std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f);
const T* table_data = table.data();
for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) {
for (int idx_h : {1, 2, 9, 13, 16}) {
auto ref = jit::GetRefer<KT, jit::EmbSeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<int64_t> idx(idx_h * idx_w);
RandomVec<int64_t>(idx_h * idx_w, idx.data(), 0, tbl_h - 1);
int64_t out_w = tbl_w * idx_w;
std::vector<T> oref(out_w);
const int64_t* idx_data = idx.data();
T* o_data = oref.data();
jit::emb_seq_pool_attr_t attr(tbl_h, tbl_w, idx_h, idx_w, out_w,
type);
ref(table_data, idx_data, o_data, &attr);
TestAllImpls<KT, jit::EmbSeqPoolTuples<T>, PlaceType, std::vector<T>,
std::vector<int64_t>, std::vector<T>>(attr, table, idx,
oref, attr);
}
}
}
}
}
template <jit::KernelType KT, typename T, typename PlaceType>
void TestNCHW16CMulNCKernel() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
......@@ -878,6 +938,11 @@ TEST(JITKernel, kSoftmax) {
TestSoftmaxKernel<jit::kSoftmax, double, CPUPlace>();
}
TEST(JITKernel, kEmbSeqPool) {
TestEmbSeqPoolKernel<jit::kEmbSeqPool, float, CPUPlace>();
TestEmbSeqPoolKernel<jit::kEmbSeqPool, double, CPUPlace>();
}
TEST(JITKernel, kNCHW16CMulNC) {
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, float, CPUPlace>();
TestNCHW16CMulNCKernel<jit::kNCHW16CMulNC, double, CPUPlace>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册