提交 5660d6a3 编写于 作者: M minqiyang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into accelerate_embedding_grad

...@@ -16,10 +16,9 @@ limitations under the License. */ ...@@ -16,10 +16,9 @@ limitations under the License. */
#include <cstring> // for memcpy #include <cstring> // for memcpy
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/jit_kernel.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -174,58 +173,44 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
} }
#define INIT_VEC_FUNC \ #define INIT_BASE_DEFINES \
std::function<void(const int, const T *, T *)> act_gate, act_state; \ auto* x = ctx.Input<LoDTensor>("X"); \
std::function<void(const int, const T*, const T*, const T*, T*)> cross; \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto& act_gate_str = ctx.Attr<std::string>("gate_activation"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto& act_state_str = ctx.Attr<std::string>("activation"); \ auto x_lod = x->lod(); \
if (platform::jit::MayIUse(platform::jit::avx)) { \ auto x_dims = x->dims(); /* T x M*/ \
math::VecActivations<T, platform::jit::avx> act_functor; \ auto wh_dims = wh->dims(); /* D x 3D*/ \
act_gate = act_functor(act_gate_str); \ const int total_T = x_dims[0]; \
act_state = act_functor(act_state_str); \ const int D3 = wh_dims[1]
cross = math::vec_cross<T, platform::jit::avx>; \
} else { \ #define INIT_OTHER_DEFINES \
math::VecActivations<T, platform::jit::isa_any> act_functor; \ auto* h0 = ctx.Input<Tensor>("H0"); \
act_gate = act_functor(act_gate_str); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
act_state = act_functor(act_state_str); \ auto* bias = ctx.Input<Tensor>("Bias"); \
cross = math::vec_cross<T, platform::jit::isa_any>; \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
} bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
#define INIT_BASE_INPUT_OUTPUT \ const int D = wh_dims[0]; \
auto* h0 = ctx.Input<Tensor>("H0"); \ const int D2 = D * 2; \
auto* wx = ctx.Input<Tensor>("WeightX"); \ const auto& ker = math::jitkernel::KernelPool::Instance() \
auto* wh = ctx.Input<Tensor>("WeightH"); \ .template Get<math::jitkernel::GRUKernel<T>, \
auto* bias = ctx.Input<Tensor>("Bias"); \ const std::string&, const std::string&>( \
auto* xx = ctx.Output<LoDTensor>("XX"); \ ctx.Attr<std::string>("gate_activation"), \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ ctx.Attr<std::string>("activation"), D); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); const T* x_data = x->data<T>(); \
const T* wx_data = wx->data<T>(); \
#define INIT_BASE_SIZES \ const T* wh_data = wh->data<T>(); \
auto x_dims = x->dims(); /* T x M*/ \ auto place = ctx.GetPlace(); \
auto wh_dims = wh->dims(); /* D x 3D*/ \ T* xx_data = xx->mutable_data<T>(place)
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT INIT_OTHER_DEFINES;
INIT_BASE_SIZES
INIT_VEC_FUNC
auto x_lod = x->lod();
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(place);
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, xx_data,
...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -252,14 +237,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
} else { } else {
// W: {W_update, W_reset; W_state} ker->ComputeH1(xx_data, hidden_out_data);
// update gate
act_gate(D, xx_data, xx_data);
// state gate
act_state(D, xx_data + D2, xx_data + D2);
// out = a*b
blas.VMUL(D, xx_data, xx_data + D2, hidden_out_data);
// save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
tstart = 1; tstart = 1;
move_step(); move_step();
...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -269,17 +247,12 @@ class FusionGRUKernel : public framework::OpKernel<T> {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D2, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data, prev_hidden_data, D, wh_data, D2, static_cast<T>(1), xx_data,
D3); D3);
act_gate(D2, xx_data, xx_data); ker->ComputeHtPart1(xx_data, prev_hidden_data, hidden_out_data);
// rt = rt*ht_1 inplace result
blas.VMUL(D, prev_hidden_data, xx_data + D, hidden_out_data);
// gemm rt * Ws // gemm rt * Ws
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D, D, static_cast<T>(1),
hidden_out_data, D, wh_state_data, D, static_cast<T>(1), hidden_out_data, D, wh_state_data, D, static_cast<T>(1),
xx_data + D2, D3); xx_data + D2, D3);
act_state(D, xx_data + D2, xx_data + D2); ker->ComputeHtPart2(xx_data, prev_hidden_data, hidden_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, xx_data, xx_data + D2, prev_hidden_data, hidden_out_data);
// save prev // save prev
prev_hidden_data = hidden_out_data; prev_hidden_data = hidden_out_data;
move_step(); move_step();
...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -289,28 +262,19 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); INIT_BASE_DEFINES;
INIT_BASE_INPUT_OUTPUT if (x_lod[0].size() == 2) {
INIT_BASE_SIZES
if (x->lod()[0].size() == 2) {
xx->Resize({total_T, D3}); xx->Resize({total_T, D3});
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
INIT_VEC_FUNC INIT_OTHER_DEFINES;
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut"); auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
T* batched_input_data = batched_input->mutable_data<T>(place);
const T* x_data = x->data<T>(); T* batched_out_data = batched_out->mutable_data<T>(place);
const T* wx_data = wx->data<T>(); hidden_out->mutable_data<T>(place);
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* batched_input_data = batched_input->mutable_data<T>(ctx.GetPlace());
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -336,7 +300,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* prev_hidden_data = nullptr; T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
// reorder h0 // reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace()); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
prev_hidden_data = reordered_h0_data; prev_hidden_data = reordered_h0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -350,12 +314,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
// W: {W_update, W_reset; W_state} // W: {W_update, W_reset; W_state}
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
// update gate ker->ComputeH1(cur_in_data, cur_out_data);
act_gate(D, cur_in_data, cur_in_data);
// state gate
act_state(D, cur_in_data + D2, cur_in_data + D2);
// out = a*b
blas.VMUL(D, cur_in_data, cur_in_data + D2, cur_out_data);
// add offset // add offset
cur_in_data += D3; cur_in_data += D3;
cur_out_data += D; cur_out_data += D;
...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -380,10 +339,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* cur_out_data = batched_out_data; T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data); ker->ComputeHtPart1(cur_batched_data, cur_prev_hidden_data,
// rt = rt*ht_1 inplace result cur_out_data);
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -397,12 +354,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...) ker->ComputeHtPart2(cur_batched_data, cur_prev_hidden_data,
act_state(D, cur_batched_data + D2, cur_batched_data + D2); cur_out_data);
// out = zt*ht~ + (1-zt)*ht_1
cross(D, cur_batched_data, cur_batched_data + D2, cur_prev_hidden_data,
cur_out_data);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D; cur_out_data += D;
...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -416,9 +369,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_out->set_lod(batched_lod); batched_out->set_lod(batched_lod);
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC #undef INIT_OTHER_DEFINES
#undef INIT_BASE_SIZES #undef INIT_BASE_DEFINES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
......
...@@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec ...@@ -68,6 +68,7 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test(im2col_test SRCS im2col_test.cc DEPS im2col) cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding) cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
if(WITH_GPU) if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function) nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function) nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor math_function)
...@@ -75,6 +76,6 @@ endif() ...@@ -75,6 +76,6 @@ endif()
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split) cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
cc_library(jit_kernel cc_library(jit_kernel
SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc
DEPS cpu_info cblas) DEPS cpu_info cblas)
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)
...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel { ...@@ -142,6 +142,15 @@ class LSTMKernel : public Kernel {
const T *wp_data = nullptr) const = 0; const T *wp_data = nullptr) const = 0;
}; };
template <typename T>
class GRUKernel : public Kernel {
public:
// compute h1 without h0
virtual void ComputeH1(T *gates, T *ht) const = 0;
virtual void ComputeHtPart1(T *gates, const T *ht_1, T *ht) const = 0;
virtual void ComputeHtPart2(T *gates, const T *ht_1, T *ht) const = 0;
};
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel( ...@@ -136,6 +136,23 @@ static std::shared_ptr<const VActKernel<T>> GetActKernel(
return nullptr; return nullptr;
} }
#ifdef __AVX__
template <jit::cpu_isa_t isa>
static std::unique_ptr<AVXAct> GetAVXAct(const std::string& type) {
if (type == "sigmoid") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>());
} else if (type == "relu") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>());
} else if (type == "tanh") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>());
} else if (type == "identity" || type == "") {
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>());
}
PADDLE_THROW("Not support type: %s", type);
return nullptr;
}
#endif
/* LSTM JitKernel */ /* LSTM JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block> template <typename T, jit::cpu_isa_t isa, jit_block>
class LSTMKernelImpl : public LSTMKernel<T> { class LSTMKernelImpl : public LSTMKernel<T> {
...@@ -192,61 +209,49 @@ class LSTMKernelImpl : public LSTMKernel<T> { ...@@ -192,61 +209,49 @@ class LSTMKernelImpl : public LSTMKernel<T> {
#endif #endif
}; };
#define INTRI8_FLOAT(isa) \ #define INTRI8_FLOAT(isa) \
template <> \ template <> \
LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \ LSTMKernelImpl<float, isa, kEQ8>::LSTMKernelImpl( \
const std::string& act_gate, const std::string& act_cand, \ const std::string& act_gate, const std::string& act_cand, \
const std::string& act_cell, int d) \ const std::string& act_cell, int d) \
: LSTMKernel<float>() { \ : LSTMKernel<float>() { \
auto GetAVXAct = [&](const std::string& type) -> std::unique_ptr<AVXAct> { \ avx_act_gate_ = GetAVXAct<isa>(act_gate); \
if (type == "sigmoid") { \ avx_act_cand_ = GetAVXAct<isa>(act_cand); \
return std::unique_ptr<AVXAct>(new AVXActImpl<kSigmoid, isa>()); \ avx_act_cell_ = GetAVXAct<isa>(act_cell); \
} else if (type == "relu") { \ } \
return std::unique_ptr<AVXAct>(new AVXActImpl<kRelu, isa>()); \ template <> \
} else if (type == "tanh") { \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \
return std::unique_ptr<AVXAct>(new AVXActImpl<kTanh, isa>()); \ float* gates, const float* ct_1, float* ct, float* ht, \
} else if (type == "identity" || type == "") { \ const float* wp_data, float* checked) const { \
return std::unique_ptr<AVXAct>(new AVXActImpl<kIdentity, isa>()); \ /* gates: W_ch, W_ih, W_fh, W_oh */ \
} \ __m256 c, i, f, o; \
PADDLE_THROW("Not support type: %s", type); \ c = _mm256_loadu_ps(gates); \
}; \ i = _mm256_loadu_ps(gates + 8); \
avx_act_gate_ = GetAVXAct(act_gate); \ f = _mm256_loadu_ps(gates + 16); \
avx_act_cand_ = GetAVXAct(act_cand); \ o = _mm256_loadu_ps(gates + 24); \
avx_act_cell_ = GetAVXAct(act_cell); \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \
} \ c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \
template <> \ i = _mm256_loadu_ps(ct_1); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeCtHt( \ f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \
float* gates, const float* ct_1, float* ct, float* ht, \ f = _mm256_add_ps(c, f); \
const float* wp_data, float* checked) const { \ _mm256_storeu_ps(ct, f); \
/* gates: W_ch, W_ih, W_fh, W_oh */ \ /* H_t = act_cell(C_t) * ogated */ \
__m256 c, i, f, o; \ o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \
c = _mm256_loadu_ps(gates); \ _mm256_storeu_ps(ht, o); \
i = _mm256_loadu_ps(gates + 8); \ } \
f = _mm256_loadu_ps(gates + 16); \ template <> \
o = _mm256_loadu_ps(gates + 24); \ void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
/* C_t = C_t-1 * fgated + cand_gated * igated*/ \ float* gates, float* ct, float* ht, const float* wp_data) const { \
c = _mm256_mul_ps(avx_act_cand_->Compute(c), avx_act_gate_->Compute(i)); \ __m256 c, i, o; \
i = _mm256_loadu_ps(ct_1); \ c = _mm256_loadu_ps(gates); \
f = _mm256_mul_ps(i, avx_act_gate_->Compute(f)); \ i = _mm256_loadu_ps(gates + 8); \
f = _mm256_add_ps(c, f); \ o = _mm256_loadu_ps(gates + 24); \
_mm256_storeu_ps(ct, f); \ /* C_t = igated * cgated*/ \
/* H_t = act_cell(C_t) * ogated */ \ c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
o = _mm256_mul_ps(avx_act_cell_->Compute(f), avx_act_gate_->Compute(o)); \ _mm256_storeu_ps(ct, c); \
_mm256_storeu_ps(ht, o); \ /* H_t = act_cell(C_t) * ogated */ \
} \ o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
template <> \ _mm256_storeu_ps(ht, o); \
void LSTMKernelImpl<float, isa, kEQ8>::ComputeC1H1( \
float* gates, float* ct, float* ht, const float* wp_data) const { \
__m256 c, i, o; \
c = _mm256_loadu_ps(gates); \
i = _mm256_loadu_ps(gates + 8); \
o = _mm256_loadu_ps(gates + 24); \
/* C_t = igated * cgated*/ \
c = _mm256_mul_ps(avx_act_gate_->Compute(i), avx_act_cand_->Compute(c)); \
_mm256_storeu_ps(ct, c); \
/* H_t = act_cell(C_t) * ogated */ \
o = _mm256_mul_ps(avx_act_cell_->Compute(c), avx_act_gate_->Compute(o)); \
_mm256_storeu_ps(ht, o); \
} }
// TODO(TJ): optimize keq16 // TODO(TJ): optimize keq16
...@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM, ...@@ -354,6 +359,126 @@ REGISTER_JITKERNEL_ARGS(lstm, LSTMKernel, JITKERNEL_DECLARE_LSTM,
#undef JITKERNEL_DECLARE_LSTM #undef JITKERNEL_DECLARE_LSTM
#undef JITKERNEL_KEY_LSTM #undef JITKERNEL_KEY_LSTM
#undef JITKERNEL_NEW_LSTM_IMPL #undef JITKERNEL_NEW_LSTM_IMPL
/* GRU JitKernel */
template <typename T, jit::cpu_isa_t isa, jit_block>
class GRUKernelImpl : public GRUKernel<T> {
public:
explicit GRUKernelImpl(const std::string& act_gate,
const std::string& act_state, int d)
: GRUKernel<T>() {
d_ = d;
d2_ = d * 2;
act_gate_d2_ = GetActKernel<T>(act_gate, d2_);
act_gate_d_ = GetActKernel<T>(act_gate, d);
act_state_d_ = GetActKernel<T>(act_state, d);
vmul_d_ = KernelPool::Instance().template Get<VMulKernel<T>>(d);
}
void ComputeH1(T* gates, T* ht) const override {
act_gate_d_->Compute(gates, gates);
act_state_d_->Compute(gates + d2_, gates + d2_);
vmul_d_->Compute(gates, gates + d2_, ht);
}
void ComputeHtPart1(T* gates, const T* ht_1, T* ht) const override {
// W: {W_update, W_reset; W_state}
act_gate_d2_->Compute(gates, gates);
vmul_d_->Compute(ht_1, gates + d_, ht);
}
void ComputeHtPart2(T* gates, const T* ht_1, T* ht) const override {
T* y = gates + d2_;
act_state_d_->Compute(y, y);
// out = zt*ht~ + (1-zt)*ht_1
for (int i = 0; i < d_; ++i) {
ht[i] = gates[i] * y[i] + (static_cast<T>(1) - gates[i]) * ht_1[i];
}
}
private:
int d_, d2_;
std::shared_ptr<const VActKernel<T>> act_gate_d2_, act_gate_d_, act_state_d_;
std::shared_ptr<const VMulKernel<T>> vmul_d_;
#ifdef __AVX__
std::unique_ptr<const AVXAct> avx_act_gate_, avx_act_state_;
#endif
};
#define INTRI8_FLOAT(isa) \
template <> \
GRUKernelImpl<float, isa, kEQ8>::GRUKernelImpl( \
const std::string& act_gate, const std::string& act_state, int d) \
: GRUKernel<float>() { \
avx_act_gate_ = GetAVXAct<isa>(act_gate); \
avx_act_state_ = GetAVXAct<isa>(act_state); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeH1(float* gates, float* ht) \
const { \
__m256 u, s; \
/* W: {W_update, W_reset; W_state} */ \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
s = _mm256_mul_ps(avx_act_gate_->Compute(u), avx_act_state_->Compute(s)); \
_mm256_storeu_ps(ht, s); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart1( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 r, ht0; \
r = _mm256_loadu_ps(gates + 8); \
ht0 = _mm256_loadu_ps(ht_1); \
r = _mm256_mul_ps(avx_act_gate_->Compute(r), ht0); \
_mm256_storeu_ps(ht, r); \
} \
template <> \
void GRUKernelImpl<float, isa, kEQ8>::ComputeHtPart2( \
float* gates, const float* ht_1, float* ht) const { \
/* not exactly equal the any implementation */ \
__m256 u, s, ht0; \
u = _mm256_loadu_ps(gates); \
s = _mm256_loadu_ps(gates + 16); \
ht0 = _mm256_loadu_ps(ht_1); \
u = avx_act_gate_->Compute(u); \
s = _mm256_mul_ps(u, avx_act_state_->Compute(s)); \
u = _mm256_sub_ps(_mm256_set1_ps(1.f), u); \
u = _mm256_mul_ps(u, ht0); \
u = _mm256_add_ps(s, u); \
_mm256_storeu_ps(ht, u); \
}
#ifdef __AVX__
INTRI8_FLOAT(jit::avx);
#endif
#ifdef __AVX2__
INTRI8_FLOAT(jit::avx2);
#endif
#ifdef __AVX512F__
INTRI8_FLOAT(jit::avx512f);
#endif
#define JITKERNEL_DECLARE_GRU(ker_class, ker_dtype) \
template <> \
std::shared_ptr<const GRUKernel<ker_dtype>> KernelPool::Get< \
GRUKernel<ker_dtype>, const std::string&, const std::string&, int>( \
const std::string& act_gate, const std::string& act_state, int d)
#define JITKERNEL_KEY_GRU(ker_key, dtype_key) \
#ker_key #dtype_key + std::to_string(d) + act_gate + act_state
#define JITKERNEL_NEW_GRU_IMPL(ker, dtype, isa, k) \
p = std::dynamic_pointer_cast<ker<dtype>>( \
std::make_shared<ker##Impl<dtype, isa, k>>(act_gate, act_state, d));
REGISTER_JITKERNEL_ARGS(gru, GRUKernel, JITKERNEL_DECLARE_GRU,
JITKERNEL_KEY_GRU, JITKERNEL_NEW_GRU_IMPL);
#undef INTRI8_FLOAT
#undef JITKERNEL_NEW_GRU_IMPL
#undef JITKERNEL_KEY_GRU
#undef JITKERNEL_DECLARE_GRU
} // namespace jitkernel } // namespace jitkernel
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor { ...@@ -157,6 +157,31 @@ class FirstSeqPoolFunctor {
} }
}; };
template <typename T>
class SumSeqPoolGradFunctor {
public:
void operator()(const platform::CPUDeviceContext& context,
const framework::Tensor& out_grad,
framework::LoDTensor* in_grad) {
auto lod = in_grad->lod()[0];
int64_t out_w = out_grad.numel() / out_grad.dims()[0];
int64_t in_w = in_grad->numel() / in_grad->dims()[0];
PADDLE_ENFORCE(in_w == out_w);
const T* out_g_data = out_grad.data<T>();
T* in_g_data = in_grad->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i] * in_w;
const T* out_pos = out_g_data + i * out_w;
T* in_pos = in_g_data + in_offset;
for (int r = 0; r != h; ++r) {
blas.VCOPY(in_w, out_pos, in_pos + r * in_w);
}
}
}
};
template <typename T> template <typename T>
class SequencePoolFunctor<platform::CPUDeviceContext, T> { class SequencePoolFunctor<platform::CPUDeviceContext, T> {
public: public:
...@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -231,9 +256,15 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
math::SetConstant<platform::CPUDeviceContext, T> functor; math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(context, in_grad, 0); functor(context, in_grad, 0);
} }
if (pooltype == "SUM") {
math::SumSeqPoolGradFunctor<T> sum_pool_grad;
sum_pool_grad(context, out_grad, in_grad);
return;
}
auto lod = in_grad->lod()[0]; auto lod = in_grad->lod()[0];
auto& place = *context.eigen_device(); auto& place = *context.eigen_device();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) { for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]), auto in_g_t = in_grad->Slice(static_cast<int>(lod[i]),
static_cast<int>(lod[i + 1])); static_cast<int>(lod[i + 1]));
...@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> { ...@@ -247,12 +278,6 @@ class SequencePoolGradFunctor<platform::CPUDeviceContext, T> {
if (pooltype == "AVERAGE") { if (pooltype == "AVERAGE") {
in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast); in_g_e.device(place) = (out_g_e / static_cast<T>(h)).broadcast(bcast);
} else if (pooltype == "SUM") {
const T* out_g_data = out_g_t.data<T>();
T* in_g_data = in_g_t.mutable_data<T>(context.GetPlace());
for (int r = 0; r != h; ++r) {
blas.VCOPY(w, out_g_data, in_g_data + r * w);
}
} else if (pooltype == "SQRT") { } else if (pooltype == "SQRT") {
in_g_e.device(place) = in_g_e.device(place) =
(out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast); (out_g_e / std::sqrt(static_cast<T>(h))).broadcast(bcast);
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/sequence_pooling.h"
#include <gtest/gtest.h>
#include <vector>
template <typename DeviceContext, typename Place, typename T>
void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
paddle::framework::LoDTensor cpu_out_grad;
paddle::framework::LoDTensor cpu_in_grad;
paddle::framework::LoDTensor out_grad;
paddle::framework::LoDTensor in_grad;
const size_t second_dim = 128u;
// construct out_grad's tensor in cpu
const size_t out_first_dim = lod[0].size() - 1;
auto out_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(out_first_dim), static_cast<int64_t>(second_dim)});
cpu_out_grad.mutable_data<T>(out_dims, paddle::platform::CPUPlace());
for (int64_t i = 0; i < cpu_out_grad.numel(); ++i) {
cpu_out_grad.data<T>()[i] = static_cast<T>(i);
}
// copy to dst out_grad
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
out_grad = cpu_out_grad;
} else {
TensorCopySync(cpu_out_grad, *place, &out_grad);
}
// construct in_grad
in_grad.set_lod(lod);
auto in_dims = paddle::framework::make_ddim(
{static_cast<int64_t>(lod[0].back()), static_cast<int64_t>(second_dim)});
in_grad.mutable_data<T>(in_dims, context->GetPlace());
// check tensor contruction result
PADDLE_ENFORCE_EQ(in_grad.dims().size(), out_grad.dims().size());
for (int64_t i = 1; i < out_grad.dims().size(); ++i) {
PADDLE_ENFORCE_EQ(in_grad.dims()[i], out_grad.dims()[i]);
}
// call functor
paddle::operators::math::SequencePoolGradFunctor<DeviceContext, T>()(
*context, "SUM", out_grad, &in_grad);
if (paddle::platform::is_cpu_place(*place)) {
cpu_in_grad = in_grad;
} else {
TensorCopySync(in_grad, paddle::platform::CPUPlace(), &cpu_in_grad);
cpu_in_grad.set_lod(in_grad.lod());
}
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {
for (int64_t i = 0; i < in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = in_grad.lod()[0][i];
int64_t end = in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
out_grad.data<T>()[m + i * second_dim]);
}
}
}
} else {
for (int64_t i = 0; i < cpu_in_grad.lod()[0].size() - 1; ++i) {
int64_t begin = cpu_in_grad.lod()[0][i];
int64_t end = cpu_in_grad.lod()[0][i + 1];
paddle::framework::Tensor tmp = cpu_in_grad.Slice(begin, end);
for (int64_t j = 0; j != tmp.numel() / second_dim; ++j) {
for (int64_t m = 0; m != second_dim; ++m) {
EXPECT_EQ(tmp.data<T>()[m + j * second_dim],
cpu_out_grad.data<T>()[m + i * second_dim]);
}
}
}
}
delete place;
delete context;
}
TEST(SequencePoolingGrad, CPU_SUM) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod1);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace, float>(lod2);
}
#ifdef PADDLE_WITH_CUDA
TEST(SequencePoolingGrad, CUDA_SUM) {
paddle::framework::LoD lod1;
lod1.push_back(std::vector<size_t>{0, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod1);
paddle::framework::LoD lod2;
lod2.push_back(std::vector<size_t>{0, 2, 7, 10});
TestSequencePoolingSum<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace, float>(lod2);
}
#endif
...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE) ...@@ -78,9 +78,9 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_word2vec PROPERTIES TIMEOUT 200)
py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext) py_test_modules(test_dist_se_resnext MODULES test_dist_se_resnext)
set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_se_resnext PROPERTIES TIMEOUT 1000)
# TODO: fix this test
#py_test_modules(test_dist_transformer MODULES test_dist_transformer) py_test_modules(test_dist_transformer MODULES test_dist_transformer)
#set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
...@@ -35,7 +35,7 @@ import paddle ...@@ -35,7 +35,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
from paddle.fluid import core from paddle.fluid import core
from test_dist_base import TestDistRunnerBase, runtime_main from test_dist_base import TestDistRunnerBase, runtime_main, RUN_STEP
import paddle.compat as cpt import paddle.compat as cpt
from paddle.compat import long_type from paddle.compat import long_type
...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -562,18 +562,12 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
for pass_id in six.moves.xrange(TrainTaskConfig.pass_num): for pass_id in six.moves.xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
if batch_id >= 5: if batch_id >= RUN_STEP:
break break
feed_list = [] feed_list = []
total_num_token = 0 total_num_token = 0
#if TrainTaskConfig.local:
# lr_rate = lr_scheduler.update_learning_rate()
#for place_id, data_buffer in enumerate(
# split_data(
# data, num_part=dev_count)):
if TrainTaskConfig.local: if TrainTaskConfig.local:
lr_rate = lr_scheduler.update_learning_rate() lr_rate = lr_scheduler.update_learning_rate()
...@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler, ...@@ -619,12 +613,11 @@ def train_loop(exe, train_progm, dev_count, sum_cost, avg_cost, lr_scheduler,
init = True init = True
# Validate and save the model for inference. # Validate and save the model for inference.
if batch_id == 0 or batch_id == 4: if TrainTaskConfig.val_file_pattern is not None:
if TrainTaskConfig.val_file_pattern is not None: val_avg_cost, val_ppl = test()
val_avg_cost, val_ppl = test() print("[%f]" % val_avg_cost)
print("[%f]" % val_avg_cost) else:
else: assert (False)
assert (False)
#import transformer_reader as reader #import transformer_reader as reader
...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase): ...@@ -1701,7 +1694,7 @@ class DistTransformer2x2(TestDistRunnerBase):
def run_trainer(self, args): def run_trainer(self, args):
TrainTaskConfig.use_gpu = args.use_cuda TrainTaskConfig.use_gpu = args.use_cuda
sum_cost, avg_cost, predict, token_num, local_lr_scheduler = get_model( sum_cost, avg_cost, predict, token_num, local_lr_scheduler, test_program = get_model(
args.is_dist, not args.sync_mode) args.is_dist, not args.sync_mode)
if args.is_dist: if args.is_dist:
......
...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase): ...@@ -61,7 +61,8 @@ class TestDistTransformer2x2Sync(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1e-5) self.check_with_place(
"dist_transformer.py", delta=1e-5, check_error_log=False)
class TestDistTransformer2x2Async(TestDistBase): class TestDistTransformer2x2Async(TestDistBase):
...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase): ...@@ -70,7 +71,8 @@ class TestDistTransformer2x2Async(TestDistBase):
def test_dist_train(self): def test_dist_train(self):
download_files() download_files()
self.check_with_place("dist_transformer.py", delta=1.0) self.check_with_place(
"dist_transformer.py", delta=1.0, check_error_log=False)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp): ...@@ -125,6 +125,12 @@ class TestFusionGRUOpMD2(TestFusionGRUOp):
self.D = 8 self.D = 8
class TestFusionGRUOpMD3(TestFusionGRUOp):
def set_confs(self):
self.M = 17
self.D = 15
class TestFusionGRUOpBS1(TestFusionGRUOp): class TestFusionGRUOpBS1(TestFusionGRUOp):
def set_confs(self): def set_confs(self):
self.lod = [[3]] self.lod = [[3]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册