未验证 提交 1c6dcfd9 编写于 作者: S Sylwester Fraczek 提交者: GitHub

fix reshape+transpose+matmul (#40948)

上级 6a94adbe
......@@ -235,6 +235,47 @@ class MatMulMKLDNNHandler
out_strides;
};
phi::DDim GetDimForInput(const ExecutionContext& ctx,
std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
if (!shape.empty() && !axis.empty()) {
auto it_zero = std::find(shape.begin(), shape.end(), 0);
if (it_zero != shape.end()) {
for (uint64_t i = 0; i < shape.size(); i++) {
if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i, input_dims.size(),
paddle::platform::errors::InvalidArgument(
"The index of 0 in fused_reshape_%s ",
"should be less than output dim size, ",
"but the index is %d and output dim size is %d", input_name,
i, input_dims.size()));
shape[i] = input_dims.at(i);
}
}
}
// if "-1" is present then one of reshape dims must be infered
auto it_negative = std::find(shape.begin(), shape.end(), -1);
if (it_negative != shape.end()) {
int64_t dim_product = 1;
for (int i = 0; i < input_dims.size(); i++) {
dim_product *= input_dims.at(i);
}
int64_t shape_product = std::accumulate(shape.begin(), shape.end(), -1,
std::multiplies<int>());
int index = std::distance(shape.begin(), it_negative);
shape[index] = dim_product / shape_product;
}
return input_dims.reshape(shape).transpose(axis);
}
return input_dims;
}
std::pair<phi::funcs::MatDescriptor, memory::dims> GetInputDimsAndStrides(
const ExecutionContext& ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
......@@ -342,8 +383,8 @@ class MatMulMKLDNNHandler
batch_size_ = 1;
if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
auto& x_dims = ctx.Input<Tensor>("X")->dims();
auto& y_dims = ctx.Input<Tensor>("Y")->dims();
auto x_dims = GetDimForInput(ctx, "X");
auto y_dims = GetDimForInput(ctx, "Y");
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
x_bs /= batch_size_;
y_bs /= batch_size_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册