提交 5f586e22 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into refine/op/fusion_lstm

...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h" #include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
......
...@@ -14,7 +14,7 @@ else ...@@ -14,7 +14,7 @@ else
fi fi
PREFIX=inference-vis-demos%2F PREFIX=inference-vis-demos%2F
URL_ROOT=http://paddlemodels.bj.bcebos.com/${PREFIX} URL_ROOT=http://paddlemodels.cdn.bcebos.com/${PREFIX}
# download vis_demo data # download vis_demo data
function download() { function download() {
......
...@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -39,19 +39,6 @@ bool RequestSendHandler::Handle(const std::string& varname,
const std::string& out_var_name) { const std::string& out_var_name) {
VLOG(4) << "RequestSendHandler:" << varname; VLOG(4) << "RequestSendHandler:" << varname;
// Async
if (!sync_mode_) {
rpc_server_->Profiler().OneStep();
try {
executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
scope);
} catch (std::exception& e) {
LOG(ERROR) << "async: run sub program error " << e.what();
return false;
}
return true;
}
// Sync // Sync
if (varname == BATCH_BARRIER_MESSAGE) { if (varname == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE"; VLOG(3) << "sync: recv BATCH_BARRIER_MESSAGE";
...@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname, ...@@ -60,17 +47,31 @@ bool RequestSendHandler::Handle(const std::string& varname,
VLOG(3) << "sync: recv complete message"; VLOG(3) << "sync: recv complete message";
rpc_server_->Complete(); rpc_server_->Complete();
} else { } else {
VLOG(3) << "sync: received var_name: " << varname; // Async
rpc_server_->WaitCond(kRequestSend); if (!sync_mode_) {
VLOG(3) << "sync: processing received var: " << varname; VLOG(3) << "async process var: " << varname;
rpc_server_->Profiler().OneStep();
if (invar == nullptr) { try {
LOG(FATAL) << "sync: Can not find server side var: " << varname; executor_->RunPreparedContext((*grad_to_prepared_ctx_)[varname].get(),
return false; scope);
} } catch (std::exception& e) {
if (invar->IsType<framework::SelectedRows>()) { LOG(ERROR) << "async: run sub program error " << e.what();
std::unique_lock<std::mutex> lock(mutex_sparse_vars_); return false;
sparse_vars_.push_back(invar); }
return true;
} else { // sync
rpc_server_->WaitCond(kRequestSend);
VLOG(3) << "sync: processing received var: " << varname;
if (invar == nullptr) {
LOG(FATAL) << "sync: Can not find server side var: " << varname;
return false;
}
if (invar->IsType<framework::SelectedRows>()) {
std::unique_lock<std::mutex> lock(mutex_sparse_vars_);
sparse_vars_.push_back(invar);
}
} }
} }
return true; return true;
......
...@@ -79,12 +79,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -79,12 +79,12 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1."); "The first dimension of Input(Bias) should be 1.");
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_peepholes"), auto use_peepholes = ctx->Attrs().Get<bool>("use_peepholes");
"Do not support peephole yet."); PADDLE_ENFORCE_EQ(b_dims[1], (use_peepholes ? 7 : 4) * frame_size,
PADDLE_ENFORCE_EQ(b_dims[1], 4 * frame_size,
"The second dimension of Input(Bias) should be " "The second dimension of Input(Bias) should be "
"4 * %d if disable peepholes connection", "7 * %d if enable peepholes connection or"
frame_size); "4 * %d if disable peepholes",
frame_size, frame_size);
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
...@@ -231,16 +231,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -231,16 +231,17 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
act_cand = act_functor(act_cand_str); \ act_cand = act_functor(act_cand_str); \
} }
#define INIT_BASE_INPUT_OUTPUT \ #define INIT_BASE_INPUT_OUTPUT \
auto* x = ctx.Input<LoDTensor>("X"); \ auto* x = ctx.Input<LoDTensor>("X"); \
auto* h0 = ctx.Input<Tensor>("H0"); \ auto* h0 = ctx.Input<Tensor>("H0"); \
auto* c0 = ctx.Input<Tensor>("C0"); \ auto* c0 = ctx.Input<Tensor>("C0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \ auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
auto* cell_out = ctx.Output<LoDTensor>("Cell"); \ auto* cell_out = ctx.Output<LoDTensor>("Cell"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \ #define INIT_BASE_SIZES \
...@@ -265,12 +266,21 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -265,12 +266,21 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : nullptr; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* c0_data = c0 ? c0->data<T>() : nullptr; const T* c0_data = c0 ? c0->data<T>() : nullptr;
const T* bias_data = bias->data<T>();
const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>(); const T* wh_data = wh->data<T>();
T* xx_data = xx->mutable_data<T>(ctx.GetPlace()); T* xx_data = xx->mutable_data<T>(ctx.GetPlace());
T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace()); T* hidden_out_data = hidden_out->mutable_data<T>(ctx.GetPlace());
T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace()); T* cell_out_data = cell_out->mutable_data<T>(ctx.GetPlace());
// use local variable
framework::DDim check_dims({3, D});
Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto checked_cell_data =
checked_cell.mutable_data<T>(check_dims, ctx.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D4, M, x_data, wx_data,
xx_data, bias->data<T>()); xx_data, bias->data<T>());
...@@ -296,46 +306,86 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -296,46 +306,86 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_c_data = nullptr; const T* prev_c_data = nullptr;
const T* prev_h_data = nullptr; const T* prev_h_data = nullptr;
int tstart = 0; int tstart = 0;
if (h0_data) { if (h0_data) {
prev_h_data = h0_data + bid * D; prev_h_data = h0_data + bid * D;
prev_c_data = c0_data + bid * D; prev_c_data = c0_data + bid * D;
} else { } else {
// W_ch, W_ih, W_fh, W_oh // If step == 0 and there is no initialized hidden state, that is to say
act_gate(D3, xx_data + D, xx_data + D); // the H0 is zeros. Then W_h * H_t-1 can be skipped
// ~C_t
act_cand(D, xx_data, xx_data); act_cand(D, xx_data, xx_data);
// cell out= input*tilde if (use_peepholes) {
// I_t, F_t
act_gate(D2, xx_data + D, xx_data + D);
} else {
// I_t, F_t, O_t
act_gate(D3, xx_data + D, xx_data + D);
}
// C_t = I_t * ~C_t
blas.VMUL(D, xx_data, xx_data + D, cell_out_data); blas.VMUL(D, xx_data, xx_data + D, cell_out_data);
if (use_peepholes) {
// + W_oc * C_t for peephole connection
blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
// O_t
act_gate(D, xx_data + D3, xx_data + D3);
}
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2); act_cell(D, cell_out_data, xx_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev // prev
prev_h_data = hidden_out_data; prev_h_data = hidden_out_data;
prev_c_data = cell_out_data; prev_c_data = cell_out_data;
tstart = 1;
tstart = 1;
move_step(); move_step();
} }
for (int step = tstart; step < seq_len; ++step) { for (int step = tstart; step < seq_len; ++step) {
// + W_h * H_t-1
blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, 1, D4, D, static_cast<T>(1),
prev_h_data, D, wh_data, D4, static_cast<T>(1), xx_data, D4); prev_h_data, D, wh_data, D4, static_cast<T>(1), xx_data, D4);
// W_ch, W_ih, W_fh, W_oh // ~C_t
act_gate(D3, xx_data + D, xx_data + D);
act_cand(D, xx_data, xx_data); act_cand(D, xx_data, xx_data);
// a = forget * prev_cell if (use_peepholes) {
// + W_ic|W_fc * C_t-1 for peephole connection
blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
blas.VADD(D2, xx_data + D, checked_cell_data, xx_data + D);
// I_t, F_t
act_gate(D2, xx_data + D, xx_data + D);
} else {
// I_t, F_t, O_t
act_gate(D3, xx_data + D, xx_data + D);
}
// F_t * C_t-1
blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2); blas.VMUL(D, xx_data + D2, prev_c_data, xx_data + D2);
// I_t * ~C_t
// b = input * tilde
blas.VMUL(D, xx_data, xx_data + D, xx_data + D); blas.VMUL(D, xx_data, xx_data + D, xx_data + D);
// C_t = F_t * C_t-1 + I_t * ~C_t
// cell out= a+b
blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data); blas.VADD(D, xx_data + D, xx_data + D2, cell_out_data);
if (use_peepholes) {
// + W_oc * C_t for peephole connection
blas.VMUL(D, wc_data + D2, cell_out_data, checked_cell_data + D2);
blas.VADD(D, xx_data + D3, checked_cell_data + D2, xx_data + D3);
// O_t
act_gate(D, xx_data + D3, xx_data + D3);
}
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
act_cell(D, cell_out_data, xx_data + D2); act_cell(D, cell_out_data, xx_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data); blas.VMUL(D, xx_data + D2, xx_data + D3, hidden_out_data);
// prev // prev
...@@ -343,14 +393,14 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -343,14 +393,14 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
prev_c_data = cell_out_data; prev_c_data = cell_out_data;
move_step(); move_step();
} } // for each step in batch
} } // for each batch
} }
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = platform::CPUDeviceContext; using DeviceContext = platform::CPUDeviceContext;
INIT_BASE_INPUT_OUTPUT INIT_BASE_INPUT_OUTPUT
if (x->lod()[0].size() == 2) { if (x->lod()[0].size() == 2) { // batch size == 1
SeqCompute(ctx); SeqCompute(ctx);
return; return;
} }
...@@ -366,6 +416,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -366,6 +416,8 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>(); const T* wh_data = wh->data<T>();
const T* bias_data = bias->data<T>();
const T* wc_data = bias_data + D4; // w_ic, w_fc, w_oc
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
T* xx_data = xx->mutable_data<T>(place); T* xx_data = xx->mutable_data<T>(place);
T* batched_input_data = batched_input->mutable_data<T>(place); T* batched_input_data = batched_input->mutable_data<T>(place);
...@@ -374,6 +426,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -374,6 +426,12 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
hidden_out->mutable_data<T>(place); hidden_out->mutable_data<T>(place);
cell_out->mutable_data<T>(place); cell_out->mutable_data<T>(place);
// use local variable
framework::DDim check_dims({3, D});
Tensor checked_cell; // w_ic * Ct-1, w_fc * Ct-1, w_oc * Ct
auto checked_cell_data =
checked_cell.mutable_data<T>(check_dims, ctx.GetPlace());
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
...@@ -395,17 +453,27 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -395,17 +453,27 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_h0->Resize({max_bs, D}); reordered_h0->Resize({max_bs, D});
reordered_c0->Resize({max_bs, D}); reordered_c0->Resize({max_bs, D});
T* prev_batch_h_data = nullptr;
T* prev_batch_c_data = nullptr;
T* cur_batch_in_data = batched_input_data;
T* cur_batch_h_out_data = batched_h_out_data;
T* cur_batch_c_out_data = batched_c_out_data;
auto move_step = [&](int bs) {
cur_batch_in_data += bs * D4;
cur_batch_c_out_data += bs * D;
cur_batch_h_out_data += bs * D;
};
int tstart = 0; int tstart = 0;
T* prev_h_data = nullptr;
T* prev_c_data = nullptr;
if (h0) { if (h0) {
// reorder h0, c0 // reorder h0, c0
T* reordered_h0_data = reordered_h0->mutable_data<T>(place); T* reordered_h0_data = reordered_h0->mutable_data<T>(place);
T* reordered_c0_data = reordered_c0->mutable_data<T>(place); T* reordered_c0_data = reordered_c0->mutable_data<T>(place);
const T* h0_data = h0->data<T>(); const T* h0_data = h0->data<T>();
const T* c0_data = c0->data<T>(); const T* c0_data = c0->data<T>();
prev_h_data = reordered_h0_data; prev_batch_h_data = reordered_h0_data;
prev_c_data = reordered_c0_data; prev_batch_c_data = reordered_c0_data;
size_t sz = sizeof(T) * D; size_t sz = sizeof(T) * D;
for (int i = 0; i < max_bs; ++i) { for (int i = 0; i < max_bs; ++i) {
std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz); std::memcpy(reordered_h0_data, h0_data + seq_order[i] * D, sz);
...@@ -414,71 +482,122 @@ class FuisonLSTMKernel : public framework::OpKernel<T> { ...@@ -414,71 +482,122 @@ class FuisonLSTMKernel : public framework::OpKernel<T> {
reordered_c0_data += D; reordered_c0_data += D;
} }
} else { } else {
// compute without h0, c0 // Compute with no H0/C0
T* cur_in_data = batched_input_data; T* cur_in_data = cur_batch_in_data;
T* cur_h_out_data = batched_h_out_data; T* cur_c_out_data = cur_batch_c_out_data;
T* cur_c_out_data = batched_c_out_data; T* cur_h_out_data = cur_batch_h_out_data;
// W_ch, W_ih, W_fh, W_oh
for (int i = 0; i < max_bs; ++i) { // If step == 0 and there is no initialized hidden state, that is to say
act_gate(D3, cur_in_data + D, cur_in_data + D); // the H0 is zeros. Then W_h * H_t-1 can be skiped
for (int i = 0; i < max_bs; ++i) { // iterate each data in 1st batch
// ~C_t
act_cand(D, cur_in_data, cur_in_data); act_cand(D, cur_in_data, cur_in_data);
// cell out= input*tilde
if (use_peepholes) {
// I_t, F_t
act_gate(D2, cur_in_data + D, cur_in_data + D);
} else {
// I_t, F_t, O_t
act_gate(D3, cur_in_data + D, cur_in_data + D);
}
// C_t = I_t * ~C_t
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data); blas.VMUL(D, cur_in_data, cur_in_data + D, cur_c_out_data);
if (use_peepholes) {
// + W_oc * C_t for peephole connection
blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2);
blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
cur_in_data + D3);
// O_t
act_gate(D, cur_in_data + D3, cur_in_data + D3);
}
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2); act_cell(D, cur_c_out_data, cur_in_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
// add offset // move to next data in the same batch
cur_in_data += D4; cur_in_data += D4;
cur_c_out_data += D; cur_c_out_data += D;
cur_h_out_data += D; cur_h_out_data += D;
} }
// move to data for next timestep
prev_batch_h_data = cur_batch_h_out_data;
prev_batch_c_data = cur_batch_c_out_data;
move_step(max_bs);
tstart = 1; tstart = 1;
prev_h_data = batched_h_out_data;
prev_c_data = batched_c_out_data;
} }
// Then start from next
const auto& batch_starts = batched_lod[0]; const auto& batch_starts = batched_lod[0];
const int max_seq_len = batch_starts.size() - 1; const int max_seq_len = batch_starts.size() - 1;
const int offset = tstart * max_bs * D;
batched_input_data = batched_input_data + offset * 4;
batched_h_out_data = batched_h_out_data + offset;
batched_c_out_data = batched_c_out_data + offset;
for (int step = tstart; step < max_seq_len; ++step) { for (int step = tstart; step < max_seq_len; ++step) {
const int cur_bs = batch_starts[step + 1] - batch_starts[step]; const int cur_bs = batch_starts[step + 1] - batch_starts[step];
// + W_h * H_t-1
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D4, D, static_cast<T>(1),
prev_h_data, D, wh_data, D4, static_cast<T>(1), prev_batch_h_data, D, wh_data, D4, static_cast<T>(1),
batched_input_data, D4); cur_batch_in_data, D4);
T* cur_in_data = batched_input_data; T* cur_in_data = cur_batch_in_data;
T* cur_prev_c_data = prev_c_data; T* cur_c_out_data = cur_batch_c_out_data;
T* cur_c_out_data = batched_c_out_data; T* cur_h_out_data = cur_batch_h_out_data;
T* cur_h_out_data = batched_h_out_data; T* prev_c_data = prev_batch_c_data; // NULL if no C0 in step0
for (int i = 0; i < cur_bs; ++i) { T* prev_h_data = prev_batch_h_data; // NULL if no H0 in step0
// W_ch, W_ih, W_fh, W_oh auto next_data_in_batch = [&]() {
act_gate(D3, cur_in_data + D, cur_in_data + D); cur_in_data += D4;
cur_c_out_data += D;
cur_h_out_data += D;
prev_c_data = prev_c_data ? prev_c_data + D : nullptr;
prev_h_data = prev_h_data ? prev_h_data + D : nullptr;
};
for (int i = 0; i < cur_bs; ++i) { // iterate each data in same batch
// ~C_t
act_cand(D, cur_in_data, cur_in_data); act_cand(D, cur_in_data, cur_in_data);
// a = forget * prev_cell
blas.VMUL(D, cur_in_data + D2, cur_prev_c_data, cur_in_data + D2); if (use_peepholes) {
// b = input * tilde // + W_ic|W_fc * C_t-1 for peephole connection
blas.VMUL(D, wc_data, prev_c_data, checked_cell_data);
blas.VMUL(D, wc_data + D, prev_c_data, checked_cell_data + D);
blas.VADD(D2, cur_in_data + D, checked_cell_data, cur_in_data + D);
// I_t, F_t
act_gate(D2, cur_in_data + D, cur_in_data + D);
} else {
// I_t, F_t, O_t
act_gate(D3, cur_in_data + D, cur_in_data + D);
}
// F_t * C_t-1
blas.VMUL(D, cur_in_data + D2, prev_c_data, cur_in_data + D2);
// I_t * ~C_t
blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D); blas.VMUL(D, cur_in_data, cur_in_data + D, cur_in_data + D);
// cell out= a+b // C_t = F_t * C_t-1 + I_t * ~C_t
blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data); blas.VADD(D, cur_in_data + D, cur_in_data + D2, cur_c_out_data);
if (use_peepholes) {
// + W_oc * C_t for peephole connection
blas.VMUL(D, wc_data + D2, cur_c_out_data, checked_cell_data + D2);
blas.VADD(D, cur_in_data + D3, checked_cell_data + D2,
cur_in_data + D3);
// O_t
act_gate(D, cur_in_data + D3, cur_in_data + D3);
}
// hidden out= act_state(cellout) * outgate // hidden out= act_state(cellout) * outgate
act_cell(D, cur_c_out_data, cur_in_data + D2); act_cell(D, cur_c_out_data, cur_in_data + D2);
// H_t = O_t * act_state(C_t)
blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data); blas.VMUL(D, cur_in_data + D2, cur_in_data + D3, cur_h_out_data);
cur_in_data += D4; // move to next data in same batch
cur_prev_c_data += D; next_data_in_batch();
cur_c_out_data += D;
cur_h_out_data += D;
} }
// move to data for next timestep
prev_c_data = batched_c_out_data; prev_batch_h_data = cur_batch_h_out_data;
prev_h_data = batched_h_out_data; prev_batch_c_data = cur_batch_c_out_data;
batched_c_out_data = cur_c_out_data; move_step(cur_bs);
batched_h_out_data = cur_h_out_data;
batched_input_data = cur_in_data;
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
......
...@@ -3546,11 +3546,6 @@ def topk(input, k, name=None): ...@@ -3546,11 +3546,6 @@ def topk(input, k, name=None):
top5_values, top5_indices = layers.topk(input, k=5) top5_values, top5_indices = layers.topk(input, k=5)
""" """
shape = input.shape
if k < 1 or k >= shape[-1]:
raise ValueError("k must be greater than 0 and less than %d." %
(shape[-1]))
helper = LayerHelper("top_k", **locals()) helper = LayerHelper("top_k", **locals())
values = helper.create_tmp_variable(dtype=input.dtype) values = helper.create_tmp_variable(dtype=input.dtype)
indices = helper.create_tmp_variable(dtype="int64") indices = helper.create_tmp_variable(dtype="int64")
......
...@@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest): ...@@ -58,6 +58,7 @@ class TestFusionLSTMOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.use_peepholes = False self.use_peepholes = False
self.use_seq = False
self.set_conf() self.set_conf()
T = sum(self.lod[0]) T = sum(self.lod[0])
...@@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest): ...@@ -107,6 +108,7 @@ class TestFusionLSTMOp(OpTest):
} }
self.attrs = { self.attrs = {
'use_peepholes': self.use_peepholes, 'use_peepholes': self.use_peepholes,
'use_seq': self.use_seq,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
...@@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp): ...@@ -159,5 +161,68 @@ class TestFusionLSTMOpBS1(TestFusionLSTMOp):
self.D = 16 self.D = 16
class TestFusionLSTMOpPeepholes(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
class TestFusionLSTMOpPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.is_reverse = True
class TestFusionLSTMOpPoopholesBS1(TestFusionLSTMOp):
def set_conf(self):
self.use_peepholes = True
self.lod = [[3]]
self.D = 16
class TestFusionLSTMOpSeqInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
class TestFusionLSTMOpSeqReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.is_reverse = True
class TestFusionLSTMOpSeqInitReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.has_initial_state = True
self.is_reverse = True
class TestFusionLSTMOpSeqPeepholes(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
class TestFusionLSTMOpSeqPeepholesInit(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.has_initial_state = True
class TestFusionLSTMOpSeqPeepholesReverse(TestFusionLSTMOp):
def set_conf(self):
self.use_seq = True
self.use_peepholes = True
self.is_reverse = True
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册