未验证 提交 44855da3 编写于 作者: J jakpiase 提交者: GitHub

Fix for bad_alloc in oneDNN matmul_grad kernel (#48593)

* fix for matmul_grad

* another fix for matmul_grad

* fix
上级 ee4e5323
......@@ -19,37 +19,64 @@
namespace phi {
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
int new_size) {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
const std::vector<int64_t> &y_dims,
const std::vector<int64_t> &out_dims,
std::vector<int64_t> *x_bd_dims,
std::vector<int64_t> *y_bd_dims,
std::vector<int64_t> *out_bd_dims,
bool trans_x,
bool trans_y) {
if (x_dims.size() == 1) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[1];
(*x_bd_dims)[x_bd_dims->size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[x_bd_dims->size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[x_bd_dims->size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[y_bd_dims->size() - 1] = y_dims[1];
(*y_bd_dims)[y_bd_dims->size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[y_bd_dims->size() - y_dims.size() + i] = y_dims[i];
}
}
for (size_t i = 0; i < x_bd_dims->size() - 2; ++i) {
(*out_bd_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
}
int h_idx = trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2;
int w_idx = trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1;
return new_dims;
(*out_bd_dims)[x_bd_dims->size() - 2] = (*x_bd_dims)[h_idx];
(*out_bd_dims)[y_bd_dims->size() - 1] = (*y_bd_dims)[w_idx];
}
template <typename T>
void CalculateGradMatrixDims(const OneDNNContext &dev_ctx,
DenseTensor *dx_tmp,
DenseTensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i];
for (size_t i = 0; i < dx_bd_dims->size() - 2; ++i) {
if ((*dx_bd_dims)[i] != (*dy_bd_dims)[i]) {
if ((*dx_bd_dims)[i] == 1) {
(*dx_bd_dims)[i] = (*dy_bd_dims)[i];
} else {
(*dy_bd_dims)[i] = dx_dims[i];
(*dy_bd_dims)[i] = (*dx_bd_dims)[i];
}
}
}
dx_tmp->Resize(make_ddim((*dx_bd_dims)));
dx_tmp->Resize(make_ddim(*dx_bd_dims));
dev_ctx.template Alloc<T>(dx_tmp);
dy_tmp->Resize(make_ddim((*dy_bd_dims)));
dy_tmp->Resize(make_ddim(*dy_bd_dims));
dev_ctx.template Alloc<T>(dy_tmp);
}
......@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
const DenseTensor *dx_tmp,
DenseTensor *dx,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) {
const std::vector<int64_t> &x_dims) {
funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
......@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
dev_ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
x_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
......@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
template <typename T, typename Context>
......@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx,
size_t ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
}
if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
if (dout_dims.size() != ndims) {
dout_dims = ExtendDimsWithOnes(dout_dims, ndims);
}
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
DenseTensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
std::vector<int64_t> dout_bd_dims(ndims, 1);
std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims,
y_dims,
dout_dims,
&x_bd_dims,
&y_bd_dims,
&dout_bd_dims,
transpose_x,
transpose_y);
std::vector<int64_t> dx_bd_dims(x_bd_dims);
std::vector<int64_t> dy_bd_dims(y_bd_dims);
CalculateGradMatrixDims<T>(
dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
dev_ctx, &dx_tmp, &dy_tmp, &dx_bd_dims, &dy_bd_dims);
if (transpose_x && transpose_y) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp);
dev_ctx, y, dout, y_bd_dims, dout_bd_dims, true, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp);
dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, true, &dy_tmp);
} else if (transpose_x) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp);
dev_ctx, y, dout, y_bd_dims, dout_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp);
dev_ctx, x, dout, x_bd_dims, dout_bd_dims, false, false, &dy_tmp);
} else if (transpose_y) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp);
dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, false, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp);
dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, false, &dy_tmp);
} else {
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp);
dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp);
dev_ctx, x, dout, x_bd_dims, dout_bd_dims, true, false, &dy_tmp);
}
if (x_dims != dx_bd_dims) {
if (x_bd_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims()));
dev_ctx, &dx_tmp, dx, dx_bd_dims, x_bd_dims);
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
if (y_bd_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims()));
dev_ctx, &dy_tmp, dy, dy_bd_dims, y_bd_dims);
} else {
*dy = std::move(dy_tmp);
}
dx->set_mem_desc(x.mem_desc());
dx->Resize(x.dims());
dx->set_mem_desc(x.mem_desc().reshape(vectorize(x.dims())));
dy->set_mem_desc(y.mem_desc());
dy->Resize(y.dims());
dy->set_mem_desc(y.mem_desc().reshape(vectorize(y.dims())));
}
template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册