提交 bb753814 编写于 作者: G guosheng

Clean code of GRU Operator

上级 53d8165f
......@@ -51,26 +51,16 @@ class GRUKernel : public framework::OpKernel<T> {
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
// context.ShareLoD("Input", "Gate");
// context.ShareLoD("Input", "ResetHiddenPrev");
context.ShareLoD("Input", "Hidden");
// auto gate_dims = gate->dims();
auto hidden_dims = hidden->dims();
// LoDTensor batch_gate, batch_reset_hidden_prev, batch_hidden;
// batch_gate.mutable_data<T>(gate_dims, context.GetPlace());
// batch_reset_hidden_prev.mutable_data<T>(hidden_dims, context.GetPlace());
// batch_hidden.mutable_data<T>(hidden_dims, context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// to_batch(context.device_context(), *input, batch_gate, is_reverse);
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0];
// auto g = EigenMatrix<T>::From(batch_gate);
auto g = EigenMatrix<T>::From(*batch_gate);
auto place = context.GetEigenDevice<Place>();
if (bias) {
......@@ -85,20 +75,13 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data);
// auto batch_starts = batch_gate.lod()[0];
auto batch_starts = batch_gate->lod()[0];
// for (auto i = batch_gate->lod()[1].begin(); i !=
// batch_gate->lod()[1].end(); ++i)
// std::cout << static_cast<int>(*i) << ' ';
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
// Tensor gate_t = batch_gate.Slice(bstart, bend);
// Tensor reset_hidden_prev_t = batch_reset_hidden_prev.Slice(bstart,
// bend);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
......@@ -113,13 +96,6 @@ class GRUKernel : public framework::OpKernel<T> {
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
// batch_gate.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_gate, *gate);
// batch_reset_hidden_prev.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_reset_hidden_prev,
// *reset_hidden_prev);
// batch_hidden.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_hidden, *hidden);
batch_hidden->set_lod(batch_gate->lod());
to_seq(context.device_context(), *batch_hidden, *hidden);
}
......@@ -167,11 +143,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0));
// batch_hidden.set_lod(batch_gate->lod());
bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
// context.ShareLoD(framework::GradVarName("Hidden"),
// framework::GradVarName("Input"));
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
is_reverse);
......
......@@ -62,7 +62,6 @@ class TestGRUOp(OpTest):
return idx_in_seq_list
def gru_step(self, x, h_p, w, b):
# print x.shape, h_p.shape, w.shape, b.shape
batch_size = x.shape[0]
frame_size = w.shape[0]
g = x + np.tile(b, (batch_size, 1))
......@@ -96,7 +95,6 @@ class TestGRUOp(OpTest):
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
# print idx_in_seq_list[batch_idx]
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = self.gru_step(x, h_p, w, b)
if batch_idx < (num_batch - 1):
......@@ -110,9 +108,8 @@ class TestGRUOp(OpTest):
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = [[0, 2, 6, 9]] #[[0, 1, 2, 3]]
lod = [[0, 2, 6, 9]]
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
# print self.idx_in_seq_list
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册