From 9131a35676ce36e0aa943567c60fa39a024ef6c9 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Tue, 9 Oct 2018 22:38:26 +0800 Subject: [PATCH] replace the lstm compute with jitkernel test=develop --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/fusion_lstm_op.cc | 50 ++++++--------- paddle/fluid/operators/math/CMakeLists.txt | 8 +-- .../fluid/operators/math/cpu_lstm_compute.cc | 43 ------------- .../fluid/operators/math/cpu_lstm_compute.h | 64 ------------------- 5 files changed, 22 insertions(+), 145 deletions(-) delete mode 100644 paddle/fluid/operators/math/cpu_lstm_compute.cc delete mode 100644 paddle/fluid/operators/math/cpu_lstm_compute.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 2ef13b72ed3..4d8dd0df195 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -299,7 +299,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) -op_library(fusion_lstm_op DEPS cpu_lstm_compute) +op_library(fusion_lstm_op DEPS jit_kernel) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(layer_norm_op DEPS cub) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index ae1f6d8e489..abaa9237c09 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,9 +15,9 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" +#include "paddle/fluid/operators/math/jit_kernel.h" #include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/platform/cpu_info.h" @@ -309,11 +309,6 @@ class FuisonLSTMKernel : public framework::OpKernel { act_gate(D, gates + D3, gates + D3); \ GET_Ht(ct, gates, ht) -#define COMPUTE_CtHt(gates, ct_1, ct, ht) \ - act_gate(D3, gates + D, gates + D); \ - GET_Ct(ct_1, gates, ct); \ - GET_Ht(ct, gates, ht) - #define COMPUTE_CtHt_PEEPHOLE(gates, ct_1, ct, ht) \ /* get fgated and igated*/ \ blas.VMUL(D, wc_data, ct_1, checked_cell_data); \ @@ -403,22 +398,18 @@ class FuisonLSTMKernel : public framework::OpKernel { } } } else { - // TODO(TJ): unly workaround, clean me - std::function compute_ctht; - if (platform::jit::MayIUse(platform::jit::avx) && - act_gate_str == "sigmoid" && act_cand_str == "tanh" && - act_cell_str == "tanh" && D == 8) { - compute_ctht = math::lstm_compute_ctht; - } else { - compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { - COMPUTE_CtHt(gates, ct_1, ct, ht); - }; - } + const auto& ker = + math::jitkernel::KernelPool::Instance() + .template Get, int, + const std::string&, const std::string&, + const std::string&>(D, act_gate_str, act_cand_str, + act_cell_str); + for (int i = 0; i < N; ++i) { PROCESS_H0C0 for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); - compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data); + ker->ComputeCtHt(xx_data, prev_c_data, c_out_data, h_out_data); MOVE_ONE_STEP; } } @@ -552,24 +543,20 @@ class FuisonLSTMKernel : public framework::OpKernel { MOVE_ONE_STEP; } } else { - // TODO(TJ): unly workaround, clean me - std::function compute_ctht; - if (platform::jit::MayIUse(platform::jit::avx) && - act_gate_str == "sigmoid" && act_cand_str == "tanh" && - act_cell_str == "tanh" && D == 8) { - compute_ctht = math::lstm_compute_ctht; - } else { - compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { - COMPUTE_CtHt(gates, ct_1, ct, ht); - }; - } + const auto& ker = + math::jitkernel::KernelPool::Instance() + .template Get, int, + const std::string&, const std::string&, + const std::string&>(D, act_gate_str, act_cand_str, + act_cell_str); + for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); DEFINE_CUR; for (int i = 0; i < cur_bs; ++i) { - compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data, - cur_h_out_data); + ker->ComputeCtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + cur_h_out_data); MOVE_ONE_BATCH; } MOVE_ONE_STEP; @@ -595,7 +582,6 @@ class FuisonLSTMKernel : public framework::OpKernel { } #undef COMPUTE_CtHt_PEEPHOLE -#undef COMPUTE_CtHt #undef GET_Ct_NOH0C0 #undef COMPUTE_CtHt_NOH0C0 #undef COMPUTE_CtHt_PEEPHOLE_NOH0C0 diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index 16e1dc40f16..b859636f762 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -45,8 +45,6 @@ math_library(im2col) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) -# TODO(TJ): ugly workaround, clean me -cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info) endif (NOT WIN32) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) @@ -76,7 +74,7 @@ if(WITH_GPU) endif() cc_test(concat_test SRCS concat_test.cc DEPS concat) cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info) -cc_library(jit_kernel_exp SRCS jit_kernel_exp.cc DEPS cpu_info cblas activation_functions) -cc_library(jit_kernel_lstm SRCS jit_kernel_lstm.cc DEPS cpu_info cblas activation_functions) -cc_library(jit_kernel SRCS jit_kernel.cc jit_kernel_blas.cc DEPS cpu_info cblas jit_kernel_exp jit_kernel_lstm) +cc_library(jit_kernel + SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_lstm.cc + DEPS cpu_info cblas activation_functions) cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel) diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc deleted file mode 100644 index e96d1879331..00000000000 --- a/paddle/fluid/operators/math/cpu_lstm_compute.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/math/cpu_lstm_compute.h" - -namespace paddle { -namespace operators { -namespace math { -#ifdef __AVX__ -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht) { - namespace act = detail::forward::avx; - // gates: W_ch, W_ih, W_fh, W_oh - __m256 c, i, f, o; - c = _mm256_loadu_ps(gates); - i = _mm256_loadu_ps(gates + 8); - f = _mm256_loadu_ps(gates + 16); - o = _mm256_loadu_ps(gates + 24); - - /* C_t = C_t-1 * fgated + cand_gated * igated*/ - c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); - i = _mm256_loadu_ps(ct_1); - f = _mm256_mul_ps(i, act::Sigmoid(f)); - f = _mm256_add_ps(c, f); - _mm256_storeu_ps(ct, f); - - /* H_t = act_cell(C_t) * ogated */ - o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); - _mm256_storeu_ps(ht, o); -} -#endif -} // namespace math -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h deleted file mode 100644 index 169a9e4b47f..00000000000 --- a/paddle/fluid/operators/math/cpu_lstm_compute.h +++ /dev/null @@ -1,64 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at -http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/operators/math/cpu_vec.h" -#include "paddle/fluid/platform/cpu_info.h" -#ifdef __AVX__ -#include -#endif - -namespace paddle { -namespace operators { -namespace math { - -// TODO(TJ): ugly workaround, clean me -template -void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { - // gates: W_ch, W_ih, W_fh, W_oh - vec_sigmoid(24, gates + 8, gates + 8); - vec_tanh(8, gates, gates); - const T *i = gates + 8, *f = gates + 16, *o = gates + 24; - const T min = SIGMOID_THRESHOLD_MIN; - const T max = SIGMOID_THRESHOLD_MAX; - for (int d = 0; d < 8; ++d) { - // C_t = C_t-1 * fgated + cand_gated * igated - ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; - // H_t = act_cell(C_t) * ogated - T tmp = ct[d] * 2; - tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); - vec_exp(1, &tmp, &tmp); - tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); - ht[d] = tmp * o[d]; - } -} - -#ifdef __AVX__ -namespace detail { -namespace forward { -namespace avx { -__m256 Sigmoid(const __m256 a); -__m256 Tanh(const __m256 a); - -} // namespace avx -} // namespace forward -} // namespace detail - -template <> -void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, - float* ht); - -#endif - -} // namespace math -} // namespace operators -} // namespace paddle -- GitLab