未验证 提交 6057f362 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #15996 from tensor-tang/op/embgrad

refine embeddingseqpool grad
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor { ...@@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor {
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_LE(table_width * idx_width, out_width); PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
PADDLE_ENFORCE_GT(ids_lod.size(), 1UL); PADDLE_ENFORCE_GT(ids_lod.size(), 1UL, "The LoD[0] could NOT be empty");
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width, jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
out_width, jit::SeqPoolType::kSum); out_width, jit::SeqPoolType::kSum);
...@@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
const auto &ids_lod = ids_t->lod(); const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1 // in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1"); PADDLE_ENFORCE(ids_lod.size(), 1UL,
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); "The LoD level of Input(Ids) must be 1");
int64_t batch_size = ids_lod[0].size() - 1; int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output // in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size] // should be [seq_length, 1] -> [batch_size, last_dim]
output_t->Resize({batch_size, last_dim}); output_t->Resize({batch_size, last_dim});
if (combiner_type == "sum") { if (combiner_type == "sum") {
...@@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
auto lod = ids->lod()[0]; auto lod = ids->lod()[0];
int64_t row_width = d_output->dims()[1]; int64_t out_width = d_output->dims()[1];
framework::Vector<int64_t> *new_rows = d_table->mutable_rows(); framework::Vector<int64_t> *new_rows = d_table->mutable_rows();
new_rows->resize(ids_num); new_rows->resize(ids_num);
...@@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i] * row_width; const T *src = d_output_data + i * out_width;
const T *out_pos = d_output_data + i * row_width; T *dst = d_table_data + lod[i] * out_width;
T *in_pos = d_table_data + in_offset; vbroadcast(src, dst, h, out_width);
for (int r = 0; r != h; ++r) {
blas.VCOPY(row_width, out_pos, in_pos + r * row_width);
}
} }
} else { } else {
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
......
...@@ -474,6 +474,23 @@ void BenchCRFDecodingKernel() { ...@@ -474,6 +474,23 @@ void BenchCRFDecodingKernel() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchVBroadcastKernel() {
for (int64_t w : {1, 16, 64, 100, 256}) {
Tensor x;
x.Resize({w});
RandomVec<T>(w, x.mutable_data<T>(PlaceType()));
const T* x_data = x.data<T>();
for (int h : TestSizes()) {
Tensor y;
y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>(
w, x_data, y_data, static_cast<int64_t>(h), w);
}
}
}
using T = float; using T = float;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
...@@ -498,6 +515,7 @@ BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); } ...@@ -498,6 +515,7 @@ BENCH_FP32_CPU(kVSquare) { BenchXYNKernel<jit::kVSquare, T, CPUPlace>(); }
BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); } BENCH_FP32_CPU(kVExp) { BenchXYNKernel<jit::kVExp, T, CPUPlace>(); }
BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); } BENCH_FP32_CPU(kVSigmoid) { BenchXYNKernel<jit::kVSigmoid, T, CPUPlace>(); }
BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); } BENCH_FP32_CPU(kVTanh) { BenchXYNKernel<jit::kVTanh, T, CPUPlace>(); }
BENCH_FP32_CPU(kVCopy) { BenchXYNKernel<jit::kVCopy, T, CPUPlace>(); }
// lstm and peephole // lstm and peephole
BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); } BENCH_FP32_CPU(kLSTMCtHt) { BenchLSTMKernel<jit::kLSTMCtHt, T, CPUPlace>(); }
...@@ -535,6 +553,11 @@ BENCH_FP32_CPU(kCRFDecoding) { ...@@ -535,6 +553,11 @@ BENCH_FP32_CPU(kCRFDecoding) {
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>(); BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>();
} }
// vbroadcast function
BENCH_FP32_CPU(kVBroadcast) {
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>();
}
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
// Options: // Options:
......
...@@ -33,3 +33,4 @@ USE_JITKERNEL_GEN(kHMax) ...@@ -33,3 +33,4 @@ USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum) USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool) USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kSgd) USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast)
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/jit/gen/vbroadcast.h"
#include <memory>
#include <vector>
#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
void VBroadcastJitCode::genCode() {
preCode();
constexpr int block = YMM_FLOAT_BLOCK;
constexpr int max_num_regs = 16;
const int num_block = w_ / block;
const int num_groups = num_block / max_num_regs;
const size_t block_size = sizeof(float) * block;
std::vector<int> groups(num_groups, max_num_regs);
int rest_num_regs = num_block % max_num_regs;
if (rest_num_regs > 0) {
groups.push_back(rest_num_regs);
}
// protect param_h
mov(reg_height, param_h);
Label l_next_h;
xor_(reg_h_i, reg_h_i);
mov(reg_ptr_dst_i, param_dst);
L(l_next_h);
{
mov(reg_ptr_src_i, param_src);
for (int num_regs : groups) {
size_t w_offset = 0;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ymm_t(reg_i), ptr[reg_ptr_src_i + w_offset]);
w_offset += block_size;
}
add(reg_ptr_src_i, num_regs * block_size);
w_offset = 0;
for (int reg_i = 0; reg_i < num_regs; ++reg_i) {
vmovups(ptr[reg_ptr_dst_i + w_offset], ymm_t(reg_i));
w_offset += block_size;
}
add(reg_ptr_dst_i, num_regs * block_size);
} // end of groups
inc(reg_h_i);
cmp(reg_h_i, reg_height);
jl(l_next_h, T_NEAR);
} // end of l_next_h
postCode();
}
class VBroadcastCreator : public JitCodeCreator<int64_t> {
public:
bool UseMe(const int64_t& w) const override {
return platform::MayIUse(platform::avx) && w % YMM_FLOAT_BLOCK == 0;
}
size_t CodeSize(const int64_t& w) const override {
return 96 + (w / YMM_FLOAT_BLOCK) * 16 * 8;
}
std::unique_ptr<GenBase> CreateJitCode(const int64_t& w) const override {
PADDLE_ENFORCE_GT(w, 0);
return make_unique<VBroadcastJitCode>(w, CodeSize(w));
}
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
namespace gen = paddle::operators::jit::gen;
REGISTER_JITKERNEL_GEN(kVBroadcast, gen::VBroadcastCreator);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
namespace paddle {
namespace operators {
namespace jit {
namespace gen {
class VBroadcastJitCode : public JitCode {
public:
explicit VBroadcastJitCode(const int64_t& w, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr), w_(w) {
this->genCode();
}
DECLARE_JIT_CODE(VBroadcastJitCode);
void genCode() override;
private:
int w_;
reg64_t param_src{abi_param1};
reg64_t param_dst{abi_param2};
reg64_t param_h{abi_param3};
reg64_t param_w{abi_param4};
reg64_t reg_height{r9};
reg64_t reg_h_i{r10};
reg64_t reg_ptr_src_i{r11};
reg64_t reg_ptr_dst_i{r12};
};
} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
...@@ -36,6 +36,8 @@ const char* to_string(KernelType kt) { ...@@ -36,6 +36,8 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVScal); ONE_CASE(kVScal);
ONE_CASE(kVAddBias); ONE_CASE(kVAddBias);
ONE_CASE(kVRelu); ONE_CASE(kVRelu);
ONE_CASE(kVBroadcast);
ONE_CASE(kVCopy);
ONE_CASE(kVIdentity); ONE_CASE(kVIdentity);
ONE_CASE(kVExp); ONE_CASE(kVExp);
ONE_CASE(kVSquare); ONE_CASE(kVSquare);
......
...@@ -41,6 +41,8 @@ typedef enum { ...@@ -41,6 +41,8 @@ typedef enum {
kVAdd, kVAdd,
kVAddBias, kVAddBias,
kVAddRelu, kVAddRelu,
kVBroadcast,
kVCopy,
kVExp, kVExp,
kVIdentity, kVIdentity,
kVMul, kVMul,
...@@ -133,6 +135,13 @@ struct GRUTuples { ...@@ -133,6 +135,13 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
template <typename T>
struct VBroadcastTuples {
typedef T data_type;
typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t);
};
typedef struct seq_pool_attr_s { typedef struct seq_pool_attr_s {
int h, w; // h should always be the first one int h, w; // h should always be the first one
SeqPoolType type; SeqPoolType type;
......
...@@ -24,6 +24,11 @@ size_t JitCodeKey<int>(const int& d) { ...@@ -24,6 +24,11 @@ size_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
template <>
size_t JitCodeKey<int64_t>(const int64_t& d) {
return d;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation // TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) { static inline int act_type_convert(KernelType type) {
......
...@@ -9,9 +9,11 @@ USE_JITKERNEL_MORE(kVAdd, mkl) ...@@ -9,9 +9,11 @@ USE_JITKERNEL_MORE(kVAdd, mkl)
USE_JITKERNEL_MORE(kVScal, mkl) USE_JITKERNEL_MORE(kVScal, mkl)
USE_JITKERNEL_MORE(kVExp, mkl) USE_JITKERNEL_MORE(kVExp, mkl)
USE_JITKERNEL_MORE(kVSquare, mkl) USE_JITKERNEL_MORE(kVSquare, mkl)
USE_JITKERNEL_MORE(kVCopy, mkl)
USE_JITKERNEL_MORE(kVSigmoid, mkl) USE_JITKERNEL_MORE(kVSigmoid, mkl)
USE_JITKERNEL_MORE(kVTanh, mkl) USE_JITKERNEL_MORE(kVTanh, mkl)
USE_JITKERNEL_MORE(kSeqPool, mkl) USE_JITKERNEL_MORE(kSeqPool, mkl)
USE_JITKERNEL_MORE(kSoftmax, mkl) USE_JITKERNEL_MORE(kSoftmax, mkl)
USE_JITKERNEL_MORE(kEmbSeqPool, mkl) USE_JITKERNEL_MORE(kEmbSeqPool, mkl)
USE_JITKERNEL_MORE(kSgd, mkl) USE_JITKERNEL_MORE(kSgd, mkl)
USE_JITKERNEL_MORE(kVBroadcast, mkl)
...@@ -154,6 +154,21 @@ bool VSquareKernel<float>::UseMe(const int& d) const { ...@@ -154,6 +154,21 @@ bool VSquareKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
} }
template <>
bool VCopyKernel<float>::UseMe(const int& d) const {
return d > 15;
}
template <>
bool VBroadcastKernel<float>::UseMe(const int64_t& d) const {
return d > 127;
}
template <>
bool VBroadcastKernel<double>::UseMe(const int64_t& attr) const {
return true;
}
template <> template <>
bool VSigmoidKernel<float>::UseMe(const int& d) const { bool VSigmoidKernel<float>::UseMe(const int& d) const {
return d > 7; return d > 7;
...@@ -223,6 +238,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VExp); ...@@ -223,6 +238,7 @@ AWALYS_USE_ME_WITH_DOUBLE(VExp);
AWALYS_USE_ME_WITH_DOUBLE(VSigmoid); AWALYS_USE_ME_WITH_DOUBLE(VSigmoid);
AWALYS_USE_ME_WITH_DOUBLE(VTanh); AWALYS_USE_ME_WITH_DOUBLE(VTanh);
AWALYS_USE_ME_WITH_DOUBLE(VSquare); AWALYS_USE_ME_WITH_DOUBLE(VSquare);
AWALYS_USE_ME_WITH_DOUBLE(VCopy);
AWALYS_USE_ME_WITH_DOUBLE(Softmax); AWALYS_USE_ME_WITH_DOUBLE(Softmax);
#undef AWALYS_USE_ME_WITH_DOUBLE #undef AWALYS_USE_ME_WITH_DOUBLE
...@@ -244,6 +260,8 @@ REGISTER_MKL_KERNEL(kVAdd, VAdd); ...@@ -244,6 +260,8 @@ REGISTER_MKL_KERNEL(kVAdd, VAdd);
REGISTER_MKL_KERNEL(kVScal, VScal); REGISTER_MKL_KERNEL(kVScal, VScal);
REGISTER_MKL_KERNEL(kVExp, VExp); REGISTER_MKL_KERNEL(kVExp, VExp);
REGISTER_MKL_KERNEL(kVSquare, VSquare); REGISTER_MKL_KERNEL(kVSquare, VSquare);
REGISTER_MKL_KERNEL(kVCopy, VCopy);
REGISTER_MKL_KERNEL(kVBroadcast, VBroadcast);
REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid); REGISTER_MKL_KERNEL(kVSigmoid, VSigmoid);
REGISTER_MKL_KERNEL(kVTanh, VTanh); REGISTER_MKL_KERNEL(kVTanh, VTanh);
REGISTER_MKL_KERNEL(kSeqPool, SeqPool); REGISTER_MKL_KERNEL(kSeqPool, SeqPool);
......
...@@ -50,6 +50,13 @@ void VCopy(const T* x, T* y, int n); ...@@ -50,6 +50,13 @@ void VCopy(const T* x, T* y, int n);
template <typename T> template <typename T>
void VAXPY(T a, const T* x, T* y, int n); void VAXPY(T a, const T* x, T* y, int n);
template <typename T>
void VBroadcast(const T* x, T* y, int64_t y_h, int64_t x_len) {
for (int64_t h = 0; h < y_h; ++h) {
VCopy(x, y + h * x_len, x_len);
}
}
template <typename T> template <typename T>
void VSigmoid(const T* x, T* y, int n) { void VSigmoid(const T* x, T* y, int n) {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
...@@ -192,6 +199,7 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples); ...@@ -192,6 +199,7 @@ DECLARE_MKL_KERNEL(VExp, XYNTuples);
DECLARE_MKL_KERNEL(VSigmoid, XYNTuples); DECLARE_MKL_KERNEL(VSigmoid, XYNTuples);
DECLARE_MKL_KERNEL(VTanh, XYNTuples); DECLARE_MKL_KERNEL(VTanh, XYNTuples);
DECLARE_MKL_KERNEL(VSquare, XYNTuples); DECLARE_MKL_KERNEL(VSquare, XYNTuples);
DECLARE_MKL_KERNEL(VCopy, XYNTuples);
DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples); DECLARE_MKL_KERNEL(SeqPool, SeqPoolTuples);
...@@ -201,6 +209,8 @@ DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples); ...@@ -201,6 +209,8 @@ DECLARE_MKL_KERNEL(Softmax, SoftmaxTuples);
DECLARE_MKL_KERNEL(Sgd, SgdTuples); DECLARE_MKL_KERNEL(Sgd, SgdTuples);
DECLARE_MKL_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_MKL_KERNEL #undef DECLARE_MKL_KERNEL
} // namespace mkl } // namespace mkl
......
...@@ -13,6 +13,7 @@ USE_JITKERNEL_REFER(kVAddRelu) ...@@ -13,6 +13,7 @@ USE_JITKERNEL_REFER(kVAddRelu)
USE_JITKERNEL_REFER(kVSub) USE_JITKERNEL_REFER(kVSub)
USE_JITKERNEL_REFER(kVScal) USE_JITKERNEL_REFER(kVScal)
USE_JITKERNEL_REFER(kVAddBias) USE_JITKERNEL_REFER(kVAddBias)
USE_JITKERNEL_REFER(kVCopy)
USE_JITKERNEL_REFER(kVRelu) USE_JITKERNEL_REFER(kVRelu)
USE_JITKERNEL_REFER(kVIdentity) USE_JITKERNEL_REFER(kVIdentity)
USE_JITKERNEL_REFER(kVExp) USE_JITKERNEL_REFER(kVExp)
...@@ -34,3 +35,4 @@ USE_JITKERNEL_REFER(kHMax) ...@@ -34,3 +35,4 @@ USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast)
...@@ -30,6 +30,7 @@ REGISTER_REFER_KERNEL(kVScal, VScal); ...@@ -30,6 +30,7 @@ REGISTER_REFER_KERNEL(kVScal, VScal);
REGISTER_REFER_KERNEL(kVAddBias, VAddBias); REGISTER_REFER_KERNEL(kVAddBias, VAddBias);
REGISTER_REFER_KERNEL(kVRelu, VRelu); REGISTER_REFER_KERNEL(kVRelu, VRelu);
REGISTER_REFER_KERNEL(kVCopy, VCopy);
REGISTER_REFER_KERNEL(kVIdentity, VIdentity); REGISTER_REFER_KERNEL(kVIdentity, VIdentity);
REGISTER_REFER_KERNEL(kVSquare, VSquare); REGISTER_REFER_KERNEL(kVSquare, VSquare);
REGISTER_REFER_KERNEL(kVExp, VExp); REGISTER_REFER_KERNEL(kVExp, VExp);
...@@ -61,4 +62,6 @@ REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); ...@@ -61,4 +62,6 @@ REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd); REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -70,6 +70,20 @@ void VAddBias(const T* a, const T* x, T* y, int n) { ...@@ -70,6 +70,20 @@ void VAddBias(const T* a, const T* x, T* y, int n) {
} }
} }
template <typename T>
void VCopy(const T* x, T* y, int n) {
std::memcpy(y, x, n * sizeof(T));
}
// x shape: (x_len)
// y shape: (h, x_len)
template <typename T>
void VBroadcast(const T* x, T* y, int64_t y_h, int64_t x_len) {
for (int64_t h = 0; h < y_h; ++h) {
VCopy(x, y + h * x_len, x_len);
}
}
template <typename T> template <typename T>
void VRelu(const T* x, T* y, int n) { void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -500,6 +514,7 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples); ...@@ -500,6 +514,7 @@ DECLARE_REFER_KERNEL(VExp, XYNTuples);
DECLARE_REFER_KERNEL(VSigmoid, XYNTuples); DECLARE_REFER_KERNEL(VSigmoid, XYNTuples);
DECLARE_REFER_KERNEL(VTanh, XYNTuples); DECLARE_REFER_KERNEL(VTanh, XYNTuples);
DECLARE_REFER_KERNEL(VSquare, XYNTuples); DECLARE_REFER_KERNEL(VSquare, XYNTuples);
DECLARE_REFER_KERNEL(VCopy, XYNTuples);
// lstm_t*, const lstm_attr_t* // lstm_t*, const lstm_attr_t*
DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples); DECLARE_REFER_KERNEL(LSTMCtHt, LSTMTuples);
...@@ -528,6 +543,8 @@ DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); ...@@ -528,6 +543,8 @@ DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples); DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -26,8 +26,8 @@ limitations under the License. */ ...@@ -26,8 +26,8 @@ limitations under the License. */
DEFINE_double(acc, 1e-5, "Test accuracy threshold."); DEFINE_double(acc, 1e-5, "Test accuracy threshold.");
template <typename T> template <typename T>
void RandomVec(const int n, T* a, const T lower = static_cast<T>(-20.f), void RandomVec(const int n, T* a, const T lower = static_cast<T>(-2.f),
const T upper = static_cast<T>(20.f)) { const T upper = static_cast<T>(2.f)) {
static unsigned int seed = 100; static unsigned int seed = 100;
std::mt19937 rng(seed++); std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1); std::uniform_real_distribution<double> uniform_dist(0, 1);
...@@ -157,6 +157,26 @@ struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> { ...@@ -157,6 +157,26 @@ struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> {
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::VBroadcastTuples<T>, std::vector<T>,
std::vector<T>, int64_t,
typename jit::VBroadcastTuples<T>::attr_type> {
void operator()(const typename jit::VBroadcastTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
int64_t h,
const typename jit::VBroadcastTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(x.size(), static_cast<size_t>(attr));
EXPECT_EQ(yref.size(), x.size() * h);
std::vector<T> y(yref.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
T* y_data = y.data();
tgt(x_data, y_data, h, attr);
ExpectEQ<T>(y_data, yref_data, yref.size());
}
};
template <typename T> template <typename T>
struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> { struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::XYNTuples<T>::func_type tgt, void operator()(const typename jit::XYNTuples<T>::func_type tgt,
...@@ -514,7 +534,7 @@ void TestKernelXRNTuples() { ...@@ -514,7 +534,7 @@ void TestKernelXRNTuples() {
auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>(); auto ref = jit::GetRefer<KT, jit::XRNTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(d); std::vector<T> x(d);
RandomVec<T>(d, x.data(), -2.f, 2.f); RandomVec<T>(d, x.data());
T ref_res; T ref_res;
ref(x.data(), &ref_res, d); ref(x.data(), &ref_res, d);
TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x, TestAllImpls<KT, jit::XRNTuples<T>, PlaceType, std::vector<T>, T>(d, x,
...@@ -532,7 +552,7 @@ void TestKernelXYNTuples() { ...@@ -532,7 +552,7 @@ void TestKernelXYNTuples() {
std::vector<T> x(d), yref(d); std::vector<T> x(d), yref(d);
std::vector<T> xinp(d); // inplace test std::vector<T> xinp(d); // inplace test
RandomVec<T>(d, x.data(), -2.f, 2.f); RandomVec<T>(d, x.data());
std::copy(x.begin(), x.end(), xinp.begin()); std::copy(x.begin(), x.end(), xinp.begin());
const T* x_data = x.data(); const T* x_data = x.data();
...@@ -566,7 +586,7 @@ void TestKernelLSTMTuples() { ...@@ -566,7 +586,7 @@ void TestKernelLSTMTuples() {
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d); std::vector<T> xsrc(4 * d), wp(3 * d), ct_1(d);
std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d); std::vector<T> ct_ref(d), ht_ref(d), checked(2 * d);
RandomVec<T>(4 * d, xsrc.data(), -2.f, 2.f); RandomVec<T>(4 * d, xsrc.data());
RandomVec<T>(3 * d, wp.data(), -1.f, 1.f); RandomVec<T>(3 * d, wp.data(), -1.f, 1.f);
RandomVec<T>(d, ct_1.data(), -1.f, 1.f); RandomVec<T>(d, ct_1.data(), -1.f, 1.f);
// x could be changed after compute, so copy to save src // x could be changed after compute, so copy to save src
...@@ -614,8 +634,8 @@ void TestKernelGRUTuples() { ...@@ -614,8 +634,8 @@ void TestKernelGRUTuples() {
auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>(); auto ref = jit::GetRefer<KT, jit::GRUTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d); std::vector<T> xsrc(3 * d), ht_1(d), ht_ref(d);
RandomVec<T>(3 * d, xsrc.data(), -2.f, 2.f); RandomVec<T>(3 * d, xsrc.data());
RandomVec<T>(d, ht_1.data(), -2.f, 2.f); RandomVec<T>(d, ht_1.data());
// x could be changed after compute, so copy to save src // x could be changed after compute, so copy to save src
std::vector<T> x(xsrc.size()); std::vector<T> x(xsrc.size());
std::copy(xsrc.begin(), xsrc.end(), x.begin()); std::copy(xsrc.begin(), xsrc.end(), x.begin());
...@@ -651,7 +671,7 @@ void TestKernelSeqPoolTuples() { ...@@ -651,7 +671,7 @@ void TestKernelSeqPoolTuples() {
auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SeqPoolTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(h * w), yref(w); std::vector<T> x(h * w), yref(w);
RandomVec<T>(h * w, x.data(), -2.f, 2.f); RandomVec<T>(h * w, x.data());
const T* x_data = x.data(); const T* x_data = x.data();
T* yref_data = yref.data(); T* yref_data = yref.data();
ref(x_data, yref_data, &attr); ref(x_data, yref_data, &attr);
...@@ -676,8 +696,8 @@ void TestKernelMatMulTuples() { ...@@ -676,8 +696,8 @@ void TestKernelMatMulTuples() {
auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>(); auto ref = jit::GetRefer<KT, jit::MatMulTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> a(m * k), b(k * n), c(m * n); std::vector<T> a(m * k), b(k * n), c(m * n);
RandomVec<T>(m * k, a.data(), -2.f, 2.f); RandomVec<T>(m * k, a.data());
RandomVec<T>(k * n, b.data(), -2.f, 2.f); RandomVec<T>(k * n, b.data());
const T* a_data = a.data(); const T* a_data = a.data();
const T* b_data = b.data(); const T* b_data = b.data();
T* c_data = c.data(); T* c_data = c.data();
...@@ -699,7 +719,7 @@ void TestKernelSoftmaxTuples() { ...@@ -699,7 +719,7 @@ void TestKernelSoftmaxTuples() {
auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SoftmaxTuples<T>>();
EXPECT_TRUE(ref != nullptr); EXPECT_TRUE(ref != nullptr);
std::vector<T> x(bs * n), y(bs * n); std::vector<T> x(bs * n), y(bs * n);
RandomVec<T>(bs * n, x.data(), -2.f, 2.f); RandomVec<T>(bs * n, x.data());
const T* x_data = x.data(); const T* x_data = x.data();
T* y_data = y.data(); T* y_data = y.data();
...@@ -726,7 +746,7 @@ void TestKernelEmbSeqPoolTuples() { ...@@ -726,7 +746,7 @@ void TestKernelEmbSeqPoolTuples() {
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000));
for (int tbl_w : test_sizes) { for (int tbl_w : test_sizes) {
std::vector<T> table(tbl_h * tbl_w); std::vector<T> table(tbl_h * tbl_w);
RandomVec<T>(tbl_h * tbl_w, table.data(), -2.f, 2.f); RandomVec<T>(tbl_h * tbl_w, table.data());
const T* table_data = table.data(); const T* table_data = table.data();
for (auto type : pool_types) { for (auto type : pool_types) {
for (int idx_w : {1, 2, 10, 16}) { for (int idx_w : {1, 2, 10, 16}) {
...@@ -772,14 +792,14 @@ void TestKernelSgdTuples() { ...@@ -772,14 +792,14 @@ void TestKernelSgdTuples() {
for (int grad_w : TestSizes()) { for (int grad_w : TestSizes()) {
std::vector<T> param(param_h * grad_w); std::vector<T> param(param_h * grad_w);
std::vector<T> param_out(param_h * grad_w); std::vector<T> param_out(param_h * grad_w);
RandomVec<T>(param_h * grad_w, param.data(), -2.f, 2.f); RandomVec<T>(param_h * grad_w, param.data());
const T* param_data = param.data(); const T* param_data = param.data();
T* out_data = param_out.data(); T* out_data = param_out.data();
for (int rows_size = 1; rows_size <= param_h; ++rows_size) { for (int rows_size = 1; rows_size <= param_h; ++rows_size) {
std::vector<T> grad(rows_size * grad_w); std::vector<T> grad(rows_size * grad_w);
std::vector<int64_t> rows = std::vector<int64_t> rows =
UnDuplicatedRandomVec(rows_size, 0, rows_size - 1); UnDuplicatedRandomVec(rows_size, 0, rows_size - 1);
RandomVec<T>(rows_size * grad_w, grad.data(), -2.f, 2.f); RandomVec<T>(rows_size * grad_w, grad.data());
const int64_t* rows_data = rows.data(); const int64_t* rows_data = rows.data();
const T* grad_data = grad.data(); const T* grad_data = grad.data();
auto ref = jit::GetRefer<KT, jit::SgdTuples<T>>(); auto ref = jit::GetRefer<KT, jit::SgdTuples<T>>();
...@@ -815,8 +835,8 @@ void TestKernelNCHW16CMulNCTuples() { ...@@ -815,8 +835,8 @@ void TestKernelNCHW16CMulNCTuples() {
int sz = n * c * h * w; int sz = n * c * h * w;
std::vector<T> x(sz), y(n * c), zref(sz); std::vector<T> x(sz), y(n * c), zref(sz);
std::vector<T> ztgt(sz), zjit(sz); std::vector<T> ztgt(sz), zjit(sz);
RandomVec<T>(sz, x.data(), -2.f, 2.f); RandomVec<T>(sz, x.data());
RandomVec<T>(n * c, y.data(), -2.f, 2.f); RandomVec<T>(n * c, y.data());
const T* x_data = x.data(); const T* x_data = x.data();
const T* y_data = y.data(); const T* y_data = y.data();
...@@ -873,11 +893,11 @@ void TestKernelLayerNormTuples() { ...@@ -873,11 +893,11 @@ void TestKernelLayerNormTuples() {
int sz = left * right; int sz = left * right;
std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right), std::vector<T> x(sz), mean(left), var(left), scale(right), bias(right),
outref(sz); outref(sz);
RandomVec<T>(sz, x.data(), -2.f, 2.f); RandomVec<T>(sz, x.data());
RandomVec<T>(left, mean.data(), -2.f, 2.f); RandomVec<T>(left, mean.data());
RandomVec<T>(left, var.data(), -2.f, 2.f); RandomVec<T>(left, var.data());
RandomVec<T>(right, scale.data(), -2.f, 2.f); RandomVec<T>(right, scale.data());
RandomVec<T>(right, bias.data(), -2.f, 2.f); RandomVec<T>(right, bias.data());
const T* scale_data = scale.data(); const T* scale_data = scale.data();
const T* bias_data = bias.data(); const T* bias_data = bias.data();
...@@ -903,7 +923,7 @@ void TestKernelCRFDecodingTuples() { ...@@ -903,7 +923,7 @@ void TestKernelCRFDecodingTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT); VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
constexpr int state_trans_base_idx = 2; constexpr int state_trans_base_idx = 2;
auto test_sizes = TestSizes(); auto test_sizes = TestSizes();
test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 1000)); test_sizes.erase(std::remove(test_sizes.begin(), test_sizes.end(), 2000));
for (int seq_len : {1, 11, 17, 50}) { for (int seq_len : {1, 11, 17, 50}) {
for (int tag_num : test_sizes) { for (int tag_num : test_sizes) {
auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>(); auto ref = jit::GetRefer<KT, jit::CRFDecodingTuples<T>>();
...@@ -912,8 +932,8 @@ void TestKernelCRFDecodingTuples() { ...@@ -912,8 +932,8 @@ void TestKernelCRFDecodingTuples() {
int w_sz = (tag_num + state_trans_base_idx) * tag_num; int w_sz = (tag_num + state_trans_base_idx) * tag_num;
std::vector<T> x(x_sz), w(w_sz), alpharef(x_sz); std::vector<T> x(x_sz), w(w_sz), alpharef(x_sz);
std::vector<int> trackref(x_sz); std::vector<int> trackref(x_sz);
RandomVec<T>(x_sz, x.data(), -2.f, 2.f); RandomVec<T>(x_sz, x.data());
RandomVec<T>(w_sz, w.data(), -2.f, 2.f); RandomVec<T>(w_sz, w.data());
ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(), ref(seq_len, (const T*)x.data(), (const T*)w.data(), alpharef.data(),
trackref.data(), tag_num); trackref.data(), tag_num);
...@@ -926,6 +946,27 @@ void TestKernelCRFDecodingTuples() { ...@@ -926,6 +946,27 @@ void TestKernelCRFDecodingTuples() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void TestKernelVBroadcastTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int w : TestSizes()) {
std::vector<T> x(w);
RandomVec<T>(w, x.data());
const T* x_data = x.data();
for (int64_t h : {1, 2, 6}) {
auto ref = jit::GetRefer<KT, jit::VBroadcastTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> y(w * h);
T* y_data = y.data();
ref(x_data, y_data, h, w);
TestAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, int64_t>(static_cast<int64_t>(w), x, y, h,
static_cast<int64_t>(w));
}
}
}
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \ #define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \ TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \ TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
...@@ -949,6 +990,7 @@ TEST_CPU_KERNEL(XYNTuples, kVSquare); ...@@ -949,6 +990,7 @@ TEST_CPU_KERNEL(XYNTuples, kVSquare);
TEST_CPU_KERNEL(XYNTuples, kVExp); TEST_CPU_KERNEL(XYNTuples, kVExp);
TEST_CPU_KERNEL(XYNTuples, kVSigmoid); TEST_CPU_KERNEL(XYNTuples, kVSigmoid);
TEST_CPU_KERNEL(XYNTuples, kVTanh); TEST_CPU_KERNEL(XYNTuples, kVTanh);
TEST_CPU_KERNEL(XYNTuples, kVCopy);
TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt); TEST_CPU_KERNEL(LSTMTuples, kLSTMCtHt);
TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1); TEST_CPU_KERNEL(LSTMTuples, kLSTMC1H1);
...@@ -966,6 +1008,7 @@ TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool); ...@@ -966,6 +1008,7 @@ TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool);
TEST_CPU_KERNEL(SgdTuples, kSgd); TEST_CPU_KERNEL(SgdTuples, kSgd);
TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm); TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm);
TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding); TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding);
TEST_CPU_KERNEL(VBroadcastTuples, kVBroadcast);
TEST(JITKernel_key, lstm) { TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册