未验证 提交 5c4eed66 编写于 作者: L lidanqing 提交者: GitHub

Fix GRU mkldnn kernel fail on look_table_v2 (#27198)

* Fix the lookup_table_v2 failed on GRU mkldnn kernel issue
test=develop

* fix according to reviews, removed x_num_col_dims
test=develop

* update gru model. change according to reviews
test=develop

* change according to reviews
test=develop
上级 7745ad55
......@@ -342,9 +342,9 @@ if(WITH_MKLDNN)
### Lexcial analysis GRU model
set(GRU_PATH "${INFERENCE_DEMO_INSTALL_DIR}/gru")
download_GRU_data("${GRU_PATH}" "GRU_eval_data.tar.gz")
download_GRU_data("${GRU_PATH}" "GRU_eval_model.tar.gz")
download_GRU_data("${GRU_PATH}" "GRU_eval_model_v2.tar.gz")
set(GRU_DATA_PATH "${GRU_PATH}/GRU_eval_data.bin")
set(GRU_MODEL_PATH "${GRU_PATH}/GRU_eval_model")
set(GRU_MODEL_PATH "${GRU_PATH}/GRU_eval_model_v2")
set(LEXICAL_TEST_APP "test_analyzer_lexical_analysis")
set(LEXICAL_TEST_APP_SRC "analyzer_lexical_analysis_gru_tester.cc")
......
......@@ -30,16 +30,18 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "fusion_gru");
OP_INOUT_CHECK(ctx->HasInput("WeightX"), "Input", "WeightX", "fusion_gru");
OP_INOUT_CHECK(ctx->HasInput("WeightH"), "Input", "WeightH", "fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("XX"), "Output", "XX", "fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("Hidden"), "Output", "Hidden", "fusion_gru");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 2,
platform::errors::InvalidArgument(
"Input(X)'s rank must be 2, but received input dim "
"size is:%d, input dim is:[%s]",
x_dims.size(), x_dims));
auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
? framework::flatten_to_2d(x_dims, 1)
: x_dims;
PADDLE_ENFORCE_EQ(
x_mat_dims.size(), 2,
platform::errors::InvalidArgument("The size of input X dims should be 2, "
"or 3 with second dimension equal to "
"1, but now Input X dim is:[%s] ",
x_dims));
auto wx_dims = ctx->GetInputDim("WeightX");
PADDLE_ENFORCE_EQ(wx_dims.size(), 2,
......@@ -47,12 +49,14 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"The rank of Input(WeightX) should be 2, but received "
"WeightX dim size is:%d, WeightX dim is:[%s] ",
wx_dims.size(), wx_dims));
PADDLE_ENFORCE_EQ(wx_dims[0], x_dims[1],
platform::errors::InvalidArgument(
"The first dimension of Input(WeightX) "
"should equal to second dimension of input x, but "
"received WeightX dimension is:%d, x dimension is:%d",
wx_dims[0], x_dims[1]));
PADDLE_ENFORCE_EQ(
wx_dims[0], x_mat_dims[1],
platform::errors::InvalidArgument(
"The first dimension of flattened WeightX"
"should equal to last dimension of flattened input X, but "
"received fattened WeightX dimension is:%d, flattened X dimension "
"is:%d",
wx_dims[0], x_mat_dims[1]));
int frame_size = wx_dims[1] / 3;
auto wh_dims = ctx->GetInputDim("WeightH");
......@@ -102,24 +106,24 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
"received bias dim is:[%s], frame size is:%d",
b_dims, frame_size));
}
framework::DDim out_dims({x_dims[0], frame_size});
framework::DDim out_dims({x_mat_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->ShareLoD("X", "Hidden");
int xx_width;
if (ctx->Attrs().Get<bool>("use_seq")) {
xx_width = wx_dims[1];
} else {
xx_width = x_dims[1] > wx_dims[1] ? wx_dims[1] : x_dims[1];
xx_width = x_mat_dims[1] > wx_dims[1] ? wx_dims[1] : x_mat_dims[1];
OP_INOUT_CHECK(ctx->HasOutput("ReorderedH0"), "Output", "ReorderedH0",
"fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("BatchedInput"), "Output", "BatchedInput",
"fusion_gru");
OP_INOUT_CHECK(ctx->HasOutput("BatchedOut"), "Output", "BatchedOut",
"fusion_gru");
ctx->SetOutputDim("BatchedInput", {x_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedInput", {x_mat_dims[0], wx_dims[1]});
ctx->SetOutputDim("BatchedOut", out_dims);
}
ctx->SetOutputDim("XX", {x_dims[0], xx_width});
ctx->SetOutputDim("XX", {x_mat_dims[0], xx_width});
ctx->ShareLoD("X", "XX");
}
......@@ -220,14 +224,17 @@ class FusionGRUKernel : public framework::OpKernel<T> {
}
}
#define INIT_BASE_DEFINES \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto x_lod = x->lod(); \
auto x_dims = x->dims(); /* T x M*/ \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_dims[0]; \
#define INIT_BASE_DEFINES \
auto* x = ctx.Input<LoDTensor>("X"); \
auto* wh = ctx.Input<Tensor>("WeightH"); \
auto* xx = ctx.Output<LoDTensor>("XX"); \
auto x_lod = x->lod(); \
auto x_dims = x->dims(); /* T x M*/ \
auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1) \
? framework::flatten_to_2d(x_dims, 1) \
: x_dims; \
auto wh_dims = wh->dims(); /* D x 3D*/ \
const int total_T = x_mat_dims[0]; \
const int D3 = wh_dims[1]
#define INIT_OTHER_DEFINES \
......@@ -236,7 +243,7 @@ class FusionGRUKernel : public framework::OpKernel<T> {
auto* bias = ctx.Input<Tensor>("Bias"); \
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); \
bool is_reverse = ctx.Attr<bool>("is_reverse"); \
const int M = x_dims[1]; \
const int M = x_mat_dims[1]; \
const int D = wh_dims[0]; \
const int D2 = D * 2; \
const jit::gru_attr_t attr( \
......
......@@ -364,13 +364,16 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
const auto* weight_h = ctx.Input<Tensor>("WeightH");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden = ctx.Output<LoDTensor>("Hidden");
auto x_dims = input->dims();
auto x_mat_dims = (x_dims.size() == 3 && x_dims[1] == 1)
? framework::flatten_to_2d(x_dims, 1)
: x_dims;
// Get attributes
const bool is_reverse = ctx.Attr<bool>("is_reverse");
const bool origin_mode = ctx.Attr<bool>("origin_mode");
// Get tensor dimensions
const auto x_dims = framework::vectorize(input->dims());
const auto x_mat_dims_vec = framework::vectorize(x_mat_dims);
const auto weight_h_dims = framework::vectorize(weight_h->dims());
const auto& input_lod = input->lod()[0];
......@@ -384,8 +387,8 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
}
return res;
}();
const int64_t IC = x_dims[1]; // Input channels
const int64_t OC = weight_h_dims[0]; // Output channels
const int64_t IC = x_mat_dims_vec[1]; // Input channels
const int64_t OC = weight_h_dims[0]; // Output channels
GRUMKLDNNHandler<T> handler(ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(),
input, weight_h, h0, is_reverse, N, Ti, IC, OC,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册