From 5c4eed66fd9ee485fa0dd47fc319aaa0830f3e17 Mon Sep 17 00:00:00 2001 From: lidanqing Date: Sat, 12 Sep 2020 12:58:16 +0200 Subject: [PATCH] Fix GRU mkldnn kernel fail on look_table_v2 (#27198) * Fix the lookup_table_v2 failed on GRU mkldnn kernel issue test=develop * fix according to reviews, removed x_num_col_dims test=develop * update gru model. change according to reviews test=develop * change according to reviews test=develop --- .../fluid/inference/tests/api/CMakeLists.txt | 4 +- paddle/fluid/operators/fused/fusion_gru_op.cc | 59 +++++++++++-------- .../fused/mkldnn/fusion_gru_mkldnn_op.cc | 11 ++-- 3 files changed, 42 insertions(+), 32 deletions(-) diff --git a/paddle/fluid/inference/tests/api/CMakeLists.txt b/paddle/fluid/inference/tests/api/CMakeLists.txt index fd4b1a54d2b..b3ec4b5714e 100644 --- a/paddle/fluid/inference/tests/api/CMakeLists.txt +++ b/paddle/fluid/inference/tests/api/CMakeLists.txt @@ -342,9 +342,9 @@ if(WITH_MKLDNN) ### Lexcial analysis GRU model set(GRU_PATH "${INFERENCE_DEMO_INSTALL_DIR}/gru") download_GRU_data("${GRU_PATH}" "GRU_eval_data.tar.gz") - download_GRU_data("${GRU_PATH}" "GRU_eval_model.tar.gz") + download_GRU_data("${GRU_PATH}" "GRU_eval_model_v2.tar.gz") set(GRU_DATA_PATH "${GRU_PATH}/GRU_eval_data.bin") - set(GRU_MODEL_PATH "${GRU_PATH}/GRU_eval_model") + set(GRU_MODEL_PATH "${GRU_PATH}/GRU_eval_model_v2") set(LEXICAL_TEST_APP "test_analyzer_lexical_analysis") set(LEXICAL_TEST_APP_SRC "analyzer_lexical_analysis_gru_tester.cc") diff --git a/paddle/fluid/operators/fused/fusion_gru_op.cc b/paddle/fluid/operators/fused/fusion_gru_op.cc index d0920098f60..f731a78f77b 100644 --- a/paddle/fluid/operators/fused/fusion_gru_op.cc +++ b/paddle/fluid/operators/fused/fusion_gru_op.cc @@ -30,16 +30,18 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru"); OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru"); OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru"); - OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru"); - auto x_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE_EQ(x_dims.size(), 2, - platform::errors::InvalidArgument( - "Input(X)'s rank must be 2, but received input dim " - "size is:%d, input dim is:[%s]", - x_dims.size(), x_dims)); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? framework::flatten_to_2d(x_dims, 1) + : x_dims; + PADDLE_ENFORCE_EQ( + x_mat_dims.size(), 2, + platform::errors::InvalidArgument("The size of input X dims should be 2, " + "or 3 with second dimension equal to " + "1, but now Input X dim is:[%s] ", + x_dims)); auto wx_dims = ctx->GetInputDim("WeightX"); PADDLE_ENFORCE_EQ(wx_dims.size(), 2, @@ -47,12 +49,14 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "The rank of Input(WeightX) should be 2, but received " "WeightX dim size is:%d, WeightX dim is:[%s] ", wx_dims.size(), wx_dims)); - PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], - platform::errors::InvalidArgument( - "The first dimension of Input(WeightX) " - "should equal to second dimension of input x, but " - "received WeightX dimension is:%d, x dimension is:%d", - wx_dims[0], x_dims[1])); + PADDLE_ENFORCE_EQ( + wx_dims[0], x_mat_dims[1], + platform::errors::InvalidArgument( + "The first dimension of flattened WeightX" + "should equal to last dimension of flattened input X, but " + "received fattened WeightX dimension is:%d, flattened X dimension " + "is:%d", + wx_dims[0], x_mat_dims[1])); int frame_size = wx_dims[1] / 3; auto wh_dims = ctx->GetInputDim("WeightH"); @@ -102,24 +106,24 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { "received bias dim is:[%s], frame size is:%d", b_dims, frame_size)); } - framework::DDim out_dims({x_dims[0], frame_size}); + framework::DDim out_dims({x_mat_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims); ctx->ShareLoD("X", "Hidden"); int xx_width; if (ctx->Attrs().Get("use_seq")) { xx_width = wx_dims[1]; } else { - xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; + xx_width = x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1]; OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", "fusion_gru"); OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput", "fusion_gru"); OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut", "fusion_gru"); - ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); + ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedOut", out_dims); } - ctx->SetOutputDim("XX", {x_dims[0], xx_width}); + ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width}); ctx->ShareLoD("X", "XX"); } @@ -220,14 +224,17 @@ class FusionGRUKernel : public framework::OpKernel { } } -#define INIT_BASE_DEFINES \ - auto* x = ctx.Input("X"); \ - auto* wh = ctx.Input("WeightH"); \ - auto* xx = ctx.Output("XX"); \ - auto x_lod = x->lod(); \ - auto x_dims = x->dims(); /* T x M*/ \ - auto wh_dims = wh->dims(); /* D x 3D*/ \ - const int total_T = x_dims[0]; \ +#define INIT_BASE_DEFINES \ + auto* x = ctx.Input("X"); \ + auto* wh = ctx.Input("WeightH"); \ + auto* xx = ctx.Output("XX"); \ + auto x_lod = x->lod(); \ + auto x_dims = x->dims(); /* T x M*/ \ + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \ + ? framework::flatten_to_2d(x_dims, 1) \ + : x_dims; \ + auto wh_dims = wh->dims(); /* D x 3D*/ \ + const int total_T = x_mat_dims[0]; \ const int D3 = wh_dims[1] #define INIT_OTHER_DEFINES \ @@ -236,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel { auto* bias = ctx.Input("Bias"); \ auto* hidden_out = ctx.Output("Hidden"); \ bool is_reverse = ctx.Attr("is_reverse"); \ - const int M = x_dims[1]; \ + const int M = x_mat_dims[1]; \ const int D = wh_dims[0]; \ const int D2 = D * 2; \ const jit::gru_attr_t attr( \ diff --git a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc index 3940aae53b8..a31fe168439 100644 --- a/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc +++ b/paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc @@ -364,13 +364,16 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { const auto* weight_h = ctx.Input("WeightH"); const auto* bias = ctx.Input("Bias"); auto* hidden = ctx.Output("Hidden"); - + auto x_dims = input->dims(); + auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) + ? framework::flatten_to_2d(x_dims, 1) + : x_dims; // Get attributes const bool is_reverse = ctx.Attr("is_reverse"); const bool origin_mode = ctx.Attr("origin_mode"); // Get tensor dimensions - const auto x_dims = framework::vectorize(input->dims()); + const auto x_mat_dims_vec = framework::vectorize(x_mat_dims); const auto weight_h_dims = framework::vectorize(weight_h->dims()); const auto& input_lod = input->lod()[0]; @@ -384,8 +387,8 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel { } return res; }(); - const int64_t IC = x_dims[1]; // Input channels - const int64_t OC = weight_h_dims[0]; // Output channels + const int64_t IC = x_mat_dims_vec[1]; // Input channels + const int64_t OC = weight_h_dims[0]; // Output channels GRUMKLDNNHandler handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, weight_h, h0, is_reverse, N, Ti, IC, OC, -- GitLab