未验证 提交 d292ad85 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #12958 from tensor-tang/refine/op/fusion_lstm

refine fusion lstm
...@@ -15,10 +15,14 @@ limitations under the License. */ ...@@ -15,10 +15,14 @@ limitations under the License. */
#include "paddle/fluid/operators/fusion_lstm_op.h" #include "paddle/fluid/operators/fusion_lstm_op.h"
#include <string> #include <string>
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/cpu_vec.h"
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/fc_compute.h" #include "paddle/fluid/operators/math/fc_compute.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(seq_mode, true, "Use sequence mode");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -98,7 +102,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
ctx->ShareLoD("X", "Cell"); ctx->ShareLoD("X", "Cell");
int xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; int xx_width;
if (FLAGS_seq_mode) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
} }
...@@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -205,10 +214,138 @@ inline void ReorderInitState(const DeviceContext& ctx,
row_shuffle(ctx, src, index_lod, dst, indexed_src); row_shuffle(ctx, src, index_lod, dst, indexed_src);
} }
template <typename DeviceContext, typename T> template <typename T>
class FuisonLSTMKernel : public framework::OpKernel<T> { class FuisonLSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
bool is_reverse = ctx.Attr<bool>("is_reverse");
std::function<void(const int, const T *, T *)> act_gate, act_cell, act_cand;
auto& act_gate_str = ctx.Attr<std::string>("gate_activation");
auto& act_cell_str = ctx.Attr<std::string>("cell_activation");
auto& act_cand_str = ctx.Attr<std::string>("candidate_activation");
if (platform::jit::MayIUse(platform::jit::avx)) {
math::VecActivations<T, platform::jit::avx> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
} else {
math::VecActivations<T, platform::jit::isa_any> act_functor;
act_gate = act_functor(act_gate_str);
act_cell = act_functor(act_cell_str);
act_cand = act_functor(act_cand_str);
}
auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 4D
const int total_T = x_dims[0];
const int N = x_lod[0].size() - 1; // batch size
const int M = x_dims[1]; // x frame size
const int D = wh_dims[0];
const int D2 = D * 2;
const int D3 = D * 3;
const int D4 = wh_dims[1];
const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL;
const T* c0_data = c0 ? c0->data<T>() : NULL;
const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data<T>());
int xx_offset = D4;
int gate_offset = D;
if (is_reverse) {
const int offset = (total_T - 1) * D;
xx_data = xx_data + offset * 4;
hidden_out_data = hidden_out_data + offset;
cell_out_data = cell_out_data + offset;
xx_offset = -D4;
gate_offset = -D;
}
auto move_step = [&]() {
xx_data = xx_data + xx_offset;
hidden_out_data = hidden_out_data + gate_offset;
cell_out_data = cell_out_data + gate_offset;
};
for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_cell_data = NULL;
const T* prev_hidden_data = NULL;
int tstart = 0;
if (h0_data) {
prev_hidden_data = h0_data + bid * D;
prev_cell_data = c0_data + bid * D;
} else {
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// cell out= input*tilde
blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_hidden_data = hidden_out_data;
prev_cell_data = cell_out_data;
tstart = 1;
move_step();
}
for (int step = tstart; step < seq_len; ++step) {
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_hidden_data, D, wh_data, D4, static_cast<T>(1), xx_data,
D4);
// W_ch, W_ih, W_fh, W_oh
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data);
// a = forget * prev_cell
blas.VMUL(D, xx_data + D2, prev_cell_data, xx_data + D2);
// b = input * tilde
blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
// cell out= a+b
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
// hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2);
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev
prev_hidden_data = hidden_out_data;
prev_cell_data = cell_out_data;
move_step();
}
}
}
void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX"); auto* wx = ctx.Input<Tensor>("WeightX");
auto* wh = ctx.Input<Tensor>("WeightH"); auto* wh = ctx.Input<Tensor>("WeightH");
...@@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -339,6 +476,13 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
// restore the output cell state in LoDTensor from the batch cell // restore the output cell state in LoDTensor from the batch cell
to_seq(dev_ctx, batch_cell, cell_out); to_seq(dev_ctx, batch_cell, cell_out);
} }
void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_seq_mode) {
SeqCompute(ctx);
} else {
BatchCompute(ctx);
}
}
}; };
} // namespace operators } // namespace operators
...@@ -348,7 +492,5 @@ namespace ops = paddle::operators; ...@@ -348,7 +492,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker, REGISTER_OPERATOR(fusion_lstm, ops::FusionLSTMOp, ops::FusionLSTMOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(fusion_lstm, ops::FuisonLSTMKernel<float>,
fusion_lstm, ops::FuisonLSTMKernel<double>);
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, float>,
ops::FuisonLSTMKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -43,13 +43,13 @@ def fusion_lstm( ...@@ -43,13 +43,13 @@ def fusion_lstm(
act_cell, act_cand) act_cell, act_cand)
class TestLstmOp(OpTest): class TestFusionLSTMOp(OpTest):
def set_argument(self): def set_conf(self):
self.lod = [[2, 3, 2]] pass
def setUp(self): def setUp(self):
self.op_type = 'fusion_lstm' self.op_type = 'fusion_lstm'
self.lod = [[2, 3, 2]] self.lod = [[2, 3, 5, 4]]
self.M = 8 self.M = 8
self.D = 16 self.D = 16
self.has_initial_state = False self.has_initial_state = False
...@@ -58,33 +58,33 @@ class TestLstmOp(OpTest): ...@@ -58,33 +58,33 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False self.use_peepholes = False
self.set_argument() self.set_conf()
T = sum(self.lod[0]) T = sum(self.lod[0])
bs = len(self.lod[0]) bs = len(self.lod[0])
x = np.random.normal(size=(T, self.M)).astype('float64') x = np.random.normal(size=(T, self.M)).astype('float32')
if self.has_initial_state: if self.has_initial_state:
h0 = np.random.normal(size=(bs, self.D)).astype('float64') h0 = np.random.normal(size=(bs, self.D)).astype('float32')
c0 = np.random.normal(size=(bs, self.D)).astype('float64') c0 = np.random.normal(size=(bs, self.D)).astype('float32')
else: else:
h0 = np.zeros((bs, self.D)).astype('float64') h0 = np.zeros((bs, self.D)).astype('float32')
c0 = np.zeros((bs, self.D)).astype('float64') c0 = np.zeros((bs, self.D)).astype('float32')
wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') wh = np.random.normal(size=(self.D, 4 * self.D)).astype('float32')
if self.use_peepholes: if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64') b = np.random.normal(size=(1, 7 * self.D)).astype('float32')
else: else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64') b = np.random.normal(size=(1, 4 * self.D)).astype('float32')
w_b = np.copy(b[:, 0:4 * self.D]) w_b = np.copy(b[:, 0:4 * self.D])
w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
# this is the weight of fc # this is the weight of fc
wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float64') wx = np.random.normal(size=(self.M, 4 * self.D)).astype('float32')
# this is the bias of fc # this is the bias of fc
# and it should be manually added into the bias of this fusion LSTM # and it should be manually added into the bias of this fusion LSTM
bx = np.random.normal(size=(1, 4 * self.D)).astype('float64') bx = np.random.normal(size=(1, 4 * self.D)).astype('float32')
b[0, 0:4 * self.D] += bx[0, :] b[0, 0:4 * self.D] += bx[0, :]
h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c, h, c = fusion_lstm(x, self.lod, wx, bx, h0, c0, wh, w_b, w_c,
self.is_reverse, ACTIVATION[self.act_gate], self.is_reverse, ACTIVATION[self.act_gate],
...@@ -114,35 +114,45 @@ class TestLstmOp(OpTest): ...@@ -114,35 +114,45 @@ class TestLstmOp(OpTest):
} }
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-8) self.check_output()
class TestLstmOpInitReverse(TestLstmOp): class TestFusionLSTMOpInit(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.has_initial_state = True
class TestFusionLSTMOpReverse(TestFusionLSTMOp):
def set_conf(self):
self.is_reverse = True
class TestFusionLSTMOpInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
class TestLstmOpMD1(TestLstmOp): class TestFusionLSTMOpMD1(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.M = 36 self.M = 36
self.D = 8 self.D = 8
class TestLstmOpMD2(TestLstmOp): class TestFusionLSTMOpMD2(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.M = 8 self.M = 8
self.D = 8 self.D = 8
class TestLstmOpMD3(TestLstmOp): class TestFusionLSTMOpMD3(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.M = 15 self.M = 15
self.D = 3 self.D = 3
class TestLstmOpBS1(TestLstmOp): class TestFusionLSTMOpBS1(TestFusionLSTMOp):
def set_argument(self): def set_conf(self):
self.lod = [[3]] self.lod = [[3]]
self.D = 16 self.D = 16
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册