提交 5ca0bb9a 编写于 作者: T tensor-tang

support more activation type and remove some comments

上级 dd938d0b
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h"
#include <sys/time.h>
#include <string>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
......@@ -192,24 +193,23 @@ void AttentionLSTMOpMaker::Make() {
"(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
"Shape is (1 x 4D), where M is the x frame size")
.AsIntermediate();
// TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid")
.InEnum({"sigmoid"});
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh")
.InEnum({"tanh"});
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh")
.InEnum({"tanh"});
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Attention Long-Short Term Memory (LSTM) Operator.
......@@ -273,22 +273,23 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); // T x M
auto* h0 = ctx.Input<Tensor>("H0"); // N x D
auto* c0 = ctx.Input<Tensor>("C0"); // N x D
auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1
auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias"); // 1x1
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut"); // max_seq_len x 1
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
auto* atten_b = ctx.Input<Tensor>("AttentionBias");
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
auto* lstm_b = ctx.Input<Tensor>("LSTMBias");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* atted_x = ctx.Output<Tensor>("AttentionedX");
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
auto* lstm_x = ctx.Output<Tensor>("LSTMX");
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
// some shape should be reshape here since infershape can not get lod info
auto x_lod = x->lod();
......@@ -310,11 +311,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1});
// TODO(TJ): act functor init here
// if (platform::jit::MayIUse(platform::jit::avx2)) {
// } else if (platform::jit::MayIUse(platform::jit::avx)) {
// } else {
// }
math::VecActivations<T> act_functor;
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
......@@ -381,9 +382,9 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);
// gate act: sigmoid
math::vec_sigmoid(D3, lstm_out_data, lstm_out_data);
act_gate(D3, lstm_out_data, lstm_out_data);
// candicate act: tanh
math::vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3);
act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
// a = forget * prev_cell
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
......@@ -395,7 +396,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
// state act tanh(cell_out) * output_gate
math::vec_tanh(D, cur_cell_out_data, lstm_out_data);
act_cell(D, cur_cell_out_data, lstm_out_data);
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
prev_hidden_data = cur_hidden_out_data;
......
......@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h"
namespace paddle {
......@@ -34,6 +34,12 @@ inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_identity(const int n, const T* x, T* y) {
// do nothing
return;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN;
......@@ -76,6 +82,24 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
}
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
class VecActivations {
public:
std::function<void(const int, const T*, T*)> operator()(
const std::string& type) {
if (type == "sigmoid") {
return vec_sigmoid<T, isa>;
} else if (type == "relu") {
return vec_relu<T, isa>;
} else if (type == "tanh") {
return vec_tanh<T, isa>;
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
PADDLE_THROW("Not support type %s.", type);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -160,6 +160,15 @@ class TestAttentionOpNonInit(TestAttentionLSTMOp):
self.has_initial_hidden = False
class TestAttentionOpAct(TestAttentionLSTMOp):
def set_conf(self):
self.M = 3
self.D = 2
self.act_gate = 'relu'
self.act_cell = 'tanh'
self.act_cand = 'sigmoid'
class TestAttentionOpMD1(TestAttentionLSTMOp):
def set_conf(self):
self.M = 36
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册