提交 64a90b2f 编写于 作者: T tensor-tang

use vadd, vaddrelu, lstm and gru jitkernel

上级 3713d08d
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_gru_op.h" #include "paddle/fluid/operators/fused/fusion_gru_op.h"
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -192,14 +192,16 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -192,14 +192,16 @@ class FusionGRUKernel : public framework::OpKernel<T> {
const int M = x_dims[1]; \ const int M = x_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const math::jitkernel::gru_attr_t attr( \ const jit::gru_attr_t attr( \
D, ctx.Attr<std::string>("gate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("activation")); \ jit::to_kerneltype(ctx.Attr<std::string>("activation"))); \
math::jitkernel::gru_t one_step; \ jit::gru_t one_step; \
const auto& ker = \ auto ComputeH1 = \
math::jitkernel::KernelPool::Instance() \ jit::Get<jit::gruh1, jit::GRUTuples, platform::CPUPlace>(attr); \
.template Get<math::jitkernel::GRUKernel<T>, \ auto ComputeHtPart1 = \
const math::jitkernel::gru_attr_t&>(attr); \ jit::Get<jit::gruhtpart1, jit::GRUTuples, platform::CPUPlace>(attr); \
auto ComputeHtPart2 = \
jit::Get<jit::gruhtpart2, jit::GRUTuples, platform::CPUPlace>(attr); \
const T* x_data = x->data<T>(); \ const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \ const T* wx_data = wx->data<T>(); \
const T* wh_data = wh->data<T>(); \ const T* wh_data = wh->data<T>(); \
...@@ -242,7 +244,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -242,7 +244,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} else { } else {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -255,12 +257,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -255,12 +257,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ht_1 = prev_hidden_data; one_step.ht_1 = prev_hidden_data;
one_step.ht = hidden_out_data; one_step.ht = hidden_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -324,7 +326,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -324,7 +326,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeH1(&one_step, &attr); ComputeH1(&one_step, &attr);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -352,7 +354,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -352,7 +354,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart1(&one_step, &attr); ComputeHtPart1(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
...@@ -370,7 +372,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -370,7 +372,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
one_step.gates = cur_batched_data; one_step.gates = cur_batched_data;
one_step.ht_1 = cur_prev_hidden_data; one_step.ht_1 = cur_prev_hidden_data;
one_step.ht = cur_out_data; one_step.ht = cur_out_data;
ker->ComputeHtPart2(&one_step, &attr); ComputeHtPart2(&one_step, &attr);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_lstm_op.h" #include "paddle/fluid/operators/fused/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
namespace paddle { namespace paddle {
...@@ -250,17 +250,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -250,17 +250,19 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \ auto* checked_cell = ctx.Output<Tensor>("CheckedCell"); \
checked_cell_data = checked_cell->mutable_data<T>(place); \ checked_cell_data = checked_cell->mutable_data<T>(place); \
} \ } \
const math::jitkernel::lstm_attr_t attr( \ const jit \
D, ctx.Attr<std::string>("gate_activation"), \ : lstm_attr_t attr( \
ctx.Attr<std::string>("candidate_activation"), \ D, jit::to_kerneltype(ctx.Attr<std::string>("gate_activation")), \
ctx.Attr<std::string>("cell_activation"), use_peepholes); \ jit::to_kerneltype(ctx.Attr<std::string>("candidate_activation")), \
jit::to_kerneltype(ctx.Attr<std::string>("cell_activation")), \
use_peepholes); \
math::jitkernel::lstm_t one_step; \ math::jitkernel::lstm_t one_step; \
one_step.wp = wp_data; \ one_step.wp = wp_data; \
one_step.checked = checked_cell_data; \ one_step.checked = checked_cell_data; \
const auto& ker = \ auto ComputeC1H1 = \
math::jitkernel::KernelPool::Instance() \ jit::Get<jit::lstmc1h1, jit::LSTMTuples, platform::CPUPlace>(attr); \
.template Get<math::jitkernel::LSTMKernel<T>, \ auto ComputeCtHt = \
const math::jitkernel::lstm_attr_t&>(attr) jit::Get<jit::lstmctht, jit::LSTMTuples, platform::CPUPlace>(attr)
// Wh GEMM // Wh GEMM
#define GEMM_WH_ADDON(bs, prev, out) \ #define GEMM_WH_ADDON(bs, prev, out) \
...@@ -306,7 +308,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -306,7 +308,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = xx_data; one_step.gates = xx_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
tstart = 1; tstart = 1;
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
...@@ -322,7 +324,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -322,7 +324,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = prev_c_data; one_step.ct_1 = prev_c_data;
one_step.ct = c_out_data; one_step.ct = c_out_data;
one_step.ht = h_out_data; one_step.ht = h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeCtHt(&one_step, &attr);
// move one step // move one step
prev_h_data = h_out_data; prev_h_data = h_out_data;
prev_c_data = c_out_data; prev_c_data = c_out_data;
...@@ -402,7 +404,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -402,7 +404,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.gates = cur_in_data; one_step.gates = cur_in_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeC1H1(&one_step, &attr); ComputeC1H1(&one_step, &attr);
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
...@@ -432,7 +434,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -432,7 +434,7 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
one_step.ct_1 = cur_prev_c_data; one_step.ct_1 = cur_prev_c_data;
one_step.ct = cur_c_out_data; one_step.ct = cur_c_out_data;
one_step.ht = cur_h_out_data; one_step.ht = cur_h_out_data;
ker->ComputeCtHt(&one_step, &attr); ComputeC1H1(&one_step, &attr);
// move one batch // move one batch
cur_in_data += D4; cur_in_data += D4;
......
...@@ -73,12 +73,3 @@ if(WITH_GPU) ...@@ -73,12 +73,3 @@ if(WITH_GPU)
endif() endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
# set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc jit_kernel_layer_norm.cc)
# set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
# if(WITH_XBYAK)
# list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
# list(APPEND JIT_KERNEL_DEPS xbyak)
# endif()
# cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
# cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/jit/kernels.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,22 +30,20 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M, ...@@ -30,22 +30,20 @@ inline void FCCompute(const BlasT<DeviceContext, T>& blas, const int M,
return; return;
} }
if (relu) { if (relu) {
const auto& vaddrelu = jitkernel::KernelPool::Instance() auto compute =
.template Get<jitkernel::VAddReluKernel<T>>(N); jit::Get<jit::vaddrelu, jit::XYZNTuples, platform::CPUPlcace>(N);
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
vaddrelu->Compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} else { } else {
const auto& vadd = jitkernel::KernelPool::Instance() auto compute = jit::Get<jit::vadd, jit::XYZNTuples, platform::CPUPlcace>(N);
.template Get<jitkernel::VAddKernel<T>>(N);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
#endif #endif
for (int i = 0; i < M; i++) { for (int i = 0; i < M; i++) {
T* dst = Y + i * N; T* dst = Y + i * N;
vadd->Compute(B, dst, dst, N); compute(B, dst, dst, N);
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册