未验证 提交 44855da3 编写于 作者: J jakpiase 提交者: GitHub

Fix for bad_alloc in oneDNN matmul_grad kernel (#48593)

* fix for matmul_grad

* another fix for matmul_grad

* fix
上级 ee4e5323
...@@ -19,37 +19,64 @@ ...@@ -19,37 +19,64 @@
namespace phi { namespace phi {
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims, void CalculateMatrixDims(const std::vector<int64_t> &x_dims,
int new_size) { const std::vector<int64_t> &y_dims,
std::vector<int64_t> new_dims(new_size, 1); const std::vector<int64_t> &out_dims,
for (size_t i = 0; i < dims.size(); ++i) { std::vector<int64_t> *x_bd_dims,
new_dims[new_size - dims.size() + i] = dims[i]; std::vector<int64_t> *y_bd_dims,
std::vector<int64_t> *out_bd_dims,
bool trans_x,
bool trans_y) {
if (x_dims.size() == 1) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[0];
} else if (x_dims.size() == 2) {
(*x_bd_dims)[x_bd_dims->size() - 1] = x_dims[1];
(*x_bd_dims)[x_bd_dims->size() - 2] = x_dims[0];
} else {
for (size_t i = 0; i < x_dims.size(); ++i) {
(*x_bd_dims)[x_bd_dims->size() - x_dims.size() + i] = x_dims[i];
}
}
if (y_dims.size() == 1) {
(*y_bd_dims)[x_bd_dims->size() - 2] = y_dims[0];
} else if (y_dims.size() == 2) {
(*y_bd_dims)[y_bd_dims->size() - 1] = y_dims[1];
(*y_bd_dims)[y_bd_dims->size() - 2] = y_dims[0];
} else {
for (size_t i = 0; i < y_dims.size(); ++i) {
(*y_bd_dims)[y_bd_dims->size() - y_dims.size() + i] = y_dims[i];
}
}
for (size_t i = 0; i < x_bd_dims->size() - 2; ++i) {
(*out_bd_dims)[i] = std::max((*x_bd_dims)[i], (*y_bd_dims)[i]);
} }
int h_idx = trans_x ? x_bd_dims->size() - 1 : x_bd_dims->size() - 2;
int w_idx = trans_y ? y_bd_dims->size() - 2 : y_bd_dims->size() - 1;
return new_dims; (*out_bd_dims)[x_bd_dims->size() - 2] = (*x_bd_dims)[h_idx];
(*out_bd_dims)[y_bd_dims->size() - 1] = (*y_bd_dims)[w_idx];
} }
template <typename T> template <typename T>
void CalculateGradMatrixDims(const OneDNNContext &dev_ctx, void CalculateGradMatrixDims(const OneDNNContext &dev_ctx,
DenseTensor *dx_tmp, DenseTensor *dx_tmp,
DenseTensor *dy_tmp, DenseTensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims, std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) { std::vector<int64_t> *dy_bd_dims) {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) { for (size_t i = 0; i < dx_bd_dims->size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) { if ((*dx_bd_dims)[i] != (*dy_bd_dims)[i]) {
if (dx_dims[i] == 1) { if ((*dx_bd_dims)[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i]; (*dx_bd_dims)[i] = (*dy_bd_dims)[i];
} else { } else {
(*dy_bd_dims)[i] = dx_dims[i]; (*dy_bd_dims)[i] = (*dx_bd_dims)[i];
} }
} }
} }
dx_tmp->Resize(make_ddim((*dx_bd_dims))); dx_tmp->Resize(make_ddim(*dx_bd_dims));
dev_ctx.template Alloc<T>(dx_tmp); dev_ctx.template Alloc<T>(dx_tmp);
dy_tmp->Resize(make_ddim((*dy_bd_dims))); dy_tmp->Resize(make_ddim(*dy_bd_dims));
dev_ctx.template Alloc<T>(dy_tmp); dev_ctx.template Alloc<T>(dy_tmp);
} }
...@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, ...@@ -58,7 +85,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
const DenseTensor *dx_tmp, const DenseTensor *dx_tmp,
DenseTensor *dx, DenseTensor *dx,
const std::vector<int64_t> &dx_dims, const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) { const std::vector<int64_t> &x_dims) {
funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum, funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum,
0.0f, 0.0f,
0.0f, 0.0f,
...@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, ...@@ -66,7 +93,7 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
dx_tmp, dx_tmp,
dx, dx,
dx_dims); x_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp); auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx); auto dst_memory_p = handler.AcquireDstMemory(dx);
...@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx, ...@@ -79,8 +106,6 @@ void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx, ...@@ -99,64 +124,67 @@ void MatmulGradKernel(const Context &dev_ctx,
size_t ndims = std::max(x_dims.size(), y_dims.size()); size_t ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max<size_t>(ndims, 3); ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
}
if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
if (dout_dims.size() != ndims) {
dout_dims = ExtendDimsWithOnes(dout_dims, ndims);
}
// in broadcasting scenario new memory is required because // in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims // reduce sum must be calculated upon broadcasted dims
DenseTensor dx_tmp, dy_tmp; DenseTensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims); std::vector<int64_t> dout_bd_dims(ndims, 1);
std::vector<int64_t> dy_bd_dims(y_dims); std::vector<int64_t> x_bd_dims(ndims, 1);
std::vector<int64_t> y_bd_dims(ndims, 1);
CalculateMatrixDims(x_dims,
y_dims,
dout_dims,
&x_bd_dims,
&y_bd_dims,
&dout_bd_dims,
transpose_x,
transpose_y);
std::vector<int64_t> dx_bd_dims(x_bd_dims);
std::vector<int64_t> dy_bd_dims(y_bd_dims);
CalculateGradMatrixDims<T>( CalculateGradMatrixDims<T>(
dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); dev_ctx, &dx_tmp, &dy_tmp, &dx_bd_dims, &dy_bd_dims);
if (transpose_x && transpose_y) { if (transpose_x && transpose_y) {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp); dev_ctx, y, dout, y_bd_dims, dout_bd_dims, true, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp); dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, true, &dy_tmp);
} else if (transpose_x) { } else if (transpose_x) {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp); dev_ctx, y, dout, y_bd_dims, dout_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp); dev_ctx, x, dout, x_bd_dims, dout_bd_dims, false, false, &dy_tmp);
} else if (transpose_y) { } else if (transpose_y) {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp); dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, false, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp); dev_ctx, dout, x, dout_bd_dims, x_bd_dims, true, false, &dy_tmp);
} else { } else {
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp); dev_ctx, dout, y, dout_bd_dims, y_bd_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>( funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp); dev_ctx, x, dout, x_bd_dims, dout_bd_dims, true, false, &dy_tmp);
} }
if (x_dims != dx_bd_dims) { if (x_bd_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput<T>( ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims())); dev_ctx, &dx_tmp, dx, dx_bd_dims, x_bd_dims);
} else { } else {
*dx = std::move(dx_tmp); *dx = std::move(dx_tmp);
} }
if (y_dims != dy_bd_dims) { if (y_bd_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput<T>( ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims())); dev_ctx, &dy_tmp, dy, dy_bd_dims, y_bd_dims);
} else { } else {
*dy = std::move(dy_tmp); *dy = std::move(dy_tmp);
} }
dx->set_mem_desc(x.mem_desc());
dx->Resize(x.dims()); dx->Resize(x.dims());
dx->set_mem_desc(x.mem_desc().reshape(vectorize(x.dims()))); dy->set_mem_desc(y.mem_desc());
dy->Resize(y.dims()); dy->Resize(y.dims());
dy->set_mem_desc(y.mem_desc().reshape(vectorize(y.dims())));
} }
template <typename T, typename Context> template <typename T, typename Context>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册