From 612ba41aeefe9e7650cce5f9fb0eeab68d9b1eb3 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Thu, 13 Sep 2018 17:06:09 +0800 Subject: [PATCH] add simple lstm compute --- paddle/fluid/operators/CMakeLists.txt | 1 + paddle/fluid/operators/fusion_lstm_op.cc | 15 ++++- paddle/fluid/operators/math/CMakeLists.txt | 2 + .../fluid/operators/math/cpu_lstm_compute.cc | 57 +++++++++++++++++++ .../fluid/operators/math/cpu_lstm_compute.h | 50 ++++++++++++++++ 5 files changed, 123 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/math/cpu_lstm_compute.cc create mode 100644 paddle/fluid/operators/math/cpu_lstm_compute.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 7ec1e78da..ccb7fa1f8 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -296,6 +296,7 @@ op_library(flatten_op DEPS reshape_op) op_library(sequence_pad_op DEPS sequence_padding) op_library(unstack_op DEPS stack_op) op_library(fake_quantize_op DEPS memory) +op_library(fusion_lstm_op DEPS cpu_lstm_compute) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/fusion_lstm_op.cc b/paddle/fluid/operators/fusion_lstm_op.cc index 55e465e3a..6949cf55c 100644 --- a/paddle/fluid/operators/fusion_lstm_op.cc +++ b/paddle/fluid/operators/fusion_lstm_op.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/fusion_lstm_op.h" #include #include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/sequence2batch.h" @@ -269,7 +270,6 @@ class FuisonLSTMKernel : public framework::OpKernel { blas.GEMM(CblasNoTrans, CblasNoTrans, bs, D4, D, static_cast(1), prev, D, \ wh_data, D4, static_cast(1), out, D4) -// gates: W_ch, W_ih, W_fh, W_oh #define GET_Ct(ct_1, gates, ct) \ /* C_t = C_t-1 * fgated + cand_gated * igated*/ \ act_cand(D, gates, gates); \ @@ -395,11 +395,22 @@ class FuisonLSTMKernel : public framework::OpKernel { } } } else { + // TODO(TJ): unly workaround, clean me + std::function compute_ctht; + if (platform::jit::MayIUse(platform::jit::avx) && + act_gate_str == "sigmoid" && act_cand_str == "tanh" && + act_cell_str == "tanh" && D == 8) { + compute_ctht = math::lstm_compute_ctht; + } else { + compute_ctht = [&](const T* gates, const T* ct_1, T* ct, T* ht) { + COMPUTE_CtHt(gates, ct_1, ct, ht); + } + } for (int i = 0; i < N; ++i) { PROCESS_H0C0 for (int step = tstart; step < seq_len; ++step) { GEMM_WH_ADDON(1, prev_h_data, xx_data); - COMPUTE_CtHt(xx_data, prev_c_data, c_out_data, h_out_data); + compute_ctht(xx_data, prev_c_data, c_out_data, h_out_data); MOVE_ONE_STEP; } } diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index d7f0f3c62..c7b627c4a 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -45,6 +45,8 @@ math_library(im2col) if (NOT WIN32) # windows do not support avx functions yet. math_library(gru_compute DEPS activation_functions math_function) math_library(lstm_compute DEPS activation_functions) +# TODO(TJ): ugly workaround, clean me +cc_library(cpu_lstm_compute SRCS cpu_lstm_compute.cc DEPS activation_functions cblas) endif (NOT WIN32) cc_library(blas SRCS blas.cc DEPS cblas framework_proto device_context) diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.cc b/paddle/fluid/operators/math/cpu_lstm_compute.cc new file mode 100644 index 000000000..7e487079d --- /dev/null +++ b/paddle/fluid/operators/math/cpu_lstm_compute.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/cpu_lstm_compute.h" +#ifdef __AVX__ +#include +#endif +namespace paddle { +namespace operators { +namespace math { + +#ifdef __AVX__ +// TODO(TJ): ugly workaround, clean me + +namespace detail { +namespace forward { +namespace avx {} // namespace avx +} // namespace forward +} // namespace detail + +template <> +void lstm_compute_ctht(const float* gates, const float* ct_1, float* ct, + float* ht) { + namespace act = detail::forward::avx; + // gates: W_ch, W_ih, W_fh, W_oh + __m256 c, i, f, o; + c = _mm256_loadu_ps(gates); + i = _mm256_loadu_ps(gates + 8); + f = _mm256_loadu_ps(gates + 16); + o = _mm256_loadu_ps(gates + 24); + + /* C_t = C_t-1 * fgated + cand_gated * igated*/ + c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i)); + i = _mm256_loadu_ps(ct_1); + f = _mm256_mul_ps(i, act::Sigmoid(f)); + f = _mm256_add_ps(c, f); + _mm256_storeu_ps(ct, f); + + /* H_t = act_cell(C_t) * ogated */ + o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o)); + _mm256_storeu_ps(ht, o); +} +#endif +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/cpu_lstm_compute.h b/paddle/fluid/operators/math/cpu_lstm_compute.h new file mode 100644 index 000000000..7b803b6c8 --- /dev/null +++ b/paddle/fluid/operators/math/cpu_lstm_compute.h @@ -0,0 +1,50 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/operators/math/cpu_vec.h" +#include "paddle/fluid/platform/cpu_info.h" + +namespace paddle { +namespace operators { +namespace math { + +// TODO(TJ): ugly workaround, clean me +template +void lstm_compute_ctht(const T* gates, const T* ct_1, T* ct, T* ht) { + // gates: W_ch, W_ih, W_fh, W_oh + vec_sigmoid(24, gates + 8, gates + 8); + vec_tanh(8, gates, gates); + const T *i = gates + 8, *f = gates + 16, *o = gates + 24; + for (int d = 0; d < 8; ++d) { + // C_t = C_t-1 * fgated + cand_gated * igated + ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; + + // H_t = act_cell(C_t) * ogated + T tmp = ct[d] * 2; + tmp = static_cast(0) - (tmp < static_cast(SIGMOID_THRESHOLD_MIN)) + ? min + : ((tmp > static_cast(SIGMOID_THRESHOLD_MAX)) + ? static_cast(SIGMOID_THRESHOLD_MAX) + : tmp); + vec_exp(1, &tmp, &tmp); + tmp = static_cast(2) / (static_cast(1) + tmp) - static_cast(1); + ht[d] = tmp * o[d]; + } +} + +} // namespace math +} // namespace operators +} // namespace paddle -- GitLab