未验证 提交 5c4eed66 编写于 作者: L lidanqing 提交者: GitHub

Fix GRU mkldnn kernel fail on look_table_v2 (#27198)

* Fix the lookup_table_v2 failed on GRU mkldnn kernel issue
test=develop

* fix according to reviews, removed x_num_col_dims
test=develop

* update gru model. change according to reviews
test=develop

* change according to reviews
test=develop
上级 7745ad55
...@@ -342,9 +342,9 @@ if(WITH_MKLDNN) ...@@ -342,9 +342,9 @@ if(WITH_MKLDNN)
### Lexcial analysis GRU model ### Lexcial analysis GRU model
set(GRU_PATH "${INFERENCE_DEMO_INSTALL_DIR}/gru") set(GRU_PATH "${INFERENCE_DEMO_INSTALL_DIR}/gru")
download_GRU_data("${GRU_PATH}" "GRU_eval_data.tar.gz") download_GRU_data("${GRU_PATH}" "GRU_eval_data.tar.gz")
download_GRU_data("${GRU_PATH}" "GRU_eval_model.tar.gz") download_GRU_data("${GRU_PATH}" "GRU_eval_model_v2.tar.gz")
set(GRU_DATA_PATH "${GRU_PATH}/GRU_eval_data.bin") set(GRU_DATA_PATH "${GRU_PATH}/GRU_eval_data.bin")
set(GRU_MODEL_PATH "${GRU_PATH}/GRU_eval_model") set(GRU_MODEL_PATH "${GRU_PATH}/GRU_eval_model_v2")
set(LEXICAL_TEST_APP "test_analyzer_lexical_analysis") set(LEXICAL_TEST_APP "test_analyzer_lexical_analysis")
set(LEXICAL_TEST_APP_SRC "analyzer_lexical_analysis_gru_tester.cc") set(LEXICAL_TEST_APP_SRC "analyzer_lexical_analysis_gru_tester.cc")
......
...@@ -30,16 +30,18 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,16 +30,18 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru"); OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru"); OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru"); OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru"); OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru"); OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
platform::errors::InvalidArgument( ? framework::flatten_to_2d(x_dims, 1)
"Input(X)'s rank must be 2, but received input dim " : x_dims;
"size is:%d, input dim is:[%s]", PADDLE_ENFORCE_EQ(
x_dims.size(), x_dims)); x_mat_dims.size(), 2,
platform::errors::InvalidArgument("The size of input X dims should be 2, "
"or 3 with second dimension equal to "
"1, but now Input X dim is:[%s] ",
x_dims));
auto wx_dims = ctx->GetInputDim("WeightX"); auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2, PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
...@@ -47,12 +49,14 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -47,12 +49,14 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"The rank of Input(WeightX) should be 2, but received " "The rank of Input(WeightX) should be 2, but received "
"WeightX dim size is:%d, WeightX dim is:[%s] ", "WeightX dim size is:%d, WeightX dim is:[%s] ",
wx_dims.size(), wx_dims)); wx_dims.size(), wx_dims));
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1], PADDLE_ENFORCE_EQ(
platform::errors::InvalidArgument( wx_dims[0], x_mat_dims[1],
"The first dimension of Input(WeightX) " platform::errors::InvalidArgument(
"should equal to second dimension of input x, but " "The first dimension of flattened WeightX"
"received WeightX dimension is:%d, x dimension is:%d", "should equal to last dimension of flattened input X, but "
wx_dims[0], x_dims[1])); "received fattened WeightX dimension is:%d, flattened X dimension "
"is:%d",
wx_dims[0], x_mat_dims[1]));
int frame_size = wx_dims[1] / 3; int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH"); auto wh_dims = ctx->GetInputDim("WeightH");
...@@ -102,24 +106,24 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -102,24 +106,24 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"received bias dim is:[%s], frame size is:%d", "received bias dim is:[%s], frame size is:%d",
b_dims, frame_size)); b_dims, frame_size));
} }
framework::DDim out_dims({x_dims[0], frame_size}); framework::DDim out_dims({x_mat_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->ShareLoD("X", "Hidden"); ctx->ShareLoD("X", "Hidden");
int xx_width; int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) { if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1]; xx_width = wx_dims[1];
} else { } else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1]; xx_width = x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1];
OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0", OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
"fusion_gru"); "fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput", OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
"fusion_gru"); "fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut", OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
"fusion_gru"); "fusion_gru");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]}); ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims); ctx->SetOutputDim("BatchedOut", out_dims);
} }
ctx->SetOutputDim("XX", {x_dims[0], xx_width}); ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width});
ctx->ShareLoD("X", "XX"); ctx->ShareLoD("X", "XX");
} }
...@@ -220,14 +224,17 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -220,14 +224,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
} }
} }
#define INIT_BASE_DEFINES \ #define INIT_BASE_DEFINES \
auto* x = ctx.Input<LoDTensor>("X"); \ auto* x = ctx.Input<LoDTensor>("X"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \ auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \ auto* xx = ctx.Output<LoDTensor>("XX"); \
auto x_lod = x->lod(); \ auto x_lod = x->lod(); \
auto x_dims = x->dims(); /* T x M*/ \ auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \ auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \
const int total_T = x_dims[0]; \ ? framework::flatten_to_2d(x_dims, 1) \
: x_dims; \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_mat_dims[0]; \
const int D3 = wh_dims[1] const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \ #define INIT_OTHER_DEFINES \
...@@ -236,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -236,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto* bias = ctx.Input<Tensor>("Bias"); \ auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \ auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \ const int M = x_mat_dims[1]; \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const jit::gru_attr_t attr( \ const jit::gru_attr_t attr( \
......
...@@ -364,13 +364,16 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -364,13 +364,16 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
const auto* weight_h = ctx.Input<Tensor>("WeightH"); const auto* weight_h = ctx.Input<Tensor>("WeightH");
const auto* bias = ctx.Input<Tensor>("Bias"); const auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden = ctx.Output<LoDTensor>("Hidden"); auto* hidden = ctx.Output<LoDTensor>("Hidden");
auto x_dims = input->dims();
auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
? framework::flatten_to_2d(x_dims, 1)
: x_dims;
// Get attributes // Get attributes
const bool is_reverse = ctx.Attr<bool>("is_reverse"); const bool is_reverse = ctx.Attr<bool>("is_reverse");
const bool origin_mode = ctx.Attr<bool>("origin_mode"); const bool origin_mode = ctx.Attr<bool>("origin_mode");
// Get tensor dimensions // Get tensor dimensions
const auto x_dims = framework::vectorize(input->dims()); const auto x_mat_dims_vec = framework::vectorize(x_mat_dims);
const auto weight_h_dims = framework::vectorize(weight_h->dims()); const auto weight_h_dims = framework::vectorize(weight_h->dims());
const auto& input_lod = input->lod()[0]; const auto& input_lod = input->lod()[0];
...@@ -384,8 +387,8 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> { ...@@ -384,8 +387,8 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
} }
return res; return res;
}(); }();
const int64_t IC = x_dims[1]; // Input channels const int64_t IC = x_mat_dims_vec[1]; // Input channels
const int64_t OC = weight_h_dims[0]; // Output channels const int64_t OC = weight_h_dims[0]; // Output channels
GRUMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), GRUMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, weight_h, h0, is_reverse, N, Ti, IC, OC, input, weight_h, h0, is_reverse, N, Ti, IC, OC,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册