提交 53d8165f 编写于 作者: G guosheng

Make GRU Operator adapt to sequence2batch

上级 83b48ebc
...@@ -66,7 +66,7 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -66,7 +66,7 @@ class GRUKernel : public framework::OpKernel<T> {
bool is_reverse = context.Attr<bool>("is_reverse"); bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch; math::LoDTensor2BatchFunctor<Place, T> to_batch;
// to_batch(context.device_context(), *input, batch_gate, is_reverse); // to_batch(context.device_context(), *input, batch_gate, is_reverse);
to_batch(context.device_context(), *input, *batch_gate, is_reverse); to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0]; int batch_size = hidden_dims[0];
...@@ -172,8 +172,8 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -172,8 +172,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad.set_lod(batch_hidden->lod()); batch_hidden_grad.set_lod(batch_hidden->lod());
// context.ShareLoD(framework::GradVarName("Hidden"), // context.ShareLoD(framework::GradVarName("Hidden"),
// framework::GradVarName("Input")); // framework::GradVarName("Input"));
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
is_reverse, false); is_reverse);
math::hl_gru_value<T> gru_value; math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data); gru_value.gateWeight = const_cast<T*>(weight_data);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册