提交 8dea07f2 编写于 作者: T tensor-tang

fix comopile

上级 612ba41a
......@@ -396,15 +396,15 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
}
} else {
// TODO(TJ): unly workaround, clean me
std::function<void(const T*, const T*, T*, T*)> compute_ctht;
std::function<void(T*, const T*, T*, T*)> compute_ctht;
if (platform::jit::MayIUse(platform::jit::avx) &&
act_gate_str == "sigmoid" && act_cand_str == "tanh" &&
act_cell_str == "tanh" && D == 8) {
compute_ctht = math::lstm_compute_ctht<T>;
} else {
compute_ctht = [&](const T* gates, const T* ct_1, T* ct, T* ht) {
compute_ctht = [&](T* gates, const T* ct_1, T* ct, T* ht) {
COMPUTE_CtHt(gates, ct_1, ct, ht);
}
};
}
for (int i = 0; i < N; ++i) {
PROCESS_H0C0
......
......@@ -25,12 +25,15 @@ namespace math {
namespace detail {
namespace forward {
namespace avx {} // namespace avx
namespace avx {
__m256 Sigmoid(const __m256 a);
__m256 Tanh(const __m256 a);
} // namespace avx
} // namespace forward
} // namespace detail
template <>
void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct,
void lstm_compute_ctht<float>(float* gates, const float* ct_1, float* ct,
float* ht) {
namespace act = detail::forward::avx;
// gates: W_ch, W_ih, W_fh, W_oh
......@@ -52,6 +55,7 @@ void lstm_compute_ctht<float>(const float* gates, const float* ct_1, float* ct,
_mm256_storeu_ps(ht, o);
}
#endif
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -23,22 +23,19 @@ namespace math {
// TODO(TJ): ugly workaround, clean me
template <typename T>
void lstm_compute_ctht(const T* gates, const T* ct_1, T* ct, T* ht) {
void lstm_compute_ctht(T* gates, const T* ct_1, T* ct, T* ht) {
// gates: W_ch, W_ih, W_fh, W_oh
vec_sigmoid<T, platform::jit::avx>(24, gates + 8, gates + 8);
vec_tanh<T, platform::jit::avx>(8, gates, gates);
const T *i = gates + 8, *f = gates + 16, *o = gates + 24;
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
for (int d = 0; d < 8; ++d) {
// C_t = C_t-1 * fgated + cand_gated * igated
ct[d] = ct_1[d] * f[d] + gates[d] * i[d];
// H_t = act_cell(C_t) * ogated
T tmp = ct[d] * 2;
tmp = static_cast<T>(0) - (tmp < static_cast<T>(SIGMOID_THRESHOLD_MIN))
? min
: ((tmp > static_cast<T>(SIGMOID_THRESHOLD_MAX))
? static_cast<T>(SIGMOID_THRESHOLD_MAX)
: tmp);
tmp = static_cast<T>(0) - (tmp < min) ? min : ((tmp > max) ? max : tmp);
vec_exp<T>(1, &tmp, &tmp);
tmp = static_cast<T>(2) / (static_cast<T>(1) + tmp) - static_cast<T>(1);
ht[d] = tmp * o[d];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册