提交 2e96da45 编写于 作者: T tensor-tang 提交者: ceci3

add vbroadcast jitkernel refer code and use it

test=develop
上级 02054094
...@@ -22,7 +22,6 @@ limitations under the License. */ ...@@ -22,7 +22,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/jit/kernels.h" #include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor { ...@@ -47,7 +46,7 @@ struct EmbeddingVSumFunctor {
auto *output = output_t->mutable_data<T>(context.GetPlace()); auto *output = output_t->mutable_data<T>(context.GetPlace());
PADDLE_ENFORCE_LE(table_width * idx_width, out_width); PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
PADDLE_ENFORCE_GT(ids_lod.size(), 1UL); PADDLE_ENFORCE_GT(ids_lod.size(), 1UL, "The LoD[0] could NOT be empty");
jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width, jit::emb_seq_pool_attr_t attr(table_height, table_width, 0, idx_width,
out_width, jit::SeqPoolType::kSum); out_width, jit::SeqPoolType::kSum);
...@@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> { ...@@ -83,11 +82,11 @@ class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims()); FusedEmbeddingSeqPoolLastDim(table_var->dims(), ids_t->dims());
const auto &ids_lod = ids_t->lod(); const auto &ids_lod = ids_t->lod();
// in run time, the LoD of ids must be 1 // in run time, the LoD of ids must be 1
PADDLE_ENFORCE(ids_lod.size(), 1u, "The LoD level of Input(Ids) must be 1"); PADDLE_ENFORCE(ids_lod.size(), 1UL,
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty"); "The LoD level of Input(Ids) must be 1");
int64_t batch_size = ids_lod[0].size() - 1; int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output // in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size] // should be [seq_length, 1] -> [batch_size, last_dim]
output_t->Resize({batch_size, last_dim}); output_t->Resize({batch_size, last_dim});
if (combiner_type == "sum") { if (combiner_type == "sum") {
...@@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -125,7 +124,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
auto lod = ids->lod()[0]; auto lod = ids->lod()[0];
int64_t row_width = d_output->dims()[1]; int64_t out_width = d_output->dims()[1];
framework::Vector<int64_t> *new_rows = d_table->mutable_rows(); framework::Vector<int64_t> *new_rows = d_table->mutable_rows();
new_rows->resize(ids_num); new_rows->resize(ids_num);
...@@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -136,15 +135,13 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace()); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>(); const T *d_output_data = d_output->data<T>();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto vbroadcast = jit::Get<jit::kVBroadcast, jit::VBroadcastTuples<T>,
platform::CPUPlace>(out_width);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]); int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i] * row_width; const T *src = d_output_data + i * out_width;
const T *out_pos = d_output_data + i * row_width; T *dst = d_table_data + lod[i] * out_width;
T *in_pos = d_table_data + in_offset; vbroadcast(src, dst, h, out_width);
for (int r = 0; r != h; ++r) {
blas.VCOPY(row_width, out_pos, in_pos + r * row_width);
}
} }
} else { } else {
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now"; LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
......
...@@ -474,6 +474,24 @@ void BenchCRFDecodingKernel() { ...@@ -474,6 +474,24 @@ void BenchCRFDecodingKernel() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void BenchVBroadcastKernel() {
for (int w : TestSizes()) {
Tensor x;
x.Resize({w});
RandomVec<T>(w, x.mutable_data<T>(PlaceType()));
const T* x_data = x.data<T>();
for (int64_t h : {1, 3, 6}) {
Tensor y;
y.Resize({h * w});
T* y_data = y.mutable_data<T>(PlaceType());
BenchAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType>(
static_cast<int64_t>(w), x_data, y_data, h, static_cast<int64_t>(w));
}
}
}
using T = float; using T = float;
using CPUPlace = paddle::platform::CPUPlace; using CPUPlace = paddle::platform::CPUPlace;
...@@ -536,6 +554,11 @@ BENCH_FP32_CPU(kCRFDecoding) { ...@@ -536,6 +554,11 @@ BENCH_FP32_CPU(kCRFDecoding) {
BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>(); BenchCRFDecodingKernel<jit::kCRFDecoding, T, CPUPlace>();
} }
// vbroadcast function
BENCH_FP32_CPU(kVBroadcast) {
BenchVBroadcastKernel<jit::kVBroadcast, T, CPUPlace>();
}
// Benchmark all jit kernels including jitcode, mkl and refer. // Benchmark all jit kernels including jitcode, mkl and refer.
// To use this tool, run command: ./benchmark [options...] // To use this tool, run command: ./benchmark [options...]
// Options: // Options:
......
...@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) { ...@@ -36,6 +36,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kVScal); ONE_CASE(kVScal);
ONE_CASE(kVAddBias); ONE_CASE(kVAddBias);
ONE_CASE(kVRelu); ONE_CASE(kVRelu);
ONE_CASE(kVBroadcast);
ONE_CASE(kVCopy); ONE_CASE(kVCopy);
ONE_CASE(kVIdentity); ONE_CASE(kVIdentity);
ONE_CASE(kVExp); ONE_CASE(kVExp);
......
...@@ -41,6 +41,7 @@ typedef enum { ...@@ -41,6 +41,7 @@ typedef enum {
kVAdd, kVAdd,
kVAddBias, kVAddBias,
kVAddRelu, kVAddRelu,
kVBroadcast,
kVCopy, kVCopy,
kVExp, kVExp,
kVIdentity, kVIdentity,
...@@ -134,6 +135,13 @@ struct GRUTuples { ...@@ -134,6 +135,13 @@ struct GRUTuples {
typedef void (*func_type)(gru_t*, const gru_attr_t*); typedef void (*func_type)(gru_t*, const gru_attr_t*);
}; };
template <typename T>
struct VBroadcastTuples {
typedef T data_type;
typedef int64_t attr_type;
typedef void (*func_type)(const T*, T*, int64_t, int64_t);
};
typedef struct seq_pool_attr_s { typedef struct seq_pool_attr_s {
int h, w; // h should always be the first one int h, w; // h should always be the first one
SeqPoolType type; SeqPoolType type;
......
...@@ -24,6 +24,11 @@ size_t JitCodeKey<int>(const int& d) { ...@@ -24,6 +24,11 @@ size_t JitCodeKey<int>(const int& d) {
return d; return d;
} }
template <>
size_t JitCodeKey<int64_t>(const int64_t& d) {
return d;
}
// TODO(TJ): refine and benchmark JitCodeKey generatation // TODO(TJ): refine and benchmark JitCodeKey generatation
constexpr int act_type_shift = 3; // suppot 2^3 act types constexpr int act_type_shift = 3; // suppot 2^3 act types
static inline int act_type_convert(KernelType type) { static inline int act_type_convert(KernelType type) {
......
...@@ -35,3 +35,4 @@ USE_JITKERNEL_REFER(kHMax) ...@@ -35,3 +35,4 @@ USE_JITKERNEL_REFER(kHMax)
USE_JITKERNEL_REFER(kSoftmax) USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool) USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kSgd) USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast)
...@@ -62,4 +62,6 @@ REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool); ...@@ -62,4 +62,6 @@ REGISTER_REFER_KERNEL(kEmbSeqPool, EmbSeqPool);
REGISTER_REFER_KERNEL(kSgd, Sgd); REGISTER_REFER_KERNEL(kSgd, Sgd);
REGISTER_REFER_KERNEL(kVBroadcast, VBroadcast);
#undef REGISTER_REFER_KERNEL #undef REGISTER_REFER_KERNEL
...@@ -75,6 +75,15 @@ void VCopy(const T* x, T* y, int n) { ...@@ -75,6 +75,15 @@ void VCopy(const T* x, T* y, int n) {
std::memcpy(y, x, n * sizeof(T)); std::memcpy(y, x, n * sizeof(T));
} }
// x shape: (x_len)
// y shape: (h, x_len)
template <typename T>
void VBroadcast(const T* x, T* y, int64_t y_h, int64_t x_len) {
for (int64_t h = 0; h < y_h; ++h) {
VCopy(x, y + h * x_len, x_len);
}
}
template <typename T> template <typename T>
void VRelu(const T* x, T* y, int n) { void VRelu(const T* x, T* y, int n) {
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
...@@ -534,6 +543,8 @@ DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples); ...@@ -534,6 +543,8 @@ DECLARE_REFER_KERNEL(EmbSeqPool, EmbSeqPoolTuples);
DECLARE_REFER_KERNEL(Sgd, SgdTuples); DECLARE_REFER_KERNEL(Sgd, SgdTuples);
DECLARE_REFER_KERNEL(VBroadcast, VBroadcastTuples);
#undef DECLARE_REFER_KERNEL #undef DECLARE_REFER_KERNEL
} // namespace refer } // namespace refer
......
...@@ -157,6 +157,26 @@ struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> { ...@@ -157,6 +157,26 @@ struct TestFuncWithRefer<jit::XRNTuples<T>, std::vector<T>, T> {
} }
}; };
template <typename T>
struct TestFuncWithRefer<jit::VBroadcastTuples<T>, std::vector<T>,
std::vector<T>, int64_t,
typename jit::VBroadcastTuples<T>::attr_type> {
void operator()(const typename jit::VBroadcastTuples<T>::func_type tgt,
const std::vector<T>& x, const std::vector<T>& yref,
int64_t h,
const typename jit::VBroadcastTuples<T>::attr_type& attr) {
EXPECT_TRUE(tgt != nullptr);
EXPECT_EQ(x.size(), static_cast<size_t>(attr));
EXPECT_EQ(yref.size(), x.size() * h);
std::vector<T> y(yref.size());
const T* x_data = x.data();
const T* yref_data = yref.data();
T* y_data = y.data();
tgt(x_data, y_data, h, attr);
ExpectEQ<T>(y_data, yref_data, yref.size());
}
};
template <typename T> template <typename T>
struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> { struct TestFuncWithRefer<jit::XYNTuples<T>, std::vector<T>, std::vector<T>> {
void operator()(const typename jit::XYNTuples<T>::func_type tgt, void operator()(const typename jit::XYNTuples<T>::func_type tgt,
...@@ -926,6 +946,27 @@ void TestKernelCRFDecodingTuples() { ...@@ -926,6 +946,27 @@ void TestKernelCRFDecodingTuples() {
} }
} }
template <jit::KernelType KT, typename T, typename PlaceType>
void TestKernelVBroadcastTuples() {
VLOG(10) << "===== Test JITKernel " << jit::to_string(KT);
for (int w : TestSizes()) {
std::vector<T> x(w);
RandomVec<T>(w, x.data());
const T* x_data = x.data();
for (int64_t h : {1, 2, 6}) {
auto ref = jit::GetRefer<KT, jit::VBroadcastTuples<T>>();
EXPECT_TRUE(ref != nullptr);
std::vector<T> y(w * h);
T* y_data = y.data();
ref(x_data, y_data, h, w);
TestAllImpls<KT, jit::VBroadcastTuples<T>, PlaceType, std::vector<T>,
std::vector<T>, int64_t>(static_cast<int64_t>(w), x, y, h,
static_cast<int64_t>(w));
}
}
}
#define TEST_CPU_KERNEL(test_tuple, kernel_type) \ #define TEST_CPU_KERNEL(test_tuple, kernel_type) \
TEST(JITKernel, kernel_type) { \ TEST(JITKernel, kernel_type) { \
TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \ TestKernel##test_tuple<jit::kernel_type, float, CPUPlace>(); \
...@@ -967,6 +1008,7 @@ TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool); ...@@ -967,6 +1008,7 @@ TEST_CPU_KERNEL(EmbSeqPoolTuples, kEmbSeqPool);
TEST_CPU_KERNEL(SgdTuples, kSgd); TEST_CPU_KERNEL(SgdTuples, kSgd);
TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm); TEST_CPU_KERNEL(LayerNormTuples, kLayerNorm);
TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding); TEST_CPU_KERNEL(CRFDecodingTuples, kCRFDecoding);
TEST_CPU_KERNEL(VBroadcastTuples, kVBroadcast);
TEST(JITKernel_key, lstm) { TEST(JITKernel_key, lstm) {
jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh); jit::lstm_attr_t attr1(8, jit::kVIdentity, jit::kVSigmoid, jit::kVTanh);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册