提交 c7adb99a 编写于 作者: T tensor-tang

follow comment and refine code

上级 f38905a6
...@@ -21,8 +21,6 @@ limitations under the License. */ ...@@ -21,8 +21,6 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
DEFINE_bool(gru_use_seq, true, "Use sequence mode");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -87,7 +85,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -87,7 +85,7 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
int xx_width; int xx_width;
if (FLAGS_gru_use_seq) { if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
...@@ -136,7 +134,10 @@ void FusionGRUOpMaker::Make() { ...@@ -136,7 +134,10 @@ void FusionGRUOpMaker::Make() {
" where T is the total time steps in this mini-batch," " where T is the total time steps in this mini-batch,"
" D is the hidden size, M is the dim size of x input.") " D is the hidden size, M is the dim size of x input.")
.AsIntermediate(); .AsIntermediate();
AddOutput("BatchedInput", "(LoDTensor) (T x 3D)").AsIntermediate(); AddOutput("BatchedInput",
"(LoDTensor) This is the batched result of input X"
"or the batched result after fc, shape (T x 3D)")
.AsIntermediate();
AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.") AddOutput("BatchedOut", "(LoDTensor) (T X D) save batched hidden.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp"); AddOutput("Hidden", "(LoDTensor) (T x D) Same as GRUOp");
...@@ -153,6 +154,10 @@ void FusionGRUOpMaker::Make() { ...@@ -153,6 +154,10 @@ void FusionGRUOpMaker::Make() {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed GRU.") "whether to compute reversed GRU.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("use_seq",
"(bool, defalut: True) "
"whether to use seq mode to compute GRU.")
.SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
The Fusion complete GRU Operator. The Fusion complete GRU Operator.
This operator fuse the fully-connected operator into GRU, This operator fuse the fully-connected operator into GRU,
...@@ -164,7 +169,7 @@ template <typename T> ...@@ -164,7 +169,7 @@ template <typename T>
class FusionGRUKernel : public framework::OpKernel<T> { class FusionGRUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
if (FLAGS_gru_use_seq) { if (ctx.Attr<bool>("use_seq")) {
SeqCompute(ctx); SeqCompute(ctx);
} else { } else {
BatchCompute(ctx); BatchCompute(ctx);
...@@ -188,31 +193,35 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -188,31 +193,35 @@ class FusionGRUKernel : public framework::OpKernel<T> {
cross = math::vec_cross<T, platform::jit::isa_any>; \ cross = math::vec_cross<T, platform::jit::isa_any>; \
} }
#define INIT_BASE_INPUT_OUTPUT \
auto* h0 = ctx.Input<Tensor>("H0"); \
auto* wx = ctx.Input<Tensor>("WeightX"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse");
#define INIT_BASE_SIZES \
auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_dims[0]; \
const int M = x_dims[1]; \
const int D = wh_dims[0]; \
const int D3 = wh_dims[1]; \
const int D2 = D * 2;
void SeqCompute(const framework::ExecutionContext& ctx) const { void SeqCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* h0 = ctx.Input<Tensor>("H0"); INIT_BASE_INPUT_OUTPUT
auto* wx = ctx.Input<Tensor>("WeightX"); INIT_BASE_SIZES
auto* wh = ctx.Input<Tensor>("WeightH");
auto* bias = ctx.Input<Tensor>("Bias");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
INIT_VEC_FUNC INIT_VEC_FUNC
auto x_lod = x->lod(); auto x_lod = x->lod();
auto x_dims = x->dims(); // T x M
auto wh_dims = wh->dims(); // D x 3D
const int N = x_lod[0].size() - 1; const int N = x_lod[0].size() - 1;
const int total_T = x_dims[0];
const int M = x_dims[1];
const int D3 = wh_dims[1];
const int D = wh_dims[0];
const int D2 = D * 2;
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* h0_data = h0 ? h0->data<T>() : NULL; const T* h0_data = h0 ? h0->data<T>() : nullptr;
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
const T* wh_data = wh->data<T>(); const T* wh_data = wh->data<T>();
const T* wh_state_data = wh_data + D * D2; const T* wh_state_data = wh_data + D * D2;
...@@ -221,7 +230,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -221,7 +230,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto blas = math::GetBlas<DeviceContext, T>(ctx); auto blas = math::GetBlas<DeviceContext, T>(ctx);
math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data, math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
xx_data, bias ? bias->data<T>() : NULL); xx_data,
bias ? bias->data<T>() : nullptr);
int xx_offset = D3; int xx_offset = D3;
int gate_offset = D; int gate_offset = D;
...@@ -239,7 +249,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -239,7 +249,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
int bid = is_reverse ? N - 1 - i : i; int bid = is_reverse ? N - 1 - i : i;
int seq_len = x_lod[0][bid + 1] - x_lod[0][bid]; int seq_len = x_lod[0][bid + 1] - x_lod[0][bid];
const T* prev_hidden_data = NULL; const T* prev_hidden_data = nullptr;
int tstart = 0; int tstart = 0;
if (h0_data) { if (h0_data) {
prev_hidden_data = h0_data + bid * D; prev_hidden_data = h0_data + bid * D;
...@@ -282,19 +292,17 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -282,19 +292,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
void BatchCompute(const framework::ExecutionContext& ctx) const { void BatchCompute(const framework::ExecutionContext& ctx) const {
using DeviceContext = paddle::platform::CPUDeviceContext; using DeviceContext = paddle::platform::CPUDeviceContext;
auto* x = ctx.Input<LoDTensor>("X"); auto* x = ctx.Input<LoDTensor>("X");
auto* wx = ctx.Input<Tensor>("WeightX"); if (x->lod()[0].size() == 2) {
auto* wh = ctx.Input<Tensor>("WeightH"); SeqCompute(ctx);
auto* bias = ctx.Input<Tensor>("Bias"); return;
auto* h0 = ctx.Input<Tensor>("H0"); }
INIT_BASE_INPUT_OUTPUT
INIT_BASE_SIZES
INIT_VEC_FUNC
auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0"); auto* reordered_h0 = ctx.Output<Tensor>("ReorderedH0");
auto* xx = ctx.Output<LoDTensor>("XX");
auto* batched_input = ctx.Output<LoDTensor>("BatchedInput"); auto* batched_input = ctx.Output<LoDTensor>("BatchedInput");
auto* batched_out = ctx.Output<LoDTensor>("BatchedOut"); auto* batched_out = ctx.Output<LoDTensor>("BatchedOut");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
bool is_reverse = ctx.Attr<bool>("is_reverse");
INIT_VEC_FUNC
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const T* wx_data = wx->data<T>(); const T* wx_data = wx->data<T>();
...@@ -304,25 +312,20 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -304,25 +312,20 @@ class FusionGRUKernel : public framework::OpKernel<T> {
T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace()); T* batched_out_data = batched_out->mutable_data<T>(ctx.GetPlace());
hidden_out->mutable_data<T>(ctx.GetPlace()); hidden_out->mutable_data<T>(ctx.GetPlace());
auto x_dims = x->dims();
auto wx_dims = wx->dims();
const int D3 = wx_dims[1];
const int D = D3 / 3;
const int D2 = D * 2;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch; math::LoDTensor2BatchFunctor<DeviceContext, T> to_batch;
if (x_dims[1] > wx_dims[1]) { if (M > D3) {
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1], math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, x_data, wx_data,
x_data, wx_data, xx_data, xx_data,
bias ? bias->data<T>() : NULL); bias ? bias->data<T>() : nullptr);
to_batch(dev_ctx, *xx, batched_input, true, is_reverse); to_batch(dev_ctx, *xx, batched_input, true, is_reverse);
} else { } else {
to_batch(dev_ctx, *x, xx, true, is_reverse); to_batch(dev_ctx, *x, xx, true, is_reverse);
batched_input->set_lod(xx->lod()); batched_input->set_lod(xx->lod());
math::FCCompute<DeviceContext, T>(blas, x_dims[0], wx_dims[1], x_dims[1], math::FCCompute<DeviceContext, T>(blas, total_T, D3, M, xx_data, wx_data,
xx_data, wx_data, batched_input_data, batched_input_data,
bias ? bias->data<T>() : NULL); bias ? bias->data<T>() : nullptr);
} }
auto batched_lod = batched_input->lod(); auto batched_lod = batched_input->lod();
...@@ -331,7 +334,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -331,7 +334,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
reordered_h0->Resize({max_bs, D}); reordered_h0->Resize({max_bs, D});
int tstart = 0; int tstart = 0;
T* prev_hidden_data = NULL; T* prev_hidden_data = nullptr;
if (h0) { if (h0) {
// reorder h0 // reorder h0
T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace()); T* reordered_h0_data = reordered_h0->mutable_data<T>(ctx.GetPlace());
...@@ -415,6 +418,8 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -415,6 +418,8 @@ class FusionGRUKernel : public framework::OpKernel<T> {
to_seq(dev_ctx, *batched_out, hidden_out); to_seq(dev_ctx, *batched_out, hidden_out);
} }
#undef INIT_VEC_FUNC #undef INIT_VEC_FUNC
#undef INIT_BASE_SIZES
#undef INIT_BASE_INPUT_OUTPUT
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册