未验证 提交 b080d986 编写于 作者: B baoachun 提交者: GitHub

fim matmul dim error (#36768)

上级 9d2e0923
...@@ -197,15 +197,15 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -197,15 +197,15 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// "embedding_fc_lstm_fuse_pass", // // "embedding_fc_lstm_fuse_pass", //
// TODO(wilber): fix correctness problem. // TODO(wilber): fix correctness problem.
// "fc_lstm_fuse_pass", // // "fc_lstm_fuse_pass", //
"mul_lstm_fuse_pass", // "mul_lstm_fuse_pass", //
"fc_gru_fuse_pass", // "fc_gru_fuse_pass", //
"mul_gru_fuse_pass", // "mul_gru_fuse_pass", //
"seq_concat_fc_fuse_pass", // "seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", // "flatten2_matmul_fuse_pass", //
"map_matmul_v2_to_mul_pass", // "map_matmul_v2_to_mul_pass", //
// "map_matmul_v2_to_matmul_pass", // "map_matmul_v2_to_matmul_pass", //
"map_matmul_to_mul_pass", // "map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", // "repeated_fc_relu_fuse_pass", //
......
...@@ -336,6 +336,7 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, ...@@ -336,6 +336,7 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
"The Input(%s) has not been initialized properly. The " "The Input(%s) has not been initialized properly. The "
"shape of Input(%s) = [%s].", "shape of Input(%s) = [%s].",
dim)); dim));
// if mkldnn reshape+transpose+matmul fuse activated
if (!shape.empty() && !axis.empty()) { if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
shape.size(), 2, shape.size(), 2,
...@@ -355,6 +356,43 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx, ...@@ -355,6 +356,43 @@ framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
"Ranks of shape_%s and axis_%s attributes of MatMulOp " "Ranks of shape_%s and axis_%s attributes of MatMulOp "
"must be equal.", "must be equal.",
input_name, input_name)); input_name, input_name));
int num_negative = std::count(shape.begin(), shape.end(), -1);
PADDLE_ENFORCE_LE(num_negative, 1,
platform::errors::InvalidArgument(
"The max number of -1 in fused_reshape_%s is 1 "
"but received %d.",
input_name, num_negative));
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(i, dim.size(),
platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d",
input_name, i, dim.size()));
shape[i] = dim.at(i);
}
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < dim.size(); i++) {
dim_product *= dim.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
dim = dim.reshape(shape).transpose(axis); dim = dim.reshape(shape).transpose(axis);
} }
return dim; return dim;
......
...@@ -245,6 +245,36 @@ class MatMulMKLDNNHandler ...@@ -245,6 +245,36 @@ class MatMulMKLDNNHandler
auto input_dims = ctx.Input<Tensor>(input_name)->dims(); auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims; auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) { if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i, input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d", input_name,
i, input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < input_dims.size(); i++) {
dim_product *= input_dims.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
new_dims = input_dims.reshape(shape).transpose(axis); new_dims = input_dims.reshape(shape).transpose(axis);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册