提交 5ca0bb9a 编写于 作者: T tensor-tang

support more activation type and remove some comments

上级 dd938d0b
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/attention_lstm_op.h" #include "paddle/fluid/operators/attention_lstm_op.h"
#include <sys/time.h>
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h" #include "paddle/fluid/operators/math/cpu_vec.h"
...@@ -192,24 +193,23 @@ void AttentionLSTMOpMaker::Make() { ...@@ -192,24 +193,23 @@ void AttentionLSTMOpMaker::Make() {
"(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step." "(Tensor) the output of LSTM X(1*(D+M))* weight((D+M)*4D) for each step."
"Shape is (1 x 4D), where M is the x frame size") "Shape is (1 x 4D), where M is the x frame size")
.AsIntermediate(); .AsIntermediate();
// TODO(TJ): InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("gate_activation", AddAttr<std::string>("gate_activation",
"(string, default: sigmoid)" "(string, default: sigmoid)"
"The activation for input gate, forget gate and output " "The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.") "gate, `sigmoid` by default.")
.SetDefault("sigmoid") .SetDefault("sigmoid")
.InEnum({"sigmoid"}); .InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation", AddAttr<std::string>("cell_activation",
"(string, default: tanh)" "(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.") "The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh") .SetDefault("tanh")
.InEnum({"tanh"}); .InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation", AddAttr<std::string>("candidate_activation",
"(string, default: tanh)" "(string, default: tanh)"
"The activation for candidate hidden state, " "The activation for candidate hidden state, "
"`tanh` by default.") "`tanh` by default.")
.SetDefault("tanh") .SetDefault("tanh")
.InEnum({"tanh"}); .InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC( AddComment(R"DOC(
Attention Long-Short Term Memory (LSTM) Operator. Attention Long-Short Term Memory (LSTM) Operator.
...@@ -273,22 +273,23 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -273,22 +273,23 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); // T x M
auto* h0 = ctx.Input<Tensor>("H0"); // N x D auto* x = ctx.Input<LoDTensor>("X");
auto* c0 = ctx.Input<Tensor>("C0"); // N x D auto* h0 = ctx.Input<Tensor>("H0");
auto* atten_w = ctx.Input<Tensor>("AttentionWeight"); // (M+D) x 1 auto* c0 = ctx.Input<Tensor>("C0");
auto* atten_b = ctx.Input<Tensor>("AttentionBias"); // 1x1 auto* atten_w = ctx.Input<Tensor>("AttentionWeight");
auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar"); // 1x1 auto* atten_b = ctx.Input<Tensor>("AttentionBias");
auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias"); // 1x1 auto* atten_scalar = ctx.Input<Tensor>("AttentionScalar");
auto* lstm_w = ctx.Input<Tensor>("LSTMWeight"); // (D+M) x D*4 auto* atten_scalar_bias = ctx.Input<Tensor>("AttentionScalarBias");
auto* lstm_b = ctx.Input<Tensor>("LSTMBias"); // 1 x D*4 auto* lstm_w = ctx.Input<Tensor>("LSTMWeight");
auto* lstm_b = ctx.Input<Tensor>("LSTMBias");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); // TxD
auto* cell_out = ctx.Output<LoDTensor>("Cell"); // TxD auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* atted_x = ctx.Output<Tensor>("AttentionedX"); // T x 1 auto* cell_out = ctx.Output<LoDTensor>("Cell");
auto* fc_out = ctx.Output<Tensor>("AttentionFCOut"); // max_seq_len x 1 auto* atted_x = ctx.Output<Tensor>("AttentionedX");
auto* lstm_x = ctx.Output<Tensor>("LSTMX"); // 1 x M auto* fc_out = ctx.Output<Tensor>("AttentionFCOut");
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT"); // 1 x 4D auto* lstm_x = ctx.Output<Tensor>("LSTMX");
auto* lstm_out = ctx.Output<Tensor>("LSTMOUT");
// some shape should be reshape here since infershape can not get lod info // some shape should be reshape here since infershape can not get lod info
auto x_lod = x->lod(); auto x_lod = x->lod();
...@@ -310,11 +311,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -310,11 +311,11 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D); PADDLE_ENFORCE_EQ(c0->dims()[0], N, "C0 dims should be %d x %d.", N, D);
fc_out->Resize({max_seq_len, 1}); fc_out->Resize({max_seq_len, 1});
// TODO(TJ): act functor init here math::VecActivations<T> act_functor;
// if (platform::jit::MayIUse(platform::jit::avx2)) { std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
// } else if (platform::jit::MayIUse(platform::jit::avx)) { act_gate = act_functor(ctx.Attr<std::string>("gate_activation"));
// } else { act_cell = act_functor(ctx.Attr<std::string>("cell_activation"));
// } act_cand = act_functor(ctx.Attr<std::string>("candidate_activation"));
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL; const T* h0_data = h0 ? h0->data<T>() : NULL;
...@@ -381,9 +382,9 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -381,9 +382,9 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data); blas.VADD(D4, lstm_b_data, lstm_out_data, lstm_out_data);
// gate act: sigmoid // gate act: sigmoid
math::vec_sigmoid(D3, lstm_out_data, lstm_out_data); act_gate(D3, lstm_out_data, lstm_out_data);
// candicate act: tanh // candicate act: tanh
math::vec_tanh(D, lstm_out_data + D3, lstm_out_data + D3); act_cand(D, lstm_out_data + D3, lstm_out_data + D3);
// a = forget * prev_cell // a = forget * prev_cell
blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data); blas.VMUL(D, lstm_out_data, prev_cell_data, lstm_out_data);
...@@ -395,7 +396,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> { ...@@ -395,7 +396,7 @@ class AttentionLSTMKernel : public framework::OpKernel<T> {
blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data); blas.VADD(D, lstm_out_data, lstm_out_data + D, cur_cell_out_data);
// state act tanh(cell_out) * output_gate // state act tanh(cell_out) * output_gate
math::vec_tanh(D, cur_cell_out_data, lstm_out_data); act_cell(D, cur_cell_out_data, lstm_out_data);
blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data); blas.VMUL(D, lstm_out_data, lstm_out_data + D2, cur_hidden_out_data);
prev_hidden_data = cur_hidden_out_data; prev_hidden_data = cur_hidden_out_data;
......
...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
namespace paddle { namespace paddle {
...@@ -34,6 +34,12 @@ inline T tanh(T x) { ...@@ -34,6 +34,12 @@ inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.; return 2. * sigmoid(2. * x) - 1.;
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_identity(const int n, const T* x, T* y) {
// do nothing
return;
}
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any> template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
inline void vec_sigmoid(const int n, const T* x, T* y) { inline void vec_sigmoid(const int n, const T* x, T* y) {
const T min = SIGMOID_THRESHOLD_MIN; const T min = SIGMOID_THRESHOLD_MIN;
...@@ -76,6 +82,24 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x, ...@@ -76,6 +82,24 @@ inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
} }
} }
template <typename T, platform::jit::cpu_isa_t isa = platform::jit::isa_any>
class VecActivations {
public:
std::function<void(const int, const T*, T*)> operator()(
const std::string& type) {
if (type == "sigmoid") {
return vec_sigmoid<T, isa>;
} else if (type == "relu") {
return vec_relu<T, isa>;
} else if (type == "tanh") {
return vec_tanh<T, isa>;
} else if (type == "identity" || type == "") {
return vec_identity<T, isa>;
}
PADDLE_THROW("Not support type %s.", type);
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -160,6 +160,15 @@ class TestAttentionOpNonInit(TestAttentionLSTMOp): ...@@ -160,6 +160,15 @@ class TestAttentionOpNonInit(TestAttentionLSTMOp):
self.has_initial_hidden = False self.has_initial_hidden = False
class TestAttentionOpAct(TestAttentionLSTMOp):
def set_conf(self):
self.M = 3
self.D = 2
self.act_gate = 'relu'
self.act_cell = 'tanh'
self.act_cand = 'sigmoid'
class TestAttentionOpMD1(TestAttentionLSTMOp): class TestAttentionOpMD1(TestAttentionLSTMOp):
def set_conf(self): def set_conf(self):
self.M = 36 self.M = 36
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册