未验证 提交 4ab18ada 编写于 作者: S Sławomir Siwek 提交者: GitHub

[PHI] Migrate matmul_grad kernel (#48023)

* cleanup unused code

* unify is_int8 is_bfloat16

* Simplify matmul_v2 FWD kernel

* remove RunKernel methods

* remove import namespace

* remove headers

* clean fluid/phi cross imports

* remove fluid axpy_handler

* delete fluid methods

* activations

* OneDNNMemDesc

* MKLDNNFormatForSize

* MatchShapeToLayout

* MKLDNNMemoryFormat

* MKLDNNFormat

* ReorderMKLDNNHandler

* to_void_cast

* review suggestions

* interpolate

* remove fluid depedency

* init

* ExecuteMatMulV2

* rm fluid kernel

* matmul_grad

* remove mutable_data
上级 7073ed5b
......@@ -75,20 +75,6 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext &dev_ctx,
return output;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static paddle::framework::DDim RowMatrixDimsFromVector(
const paddle::framework::DDim &x_dim) {
return x_dim.size() > 1 ? x_dim : phi::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static paddle::framework::DDim ColumnMatrixDimsFromVector(
const paddle::framework::DDim &y_dim) {
return y_dim.size() > 1 ? y_dim : phi::make_ddim({y_dim[0], 1});
}
phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
......@@ -245,8 +231,8 @@ static void ReshapeTensorToMatrixSequence(
*/
static void ReshapeXYOutToMatrixSequence(
Tensor *x, Tensor *y, Tensor *out, bool trans_x, bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims());
auto y_dim = phi::funcs::ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y = phi::funcs::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
......@@ -304,8 +290,9 @@ std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto &MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
auto &MatrixDimsFromVector = input_name == "X"
? phi::funcs::RowMatrixDimsFromVector
: phi::funcs::ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims),
0,
......@@ -707,199 +694,6 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
};
template <typename T>
class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Input<phi::DenseTensor>("Y");
auto x_dims = vectorize(x->dims());
auto y_dims = vectorize(y->dims());
bool is_broadcast = true;
if (x_dims.size() <= 2 || y_dims.size() <= 2) {
is_broadcast = false;
} else if (x_dims.size() != y_dims.size()) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(),
x_dims.cbegin() + x_dims.size() - 2,
y_dims.cbegin());
}
// if no broadcasting is needed, we can simply use matmul's grad and avoid
// using reduce_sum
if (!is_broadcast) {
matmul_v1_grad_mkldnn_kernel.Compute(ctx);
return;
}
auto *dout = ctx.Input<phi::DenseTensor>(GradVarName("Out"));
auto *dx = ctx.Output<phi::DenseTensor>(GradVarName("X"));
auto *dy = ctx.Output<phi::DenseTensor>(GradVarName("Y"));
bool trans_x = ctx.HasAttr("trans_x") ? ctx.Attr<bool>("trans_x")
: ctx.Attr<bool>("transpose_X");
bool trans_y = ctx.HasAttr("trans_y") ? ctx.Attr<bool>("trans_y")
: ctx.Attr<bool>("transpose_Y");
auto dout_dims = vectorize(dout->dims());
size_t ndims = std::max(x->dims().size(), y->dims().size());
ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
} else if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
Tensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims(
ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (trans_x && trans_y) {
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, y, y_dims, true, dout, dout_dims, true, &dx_tmp);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, dout, dout_dims, true, x, x_dims, true, &dy_tmp);
} else if (trans_x) {
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, y, y_dims, false, dout, dout_dims, true, &dx_tmp);
ExecuteMatMulV2<T, T>(ctx,
onednn_engine,
x,
x_dims,
false,
dout,
dout_dims,
false,
&dy_tmp);
} else if (trans_y) {
ExecuteMatMulV2<T, T>(ctx,
onednn_engine,
dout,
dout_dims,
false,
y,
y_dims,
false,
&dx_tmp);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, dout, dout_dims, true, x, x_dims, false, &dy_tmp);
} else {
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, dout, dout_dims, false, y, y_dims, true, &dx_tmp);
ExecuteMatMulV2<T, T>(
ctx, onednn_engine, x, x_dims, true, dout, dout_dims, false, &dy_tmp);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput(ctx,
dev_ctx,
onednn_engine,
&dx_tmp,
dx,
x_dims,
vectorize(x->dims()));
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput(ctx,
dev_ctx,
onednn_engine,
&dy_tmp,
dy,
y_dims,
vectorize(y->dims()));
} else {
*dy = std::move(dy_tmp);
}
dx->Resize(x->dims());
dy->Resize(y->dims());
}
private:
void CalculateGradMatrixDims(const ExecutionContext &ctx,
Tensor *dx_tmp,
Tensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) const {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i];
} else {
(*dy_bd_dims)[i] = dx_dims[i];
}
}
}
dx_tmp->Resize(phi::make_ddim((*dx_bd_dims)));
dx_tmp->mutable_data<T>(ctx.GetPlace());
dy_tmp->Resize(phi::make_ddim((*dy_bd_dims)));
dy_tmp->mutable_data<T>(ctx.GetPlace());
}
void ReduceSumForMatmulGradOutput(
const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine onednn_engine,
const Tensor *dx_tmp,
Tensor *dx,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) const {
phi::funcs::ReductionOneDNNHandler<T> handler(
dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto &astream = MKLDNNDeviceContext::tls().get_stream();
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
int new_size) const {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
}
return new_dims;
}
private:
MatMulGradMKLDNNKernel<T> matmul_v1_grad_mkldnn_kernel;
};
} // anonymous namespace
REGISTER_OP_KERNEL(matmul,
......@@ -923,9 +717,3 @@ REGISTER_OP_KERNEL(matmul_v2,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_v2_grad,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2GradMKLDNNKernel<float>,
MatMulV2GradMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/axis_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/data_layout_transform.h"
#include "paddle/phi/kernels/funcs/pooling.h"
......@@ -1331,14 +1332,13 @@ class BatchNormOneDNNHandler
diff_scaleshift_data);
}
std::shared_ptr<dnnl::memory> AcquireMeanMemory(
const phi::DenseTensor* mean) {
std::shared_ptr<dnnl::memory> AcquireMeanMemory(const DenseTensor* mean) {
const T* mean_data = mean->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
to_void_cast<T>(mean_data));
}
std::shared_ptr<dnnl::memory> AcquireMeanMemory(phi::DenseTensor* mean) {
std::shared_ptr<dnnl::memory> AcquireMeanMemory(DenseTensor* mean) {
T* mean_data = mean->mutable_data<T>(this->place_,
this->fwd_pd_->mean_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->mean_desc(),
......@@ -1346,14 +1346,13 @@ class BatchNormOneDNNHandler
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
const phi::DenseTensor* variance) {
const DenseTensor* variance) {
const T* variance_data = variance->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
to_void_cast<T>(variance_data));
}
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(
phi::DenseTensor* variance) {
std::shared_ptr<dnnl::memory> AcquireVarianceMemory(DenseTensor* variance) {
T* variance_data = variance->mutable_data<T>(
this->place_, this->fwd_pd_->variance_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->variance_desc(),
......@@ -1630,5 +1629,346 @@ class PoolingOneDNNHandler
}
};
static DDim RowMatrixDimsFromVector(const DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]});
}
static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1});
}
static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank,
axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
static std::vector<int64_t> GetInputStrides(const OneDNNContext& dev_ctx,
const DDim& input_dims,
const std::string input_name,
const bool transpose_input) {
auto new_dims = input_dims;
auto shape =
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
: std::vector<int>();
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
? PADDLE_GET_CONST(
std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
: std::vector<int>();
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, transpose_input);
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = TransposeAxis(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
static bool IsOutputFused(const OneDNNContext& dev_ctx) {
const auto shape =
dev_ctx.HasDnnAttr("fused_reshape_Out")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_Out"))
: std::vector<int>();
const auto axis =
dev_ctx.HasDnnAttr("fused_transpose_Out")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_Out"))
: std::vector<int>();
return !shape.empty() && !axis.empty();
}
template <typename XT, typename YT, typename OT>
class MatmulOneDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatmulOneDNNHandler(const OneDNNContext& dev_ctx,
const std::vector<int64_t>& x_org_dims,
const std::vector<int64_t>& y_org_dims,
bool trans_x,
bool trans_y,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override,
bool is_output_fused)
: phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul>(
dev_ctx.GetEngine(), dev_ctx.GetPlace()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
}
if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
// TODO(jczaja): Why not for int8??
if (!is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, OneDNNGetDataType<OT>(), out_strides);
const auto matmul_attrs = CreateMatmulAttrs(dev_ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
float ComputeOutputScale(const OneDNNContext& dev_ctx) {
float alpha = dev_ctx.HasDnnAttr("alpha")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("alpha"))
: 1.0f;
if (dev_ctx.HasDnnAttr("Scale_x") && dev_ctx.HasDnnAttr("Scale_y") &&
dev_ctx.HasDnnAttr("Scale_out")) {
float scale_x = PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_x"));
float scale_y = PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_y"));
bool force_fp32_out =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false;
float scale_out =
force_fp32_out
? 1.f
: PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_out"));
alpha *= scale_out / (scale_x * scale_y);
}
return alpha;
}
dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext& dev_ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(dev_ctx);
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
if (dev_ctx.HasDnnInput("ResidualData")) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
auto residual_data_tz = vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
OneDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
if (dev_ctx.HasDnnAttr("Scale_in_eltwise")) {
float scale_in_eltwise =
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_in_eltwise"));
float sum_scale = scale_out / scale_in_eltwise;
post_operations.append_sum(sum_scale);
}
}
AppendActivation(dev_ctx, post_operations);
if (dev_ctx.HasDnnAttr("fused_output_scale")) {
float scale_alpha =
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"));
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor* input) {
const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(const OneDNNContext& dev_ctx,
DenseTensor* output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence DenseTensor size is bigger that
// dst memory primitive size. So would we request less memory that is there
// and it triggers an assertion. So as there is no 'any' format here we can
// leave default size of DenseTensor as computed in ComputeInferShape
OT* ptr = dev_ctx.template Alloc<OT>(output);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
template <typename T, typename T_out>
void ExecuteMatmul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
auto x_strides_override = GetInputStrides(dev_ctx, x.dims(), "X", trans_x);
auto y_strides_override = GetInputStrides(dev_ctx, y.dims(), "Y", trans_y);
MatmulOneDNNHandler<T, T, T_out> handler(dev_ctx,
x_dims,
y_dims,
trans_x,
trans_y,
x_strides_override,
y_strides_override,
IsOutputFused(dev_ctx));
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
if (dev_ctx.HasDnnInput("ResidualData")) {
auto* residual_data = dev_ctx.GetDnnInput("ResidualData");
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
}
auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(dev_ctx) && !is_int8<T_out>()) {
const auto axis =
dev_ctx.HasDnnAttr("fused_transpose_Out")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_Out"))
: std::vector<int>();
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
std::vector<int64_t> ExtendDimsWithOnes(const std::vector<int64_t> &dims,
int new_size) {
std::vector<int64_t> new_dims(new_size, 1);
for (size_t i = 0; i < dims.size(); ++i) {
new_dims[new_size - dims.size() + i] = dims[i];
}
return new_dims;
}
template <typename T>
void CalculateGradMatrixDims(const OneDNNContext &dev_ctx,
DenseTensor *dx_tmp,
DenseTensor *dy_tmp,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &dy_dims,
std::vector<int64_t> *dx_bd_dims,
std::vector<int64_t> *dy_bd_dims) {
for (size_t i = 0; i < dx_dims.size() - 2; ++i) {
if (dx_dims[i] != dy_dims[i]) {
if (dx_dims[i] == 1) {
(*dx_bd_dims)[i] = dy_dims[i];
} else {
(*dy_bd_dims)[i] = dx_dims[i];
}
}
}
dx_tmp->Resize(make_ddim((*dx_bd_dims)));
dev_ctx.template Alloc<T>(dx_tmp);
dy_tmp->Resize(make_ddim((*dy_bd_dims)));
dev_ctx.template Alloc<T>(dy_tmp);
}
template <typename T>
void ReduceSumForMatmulGradOutput(const OneDNNContext &dev_ctx,
const DenseTensor *dx_tmp,
DenseTensor *dx,
const std::vector<int64_t> &dx_dims,
const std::vector<int64_t> &squeezed_dims) {
funcs::ReductionOneDNNHandler<T> handler(dnnl::algorithm::reduction_sum,
0.0f,
0.0f,
dev_ctx.GetEngine(),
dev_ctx.GetPlace(),
dx_tmp,
dx,
dx_dims);
auto src_memory_p = handler.AcquireSrcMemory(dx_tmp);
auto dst_memory_p = handler.AcquireDstMemory(dx);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto &astream = OneDNNContext::tls().get_stream();
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc().reshape(squeezed_dims));
}
template <typename T, typename Context>
void MatmulGradKernel(const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &y,
const DenseTensor &dout,
bool transpose_x,
bool transpose_y,
DenseTensor *dx,
DenseTensor *dy) {
auto x_dims = vectorize(x.dims());
auto y_dims = vectorize(y.dims());
auto dout_dims = vectorize(dout.dims());
size_t ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max<size_t>(ndims, 3);
if (x_dims.size() != ndims) {
x_dims = ExtendDimsWithOnes(x_dims, ndims);
} else if (y_dims.size() != ndims) {
y_dims = ExtendDimsWithOnes(y_dims, ndims);
}
// in broadcasting scenario new memory is required because
// reduce sum must be calculated upon broadcasted dims
DenseTensor dx_tmp, dy_tmp;
std::vector<int64_t> dx_bd_dims(x_dims);
std::vector<int64_t> dy_bd_dims(y_dims);
CalculateGradMatrixDims<T>(
dev_ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (transpose_x && transpose_y) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, true, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, true, &dy_tmp);
} else if (transpose_x) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, y, dout, y_dims, dout_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, false, false, &dy_tmp);
} else if (transpose_y) {
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, false, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, x, dout_dims, x_dims, true, false, &dy_tmp);
} else {
funcs::ExecuteMatmul<T, T>(
dev_ctx, dout, y, dout_dims, y_dims, false, true, &dx_tmp);
funcs::ExecuteMatmul<T, T>(
dev_ctx, x, dout, x_dims, dout_dims, true, false, &dy_tmp);
}
if (x_dims != dx_bd_dims) {
ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dx_tmp, dx, x_dims, vectorize(x.dims()));
} else {
*dx = std::move(dx_tmp);
}
if (y_dims != dy_bd_dims) {
ReduceSumForMatmulGradOutput<T>(
dev_ctx, &dy_tmp, dy, y_dims, vectorize(y.dims()));
} else {
*dy = std::move(dy_tmp);
}
dx->Resize(x.dims());
dy->Resize(y.dims());
}
} // namespace phi
PD_REGISTER_KERNEL(matmul_grad,
OneDNN,
ONEDNN,
phi::MatmulGradKernel,
float,
phi::dtype::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册