diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 7ec1e78da4ec642cb1e6248edfbcfed748fa11b8..ccb7fa1f8cce8cc757038904bce762af3b5ff30b 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -296,6 +296,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) +op_library(fusion_lstm_op DEPS cpu_lstm_compute) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 55e465e3af08c012b8cff7714452ed32b32a5556..8ca79d20ec4f6412b00dbf3990068f81b65e2efd 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" @@ -269,7 +270,6 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ wh_data, D4, static_cast(1), out, D4) -// gates: W_ch, W_ih, W_fh, W_oh #define GET_Ct(ct_1, gates, ct) \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ act_cand(D, gates, gates); \ @@ -395,11 +395,22 @@ class FuisonLSTMKernel : public framework::OpKernel { } } } else { + // TODO(TJ): unly workaround, clean me + std::function compute_ctht; + if (platform::jit::MayIUse(platform::jit::avx) && + act_gate_str == "sigmoid" && act_cand_str == "tanh" && + act_cell_str == "tanh" && D == 8) { + compute_ctht = math::lstm_compute_ctht; + } else { + compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { + COMPUTE_CtHt(gates, ct_1, ct, ht); + }; + } for (int i = 0; i < N; ++i) { PROCESS_H0C0 for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); - COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); + compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data); MOVE_ONE_STEP; } } @@ -532,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel { MOVE_ONE_STEP; } } else { + // TODO(TJ): unly workaround, clean me + std::function compute_ctht; + if (platform::jit::MayIUse(platform::jit::avx) && + act_gate_str == "sigmoid" && act_cand_str == "tanh" && + act_cell_str == "tanh" && D == 8) { + compute_ctht = math::lstm_compute_ctht; + } else { + compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) { + COMPUTE_CtHt(gates, ct_1, ct, ht); + }; + } for (int step = tstart; step < max_seq_len; ++step) { const int cur_bs = batch_starts[step + 1] - batch_starts[step]; GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); DEFINE_CUR; for (int i = 0; i < cur_bs; ++i) { - COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, + compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data, cur_h_out_data); MOVE_ONE_BATCH; } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d7f0f3c6280db7d121bf8821ec6d578e22a33da6..91101356436c26171eaca2fe01dfd4d937e71717 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -45,6 +45,8 @@ math_library(im2col) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) +# TODO(TJ): ugly workaround, clean me +cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas cpu_info) endif (NOT WIN32) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..f7c55c215bacdafc99da5fcd0b750a058dfed21c --- /dev/null +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/cpu_lstm_compute.h" +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" +#ifdef __AVX__ +#include +#endif + +namespace paddle { +namespace operators { +namespace math { + +// TODO(TJ): ugly workaround, clean me +template +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { + // gates: W_ch, W_ih, W_fh, W_oh + vec_sigmoid(24, gates + 8, gates + 8); + vec_tanh(8, gates, gates); + const T *i = gates + 8, *f = gates + 16, *o = gates + 24; + const T min = SIGMOID_THRESHOLD_MIN; + const T max = SIGMOID_THRESHOLD_MAX; + for (int d = 0; d < 8; ++d) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; + // H_t = act_cell(C_t) * ogated + T tmp = ct[d] * 2; + tmp = static_cast(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp)); + vec_exp(1, &tmp, &tmp); + tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); + ht[d] = tmp * o[d]; + } +} + +#ifdef __AVX__ +namespace detail { +namespace forward { +namespace avx { +__m256 Sigmoid(const __m256 a); +__m256 Tanh(const __m256 a); +} // namespace avx +} // namespace forward +} // namespace detail + +template <> +void lstm_compute_ctht(float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif + +template void lstm_compute_ctht(float* gates, const float* ct_1, + float* ct, float* ht); +template void lstm_compute_ctht(double* gates, const double* ct_1, + double* ct, double* ht); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..244164f08c4bb70833a9bfc884982a4225945bf0 --- /dev/null +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +namespace paddle { +namespace operators { +namespace math { + +// TODO(TJ): ugly workaround, clean me +template +void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_vec.h b/paddle/fluid/operators/math/cpu_vec.h index 9560e3a3c15ca63892fbe3552679a22f027f11e2..6a059968b79189458349e466079cc7a663a8e5ff 100644 --- a/paddle/fluid/operators/math/cpu_vec.h +++ b/paddle/fluid/operators/math/cpu_vec.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/platform/cpu_info.h" +#include "paddle/fluid/platform/enforce.h" #ifdef __AVX__ #include #endif @@ -476,7 +477,7 @@ class VecActivations { } else if (type == "identity" || type == "") { return vec_identity; } - LOG(FATAL) << "Not support type: " << type; + PADDLE_THROW("Not support type: %s", type); } };