提交 ff858d35 编写于 作者: T tensor-tang

fix bug and enable on batch mode as well

上级 8dea07f2
...@@ -543,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -543,12 +543,23 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
MOVE_ONE_STEP; MOVE_ONE_STEP;
} }
} else { } else {
// TODO(TJ): unly workaround, clean me
std::function<void(T*, const T*, T*, T*)> compute_ctht;
if (platform::jit::MayIUse(platform::jit::avx) &&
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
act_cell_str == "tanh" && D == 8) {
compute_ctht = math::lstm_compute_ctht<T>;
} else {
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
COMPUTE_CtHt(gates, ct_1, ct, ht);
};
}
for (int step = tstart; step < max_seq_len; ++step) { for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step]; const int cur_bs = batch_starts[step + 1] - batch_starts[step];
GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data); GEMM_WH_ADDON(cur_bs, prev_h_data, batched_input_data);
DEFINE_CUR; DEFINE_CUR;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
COMPUTE_CtHt(cur_in_data, cur_prev_c_data, cur_c_out_data, compute_ctht(cur_in_data, cur_prev_c_data, cur_c_out_data,
cur_h_out_data); cur_h_out_data);
MOVE_ONE_BATCH; MOVE_ONE_BATCH;
} }
......
...@@ -13,49 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,49 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/math/cpu_lstm_compute.h" #include "paddle/fluid/operators/math/cpu_lstm_compute.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {} // namespace math
#ifdef __AVX__
// TODO(TJ): ugly workaround, clean me
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,9 @@ limitations under the License. */ ...@@ -16,6 +16,9 @@ limitations under the License. */
#include <string> #include <string>
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#ifdef __AVX__
#include <immintrin.h>
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,13 +38,47 @@ void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) { ...@@ -35,13 +38,47 @@ void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
ct[d] = ct_1[d] * f[d] + gates[d] * i[d]; ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated // H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2; T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp); tmp = static_cast<T>(0) - ((tmp < min) ? min : ((tmp > max) ? max : tmp));
vec_exp<T>(1, &tmp, &tmp); vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1); tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d]; ht[d] = tmp * o[d];
} }
} }
#ifdef __AVX__
namespace detail {
namespace forward {
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
__m256 c, i, f, o;
c = _mm256_loadu_ps(gates);
i = _mm256_loadu_ps(gates + 8);
f = _mm256_loadu_ps(gates + 16);
o = _mm256_loadu_ps(gates + 24);
/* C_t = C_t-1 * fgated + cand_gated * igated*/
c = _mm256_mul_ps(act::Tanh(c), act::Sigmoid(i));
i = _mm256_loadu_ps(ct_1);
f = _mm256_mul_ps(i, act::Sigmoid(f));
f = _mm256_add_ps(c, f);
_mm256_storeu_ps(ct, f);
/* H_t = act_cell(C_t) * ogated */
o = _mm256_mul_ps(act::Tanh(f), act::Sigmoid(o));
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册