未验证 提交 e2054925 编写于 作者: S Sławomir Siwek 提交者: GitHub

oneDNN kernels code cleanup (#50743)

* matmul refactored

* fc

* SetOutMemDescWithLogicalLayoutFusesSupport

* matmul_v2

* alpha support

* group repetetive funcs

* matmul utils

* execute matmul methods

* restore registered kernel names

* split header and impl files

* remove double negatives

* increase coverage

* add onednn tests to ctest

* remove fusion logic from base matmuls
上级 8d3457f6
......@@ -16,19 +16,12 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fc_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
namespace paddle {
namespace operators {
using dnnl::inner_product_forward;
using dnnl::memory;
using dnnl::primitive;
using dnnl::prop_kind;
using dnnl::stream;
using framework::DDim;
using framework::ExecutionContext;
using phi::OneDNNContext;
using phi::funcs::OneDNNGetDataType;
......@@ -46,7 +39,7 @@ class FCMKLDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<T_in,
dnnl::inner_product_forward> {
public:
FCMKLDNNHandler(const paddle::framework::ExecutionContext& ctx,
FCMKLDNNHandler(const ExecutionContext& ctx,
const OneDNNContext& dev_ctx,
const phi::DenseTensor* x,
const phi::DenseTensor* weights,
......@@ -92,7 +85,7 @@ class FCMKLDNNHandler
const auto attrs = CreateFCAttrs(ctx);
this->AcquireForwardPrimitiveDescriptor(attrs,
prop_kind::forward_inference,
dnnl::prop_kind::forward_inference,
src_md,
weights_md,
bias_md,
......@@ -138,7 +131,7 @@ class FCMKLDNNHandler
// Compute the bias scales so that its values correspond to the
// scale of data being an output of weights and input multiplication
std::vector<float> GetBiasScales(const framework::ExecutionContext& ctx) {
std::vector<float> GetBiasScales(const ExecutionContext& ctx) {
if (ctx.HasAttr("Bias_scales")) {
return ctx.Attr<std::vector<float>>("Bias_scales");
} else {
......@@ -230,10 +223,8 @@ class FCMKLDNNHandler
}
}
// Computing MKL-DNN's scaling mask which determines along which dimension
// slice should the scaling be applied. For more data plase refer to:
// https://intel.github.io/mkl-dnn/group__c__api__attributes.html
// Section dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales
// Computing oneDNN's scaling mask which determines along which dimension
// slice should the scaling be applied.
int CreateMask(int slice_dimension, bool is_multi_channel_quantizied) {
return is_multi_channel_quantizied ? 1 << slice_dimension : 0;
}
......@@ -287,7 +278,7 @@ class FCMKLDNNHandler
}
std::shared_ptr<dnnl::memory> AcquireBiasMemoryWithReorder(
const framework::ExecutionContext& ctx, const phi::DenseTensor* bias) {
const ExecutionContext& ctx, const phi::DenseTensor* bias) {
const float* bias_data = bias->data<float>();
if (phi::funcs::is_int8<T_w>() == false) {
......@@ -366,7 +357,7 @@ class FCMKLDNNHandler
PADDLE_ENFORCE_EQ(
out->dims(),
residual_param->dims(),
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d"
" and residual param's dimension =%d .",
......@@ -391,7 +382,7 @@ class FCMKLDNNHandler
template <typename T_in>
class FCMKLDNNKernel : public framework::OpKernel<T_in> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
void Compute(const ExecutionContext& ctx) const override {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
bool fuse_relu = ctx.Attr<std::string>("activation_type") == "relu";
......@@ -410,7 +401,7 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
}));
}
void PrepareSrcMem(const std::shared_ptr<inner_product_forward>& fc_p,
void PrepareSrcMem(const std::shared_ptr<dnnl::inner_product_forward>& fc_p,
const std::shared_ptr<dnnl::memory>& src_mem,
const phi::DenseTensor* x,
const dnnl::engine& engine) const {
......@@ -427,74 +418,8 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
}
}
void SetOutMemDescWithUnsqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) const {
const std::vector<int>& fused_unsqueeze2_axes =
ctx.Attr<std::vector<int>>("fused_unsqueeze2_axes");
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}
int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
unsqueezed_op_tz[i] = op_tz[j++];
}
}
out->set_mem_desc(out_md.reshape(unsqueezed_op_tz));
out->Resize(phi::make_ddim(unsqueezed_op_tz));
}
void SetOutMemDescWithReshape2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) const {
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
void SetOutMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) const {
if (ctx.HasAttr("fused_unsqueeze2_axes")) {
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_squeeze2_axes")) {
out->set_mem_desc(out_md);
out->Resize(phi::make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}
template <typename T_out, typename T_w>
void RunKernel(const framework::ExecutionContext& ctx) const {
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<OneDNNContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
......@@ -601,10 +526,15 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
dev_ctx.SetBlob(cache_key, ip_cache);
}
SetOutMemDescWithLogicalLayoutFusesSupport(
ctx,
out,
dst_memory_p->get_desc().reshape(phi::vectorize(out->dims())));
const auto out_md =
dst_memory_p->get_desc().reshape(phi::vectorize(out->dims()));
if (ctx.HasAttr("fused_reshape2_shape")) {
phi::funcs::SetOutMemDescWithReshape2FuseSupport(
ctx.Attr<std::vector<int>>("fused_reshape2_shape"), out, out_md);
} else {
out->set_mem_desc(out_md);
}
}
void RecomputeOutputDims(const ExecutionContext& ctx,
......@@ -615,8 +545,8 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
bool padding_weights = ctx.Attr<bool>("padding_weights");
PADDLE_ENFORCE_EQ(padding_weights,
false,
platform::errors::PermissionDenied(
"Weight padding in fc can not be used in MKLDNN."));
phi::errors::PermissionDenied(
"Weight padding in fc can not be used in oneDNN."));
std::vector<int64_t> output_dims;
FCOutputSize(x->dims(),
weights->dims(),
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/backends/onednn/matmul_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace {
......@@ -27,7 +27,7 @@ using phi::funcs::OneDNNGetDataType;
// Reshape a rank-3 tensor from P x M x N to (P * M) x N.
// Identity op if the tensor is not of rank 3.
static phi::DenseTensor FoldOuterDims(const phi::DenseTensor &input) {
phi::DenseTensor FoldOuterDims(const phi::DenseTensor &input) {
auto output = input;
auto in_dims = input.dims();
if (in_dims.size() == 3) {
......@@ -40,8 +40,8 @@ static phi::DenseTensor FoldOuterDims(const phi::DenseTensor &input) {
// (Warning: This requires transposing data and writes into new memory.)
// Identity op if the tensor is not of rank 3.
template <typename T>
static phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
const phi::DenseTensor *input) {
phi::DenseTensor FoldFirstAndLastDims(const OneDNNContext &dev_ctx,
const phi::DenseTensor *input) {
auto input_dims = vectorize(input->dims());
if (input_dims.size() != 3) {
return *input;
......@@ -82,12 +82,12 @@ phi::DDim GetDimForInput(const ExecutionContext &ctx, std::string input_name) {
}
template <typename XT, typename YT, typename OT>
class MatMulV2MKLDNNHandler
class MatMulV1OneDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const ExecutionContext &ctx,
MatMulV1OneDNNHandler(const ExecutionContext &ctx,
const dnnl::engine engine,
paddle::platform::Place cpu_place,
phi::Place cpu_place,
const std::vector<int64_t> &x_org_dims,
bool trans_x,
const std::vector<int64_t> &y_org_dims,
......@@ -121,24 +121,24 @@ class MatMulV2MKLDNNHandler
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
if (x_strides_override.empty()) {
if (trans_x) {
x_strides.insert(x_strides.end(), {M * K, 1, M});
} else {
x_strides.insert(x_strides.end(), {M * K, K, 1});
}
} else {
x_strides = x_strides_override;
}
if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
if (y_strides_override.empty()) {
if (trans_y) {
y_strides.insert(y_strides.end(), {N * K, 1, K});
} else {
y_strides.insert(y_strides.end(), {N * K, N, 1});
}
} else {
y_strides = y_strides_override;
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
......@@ -158,7 +158,8 @@ class MatMulV2MKLDNNHandler
// TODO(jczaja): Why not for int8??
if (!phi::funcs::is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
std::vector<int> transpose_axis = {0, 2, 1, 3};
out_strides = phi::funcs::FakeTransposeStrides(out_ddims, transpose_axis);
}
auto x_md =
......@@ -221,24 +222,6 @@ class MatMulV2MKLDNNHandler
return matmul_attrs;
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t> &matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const phi::DenseTensor *input) {
const YT *input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(
......@@ -260,11 +243,11 @@ class MatMulV2MKLDNNHandler
};
template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler
class MatMulOneDNNHandler
: public phi::funcs::OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const dnnl::engine engine,
paddle::platform::Place cpu_place,
MatMulOneDNNHandler(const dnnl::engine engine,
phi::Place cpu_place,
phi::DenseTensor *x,
bool trans_x,
phi::DenseTensor *y,
......@@ -312,42 +295,6 @@ class MatMulMKLDNNHandler
phi::funcs::to_void_cast<YT>(input_data));
}
public:
void Execute(const phi::DenseTensor *x,
const phi::DenseTensor *y,
phi::DenseTensor *out) {
const auto src_memory_p = this->AcquireSrcMemory(x);
const auto weights_memory_p = this->AcquireWeightsMemory(y);
const auto dst_memory_p = this->AcquireDstMemory(out);
auto matmul_p = this->AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto &astream = OneDNNContext::tls().get_stream();
// Simulate batch matmul by processing in loop
void *x_ptr = src_memory_p->get_data_handle();
void *y_ptr = weights_memory_p->get_data_handle();
void *out_ptr = dst_memory_p->get_data_handle();
auto offsets = std::make_tuple(x_offset_, y_offset_, out_offset_);
for (uint16_t i = 0; i < batch_size_; ++i) {
src_memory_p->set_data_handle(x_ptr);
weights_memory_p->set_data_handle(y_ptr);
dst_memory_p->set_data_handle(out_ptr);
matmul_p->execute(astream, matmul_args);
x_ptr = static_cast<char *>(x_ptr) + std::get<0>(offsets);
y_ptr = static_cast<char *>(y_ptr) + std::get<1>(offsets);
out_ptr = static_cast<char *>(out_ptr) + std::get<2>(offsets);
}
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc().reshape(out->dims()));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(phi::DenseTensor *output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
......@@ -359,12 +306,6 @@ class MatMulMKLDNNHandler
OT *ptr = output->mutable_data<OT>(this->place_);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
private:
uint32_t x_offset_;
uint32_t y_offset_;
uint32_t out_offset_;
uint16_t batch_size_;
};
/**
......@@ -373,7 +314,7 @@ class MatMulMKLDNNHandler
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorToMatrixSequence(
void ReshapeTensorToMatrixSequence(
phi::DenseTensor *x, const phi::funcs::MatDescriptor &descriptor) {
int64_t h, w;
h = descriptor.height_;
......@@ -402,11 +343,11 @@ static void ReshapeTensorToMatrixSequence(
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(phi::DenseTensor *x,
phi::DenseTensor *y,
phi::DenseTensor *out,
bool trans_x,
bool trans_y) {
void ReshapeXYOutToMatrixSequence(phi::DenseTensor *x,
phi::DenseTensor *y,
phi::DenseTensor *out,
bool trans_x,
bool trans_y) {
auto x_dim = phi::funcs::RowMatrixDimsFromVector(x->dims());
auto y_dim = phi::funcs::ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x = phi::funcs::CreateMatrixDescriptor(x_dim, 0, trans_x);
......@@ -423,78 +364,6 @@ static void ReshapeXYOutToMatrixSequence(phi::DenseTensor *x,
ReshapeTensorToMatrixSequence(y, mat_dim_y);
}
std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
const std::vector<int> &axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank,
axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::vector<int64_t> GetInputStrides(const ExecutionContext &ctx,
const std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto &MatrixDimsFromVector = input_name == "X"
? phi::funcs::RowMatrixDimsFromVector
: phi::funcs::ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims),
0,
ctx.HasAttr("trans_x")
? ctx.Attr<bool>(std::string("trans_") +
static_cast<char>(std::tolower(input_name[0])))
: ctx.Attr<bool>(std::string("transpose_") + input_name[0]));
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = Transpose(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
bool IsOutputFused(const ExecutionContext &ctx) {
auto &fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto &fused_transpose_Out = ctx.Attr<std::vector<int>>("fused_transpose_Out");
......@@ -502,7 +371,7 @@ bool IsOutputFused(const ExecutionContext &ctx) {
}
template <typename T, typename T_out>
void ExecuteMatMulV2(const ExecutionContext &ctx,
void ExecuteMatMulV1(const ExecutionContext &ctx,
const dnnl::engine onednn_engine,
const phi::DenseTensor *x,
const std::vector<int64_t> &x_dims,
......@@ -511,9 +380,20 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
const std::vector<int64_t> &y_dims,
bool trans_y,
phi::DenseTensor *out) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T, T, T_out> handler(ctx,
std::vector<int64_t> x_strides_override = phi::funcs::GetInputStrides(
"X",
x->dims(),
trans_x,
ctx.Attr<std::vector<int>>("fused_reshape_X"),
ctx.Attr<std::vector<int>>("fused_transpose_X"));
std::vector<int64_t> y_strides_override = phi::funcs::GetInputStrides(
"Y",
y->dims(),
trans_y,
ctx.Attr<std::vector<int>>("fused_reshape_Y"),
ctx.Attr<std::vector<int>>("fused_transpose_Y"));
MatMulV1OneDNNHandler<T, T, T_out> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
......@@ -523,7 +403,6 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
const auto dst_memory_p = handler.AcquireDstMemory(out);
......@@ -566,7 +445,7 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
paddle::platform::errors::Unimplemented(
phi::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
......@@ -576,7 +455,6 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
const bool force_fp32_output = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
const auto &dev_ctx = ctx.template device_context<OneDNNContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
......@@ -601,7 +479,7 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
CalculateMatrixDims(ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, out);
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
ExecuteMatMulV2<T, float>(ctx,
ExecuteMatMulV1<T, float>(ctx,
onednn_engine,
x,
x_bd_dims,
......@@ -611,7 +489,7 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
trans_y,
out);
} else if (is_bfloat16) {
ExecuteMatMulV2<T, paddle::platform::bfloat16>(ctx,
ExecuteMatMulV1<T, paddle::platform::bfloat16>(ctx,
onednn_engine,
x,
x_bd_dims,
......@@ -620,18 +498,8 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
y_bd_dims,
trans_y,
out);
} else if (fuse_relu) {
ExecuteMatMulV2<T, uint8_t>(ctx,
onednn_engine,
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out);
} else {
ExecuteMatMulV2<T, int8_t>(ctx,
ExecuteMatMulV1<T, int8_t>(ctx,
onednn_engine,
x,
x_bd_dims,
......@@ -678,7 +546,7 @@ class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
(*x_bd_dims)[i] == (*y_bd_dims)[i] || (*x_bd_dims)[i] == 1 ||
(*y_bd_dims)[i] == 1,
true,
paddle::platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"phi::DenseTensor dimensions are incorrect for broadcasting."
"Dimensions in X and Y must be same or equal to 1, but "
"received x_dim[%d]=%d and y_dims[%d]= %d",
......@@ -701,7 +569,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
paddle::platform::errors::Unimplemented(
phi::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
......@@ -728,7 +596,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
ReshapeXYOutToMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
paddle::framework::DDim dx_dims;
phi::DDim dx_dims;
if (dx) {
dx_dims = dx->dims();
if (dx_dims != x.dims()) {
......@@ -736,7 +604,7 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
}
}
paddle::framework::DDim dy_dims;
phi::DDim dy_dims;
if (dy) {
dy_dims = dy->dims();
if (dy_dims != y.dims()) {
......@@ -838,16 +706,14 @@ class MatMulGradMKLDNNKernel : public paddle::framework::OpKernel<T> {
: FoldFirstAndLastDims<T>(dev_ctx, y);
}
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
MatMulMKLDNNHandler<T, T, T> handler(engine,
MatMulOneDNNHandler<T, T, T> handler(engine,
ctx.GetPlace(),
&x_combined,
trans_x,
&y_combined,
trans_y,
out,
alpha);
ctx.Attr<float>("alpha"));
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
......
......@@ -60,7 +60,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
phi::funcs::GetPlainOneDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = FakeTranposeStrides(dst_md, transpose_axis);
auto dst_strides =
phi::funcs::FakeTransposeStrides(dst_md.dims(), transpose_axis);
dst_md =
dnnl::memory::desc(x_vec_dims, x->mem_desc().data_type(), dst_strides);
......@@ -77,36 +78,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
}
private:
// it is needed because oneDNN's permute axis understand axes order in
// different way PaddlePaddle's transpose
std::vector<int> TransposeToPermuteAxis(
const std::vector<int>& transpose_axis) const {
std::vector<int> permute_axis(transpose_axis.size());
for (size_t i = 0; i < transpose_axis.size(); ++i) {
permute_axis[transpose_axis[i]] = i;
}
return permute_axis;
}
std::vector<int64_t> FakeTranposeStrides(
const dnnl::memory::desc& dst_md,
const std::vector<int>& transpose_axis) const {
std::vector<int64_t> fake_strides(transpose_axis.size());
auto dims = dst_md.dims();
int total_stride = 1;
int ndims = static_cast<int>(dims.size());
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= dims[transpose_axis[i]];
}
return fake_strides;
phi::funcs::TransposeToPermuteAxes(transpose_axis)));
}
};
......
......@@ -33,6 +33,7 @@ endif()
if(WITH_MKLDNN)
list(APPEND BACKENDS_SRCS onednn/onednn_context.cc)
list(APPEND BACKENDS_SRCS onednn/axpy_handler.cc)
list(APPEND BACKENDS_SRCS onednn/matmul_utils.cc)
list(APPEND BACKENDS_DEPS mkldnn)
endif()
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/onednn/matmul_utils.h"
namespace phi {
namespace funcs {
DDim RowMatrixDimsFromVector(const DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]});
}
DDim ColumnMatrixDimsFromVector(const DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1});
}
std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(
axis_set.size(),
axis_size,
errors::InvalidArgument("In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::vector<int64_t> GetInputStrides(const std::string input_name,
const DDim& input_dims,
const bool transpose_input,
std::vector<int> shape,
std::vector<int> axis) {
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
MatDescriptor mat_dim = CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, transpose_input);
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = TransposeAxis(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
} // namespace funcs
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/onednn/onednn_reuse.h"
namespace phi {
namespace funcs {
DDim RowMatrixDimsFromVector(const DDim& x_dim);
DDim ColumnMatrixDimsFromVector(const DDim& y_dim);
std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
const std::vector<int>& axis);
std::vector<int64_t> GetInputStrides(const std::string input_name,
const DDim& input_dims,
const bool transpose_input,
std::vector<int> shape,
std::vector<int> axis);
template <typename XT, typename YT, typename OT>
class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatmulOneDNNHandler(const OneDNNContext& dev_ctx,
const std::vector<int64_t>& x_org_dims,
const std::vector<int64_t>& y_org_dims,
bool trans_x,
bool trans_y)
: OneDNNHandlerNoCachingT<XT, dnnl::matmul>(dev_ctx.GetEngine(),
dev_ctx.GetPlace()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (trans_x) {
x_strides.insert(x_strides.end(), {M * K, 1, M});
} else {
x_strides.insert(x_strides.end(), {M * K, K, 1});
}
if (trans_y) {
y_strides.insert(y_strides.end(), {N * K, 1, K});
} else {
y_strides.insert(y_strides.end(), {N * K, N, 1});
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, OneDNNGetDataType<OT>(), out_strides);
const auto matmul_attrs = CreateMatmulAttrs(dev_ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext& dev_ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = dev_ctx.HasDnnAttr("alpha")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("alpha"))
: 1.0f;
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor* input) {
const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(const OneDNNContext& dev_ctx,
DenseTensor* output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence DenseTensor size is bigger that
// dst memory primitive size. So would we request less memory that is there
// and it triggers an assertion. So as there is no 'any' format here we can
// leave default size of DenseTensor as computed in ComputeInferShape
OT* ptr = dev_ctx.template Alloc<OT>(output);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
template <typename T>
inline void ExecuteMul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
MatmulOneDNNHandler<T, T, T> handler(
dev_ctx, x_dims, y_dims, trans_x, trans_y);
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
template <typename T, typename T_out>
inline void ExecuteMatmul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
auto shape_x = dev_ctx.HasDnnAttr("fused_reshape_X")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_X"))
: std::vector<int>();
auto axis_x = dev_ctx.HasDnnAttr("fused_transpose_X")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_X"))
: std::vector<int>();
auto shape_y = dev_ctx.HasDnnAttr("fused_reshape_Y")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_Y"))
: std::vector<int>();
auto axis_y = dev_ctx.HasDnnAttr("fused_transpose_Y")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_Y"))
: std::vector<int>();
auto x_strides_override =
GetInputStrides("X", x.dims(), trans_x, shape_x, shape_x);
auto y_strides_override =
GetInputStrides("Y", y.dims(), trans_y, shape_y, axis_y);
MatmulOneDNNHandler<T, T, T_out> handler(
dev_ctx, x_dims, y_dims, trans_x, trans_y);
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
} // namespace funcs
} // namespace phi
......@@ -50,6 +50,31 @@ constexpr bool is_bfloat16() {
return std::is_same<T, dtype::bfloat16>::value;
}
// oneDNN's permute axis understand axes order in
// different way than PaddlePaddle's transpose
static std::vector<int> TransposeToPermuteAxes(const std::vector<int>& axis) {
std::vector<int> permute_axis(axis.size());
for (size_t i = 0; i < axis.size(); ++i) {
permute_axis[axis[i]] = i;
}
return permute_axis;
}
static std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& out_dims, const std::vector<int>& axis) {
std::vector<int64_t> fake_strides(axis.size());
int ndims = static_cast<int>(axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[axis[i]] = total_stride;
total_stride *= out_dims[axis[i]];
}
return fake_strides;
}
static std::unordered_map<std::string, dnnl::algorithm> OneDNNActivationMap() {
return {{"abs", dnnl::algorithm::eltwise_abs},
{"clip", dnnl::algorithm::eltwise_clip},
......@@ -1645,425 +1670,5 @@ static void SetOutMemDescWithReshape2FuseSupport(
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
static void SetOutMemDescWithLogicalLayoutFusesSupport(
const OneDNNContext& dev_ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const auto fused_unsqueeze2_axes =
dev_ctx.HasDnnAttr("fused_unsqueeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_unsqueeze2_axes"))
: std::vector<int>();
const auto fused_reshape2_shape =
dev_ctx.HasDnnAttr("fused_reshape2_shape")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape2_shape"))
: std::vector<int>();
const auto fused_squeeze2_axes =
dev_ctx.HasDnnAttr("fused_squeeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_squeeze2_axes"))
: std::vector<int>();
if (!fused_unsqueeze2_axes.empty()) {
SetOutMemDescWithUnsqueeze2FuseSupport(fused_unsqueeze2_axes, out, out_md);
} else if (!fused_reshape2_shape.empty()) {
SetOutMemDescWithReshape2FuseSupport(fused_reshape2_shape, out, out_md);
} else if (!fused_squeeze2_axes.empty()) {
out->set_mem_desc(out_md);
out->Resize(make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}
static DDim RowMatrixDimsFromVector(const DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]});
}
static DDim ColumnMatrixDimsFromVector(const DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1});
}
static std::vector<int64_t> TransposeAxis(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
phi::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
phi::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
static std::vector<int64_t> GetInputStrides(const OneDNNContext& dev_ctx,
const DDim& input_dims,
const std::string input_name,
const bool transpose_input) {
auto new_dims = input_dims;
auto shape =
dev_ctx.HasDnnAttr("fused_reshape_" + input_name)
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_" + input_name))
: std::vector<int>();
auto axis = dev_ctx.HasDnnAttr("fused_transpose_" + input_name)
? PADDLE_GET_CONST(
std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_" + input_name))
: std::vector<int>();
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto& MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
MatDescriptor mat_dim = CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, transpose_input);
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = TransposeAxis(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
static bool IsOutputFused(const OneDNNContext& dev_ctx) {
const auto shape =
dev_ctx.HasDnnAttr("fused_reshape_Out")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape_Out"))
: std::vector<int>();
const auto axis =
dev_ctx.HasDnnAttr("fused_transpose_Out")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_Out"))
: std::vector<int>();
return !shape.empty() && !axis.empty();
}
template <typename XT, typename YT, typename OT>
class MatmulOneDNNHandler : public OneDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatmulOneDNNHandler(const OneDNNContext& dev_ctx,
const std::vector<int64_t>& x_org_dims,
const std::vector<int64_t>& y_org_dims,
bool trans_x,
bool trans_y,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override,
bool is_output_fused)
: OneDNNHandlerNoCachingT<XT, dnnl::matmul>(dev_ctx.GetEngine(),
dev_ctx.GetPlace()) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
const int MB_idx = x_dims.size() - 3;
const int H_idx = x_dims.size() - 2;
const int W_idx = x_dims.size() - 1;
if (trans_x) std::swap(x_dims[H_idx], x_dims[W_idx]);
if (trans_y) std::swap(y_dims[H_idx], y_dims[W_idx]);
const memory::dim M = x_dims[H_idx];
const memory::dim K = x_dims[W_idx];
const memory::dim N = y_dims[W_idx];
std::vector<int64_t> x_strides(x_dims.size() - 3, 1);
std::vector<int64_t> y_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_strides(x_dims.size() - 3, 1);
std::vector<int64_t> out_ddims(x_dims.size() - 3, 1);
x_strides.reserve(x_dims.size());
y_strides.reserve(x_dims.size());
out_strides.reserve(x_dims.size());
if (!x_strides_override.empty()) {
x_strides = x_strides_override;
} else {
if (!trans_x) {
x_strides.insert(x_strides.end(), {M * K, K, 1});
} else {
x_strides.insert(x_strides.end(), {M * K, 1, M});
}
}
if (!y_strides_override.empty()) {
y_strides = y_strides_override;
} else {
if (!trans_y) {
y_strides.insert(y_strides.end(), {N * K, N, 1});
} else {
y_strides.insert(y_strides.end(), {N * K, 1, K});
}
}
out_strides.insert(out_strides.end(), {M * N, N, 1});
out_ddims.insert(out_ddims.end(),
{std::max(x_dims[MB_idx], y_dims[MB_idx]), M, N});
for (int i = x_dims.size() - 4; i >= 0; --i) {
out_ddims[i] = std::max(x_dims[i], y_dims[i]);
if (x_strides_override.empty()) {
x_strides[i] = x_dims[i + 1] * x_strides[i + 1];
}
if (y_strides_override.empty()) {
y_strides[i] = y_dims[i + 1] * y_strides[i + 1];
}
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
// TODO(jczaja): Why not for int8??
if (!is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, OneDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, OneDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, OneDNNGetDataType<OT>(), out_strides);
const auto matmul_attrs = CreateMatmulAttrs(dev_ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
float ComputeOutputScale(const OneDNNContext& dev_ctx) {
float alpha = dev_ctx.HasDnnAttr("alpha")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("alpha"))
: 1.0f;
if (dev_ctx.HasDnnAttr("Scale_x") && dev_ctx.HasDnnAttr("Scale_y") &&
dev_ctx.HasDnnAttr("Scale_out")) {
float scale_x = PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_x"));
float scale_y = PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_y"));
bool force_fp32_out =
dev_ctx.HasDnnAttr("force_fp32_output")
? PADDLE_GET_CONST(bool, dev_ctx.GetDnnAttr("force_fp32_output"))
: false;
float scale_out =
force_fp32_out
? 1.f
: PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_out"));
alpha *= scale_out / (scale_x * scale_y);
}
return alpha;
}
dnnl::primitive_attr CreateMatmulAttrs(const OneDNNContext& dev_ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(dev_ctx);
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;
if (residual_data) {
auto residual_data_tz = vectorize(residual_data->dims());
auto residual_data_md = memory::desc(residual_data_tz,
OneDNNGetDataType<OT>(),
dnnl::memory::format_tag::any);
post_operations.append_binary(dnnl::algorithm::binary_add,
residual_data_md);
if (dev_ctx.HasDnnAttr("Scale_in_eltwise")) {
float scale_in_eltwise =
PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("Scale_in_eltwise"));
float sum_scale = scale_out / scale_in_eltwise;
post_operations.append_sum(sum_scale);
}
}
AppendActivation(dev_ctx, post_operations);
const float scale_alpha =
dev_ctx.HasDnnAttr("fused_output_scale")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("fused_output_scale"))
: 1.0f;
if (scale_alpha != 1.0f) {
post_operations.append_eltwise(
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
}
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor* input) {
const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(const OneDNNContext& dev_ctx,
DenseTensor* output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence DenseTensor size is bigger that
// dst memory primitive size. So would we request less memory that is there
// and it triggers an assertion. So as there is no 'any' format here we can
// leave default size of DenseTensor as computed in ComputeInferShape
OT* ptr = dev_ctx.template Alloc<OT>(output);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
template <typename T>
static void ExecuteMul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
static const std::vector<int64_t> vec_placeholder;
MatmulOneDNNHandler<T, T, T> handler(dev_ctx,
x_dims,
y_dims,
trans_x,
trans_y,
vec_placeholder,
vec_placeholder,
false);
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
// This kernel is flattening dims so then we need to unflattened version
// that should be set in out reshape require plain layout, but
// MatmulV2MKLDNNHanlder enforces one so it should work
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
template <typename T, typename T_out>
void ExecuteMatmul(const OneDNNContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
const std::vector<int64_t>& x_dims,
const std::vector<int64_t>& y_dims,
bool trans_x,
bool trans_y,
DenseTensor* out) {
auto x_strides_override = GetInputStrides(dev_ctx, x.dims(), "X", trans_x);
auto y_strides_override = GetInputStrides(dev_ctx, y.dims(), "Y", trans_y);
MatmulOneDNNHandler<T, T, T_out> handler(dev_ctx,
x_dims,
y_dims,
trans_x,
trans_y,
x_strides_override,
y_strides_override,
IsOutputFused(dev_ctx));
const auto src_memory_p = handler.AcquireSrcMemory(&x);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y);
const auto dst_memory_p = handler.AcquireDstMemory(dev_ctx, out);
auto matmul_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
const auto* residual_data = dev_ctx.HasDnnInput("ResidualData")
? dev_ctx.GetDnnInput("ResidualData")
: nullptr;
if (residual_data) {
const auto residual_data_memory_p = handler.AcquireSrcMemory(residual_data);
matmul_args.insert({DNNL_ARG_ATTR_MULTIPLE_POST_OP(0) | DNNL_ARG_SRC_1,
*residual_data_memory_p});
}
auto& astream = OneDNNContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
// TODO(jczaja): Explain why int8 format of dst is ABCD and do not need
// permute
if (IsOutputFused(dev_ctx) && !is_int8<T_out>()) {
const auto axis =
dev_ctx.HasDnnAttr("fused_transpose_Out")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_transpose_Out"))
: std::vector<int>();
auto permuted_md = dst_memory_p->get_desc().permute_axes(axis);
out->set_mem_desc(permuted_md.reshape(vectorize<int64_t>(out->dims())));
} else {
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
}
} // namespace funcs
} // namespace phi
......@@ -14,7 +14,7 @@
#include <string>
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/backends/onednn/matmul_utils.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
......@@ -113,7 +113,8 @@ class FusedMatmulOneDNNHandler
// TODO(jczaja): Why not for int8??
if (!funcs::is_int8<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
std::vector<int> transpose_axis = {0, 2, 1, 3};
out_strides = phi::funcs::FakeTransposeStrides(out_ddims, transpose_axis);
}
auto x_md = memory::desc(x_dims, funcs::OneDNNGetDataType<XT>(), x_strides);
......@@ -198,24 +199,6 @@ class FusedMatmulOneDNNHandler
return matmul_attrs;
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t> &matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const DenseTensor *input) {
const YT *input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(
......@@ -236,80 +219,6 @@ class FusedMatmulOneDNNHandler
}
};
static DDim RowMatrixDimsFromVector(const DDim &x_dim) {
return x_dim.size() > 1 ? x_dim : make_ddim({1, x_dim[0]});
}
static DDim ColumnMatrixDimsFromVector(const DDim &y_dim) {
return y_dim.size() > 1 ? y_dim : make_ddim({y_dim[0], 1});
}
static std::vector<int64_t> TransposeAxis(const std::vector<int64_t> &x,
const std::vector<int> &axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
phi::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank,
axis_size,
phi::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
phi::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
static std::vector<int64_t> GetInputStrides(const std::string input_name,
const DDim &input_dims,
std::vector<int> shape,
std::vector<int> axis,
const bool transpose_input) {
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto &MatrixDimsFromVector =
input_name == "X" ? RowMatrixDimsFromVector : ColumnMatrixDimsFromVector;
funcs::MatDescriptor mat_dim = funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims), 0, transpose_input);
std::vector<int64_t> strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(),
strides.front() * static_cast<int64_t>(shape2[i]));
}
strides = TransposeAxis(strides, axis);
if (shape.size() == 2)
strides.insert(strides.begin(),
static_cast<int64_t>(shape[0] * shape[1]));
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return strides;
}
template <typename T, typename T_out>
void ExecuteFusedMatmul(const OneDNNContext &dev_ctx,
const DenseTensor &x,
......@@ -492,10 +401,10 @@ void FusedMatmulKernel(const Context &dev_ctx,
auto is_output_fused =
!fused_reshape_Out.empty() && !fused_transpose_Out.empty();
auto x_strides_override = GetInputStrides(
"X", x.dims(), fused_reshape_X, fused_transpose_X, transpose_x);
auto y_strides_override = GetInputStrides(
"Y", y.dims(), fused_reshape_Y, fused_transpose_Y, transpose_y);
auto x_strides_override = funcs::GetInputStrides(
"X", x.dims(), transpose_x, fused_reshape_X, fused_transpose_X);
auto y_strides_override = funcs::GetInputStrides(
"Y", y.dims(), transpose_y, fused_reshape_Y, fused_transpose_Y);
int ndims = std::max(x_dims.size(), y_dims.size());
ndims = std::max(ndims, 3);
......
......@@ -113,15 +113,26 @@ void ElementwiseKernel(const OneDNNContext& dev_ctx,
binary_prim->execute(astream, args);
astream.wait();
if (handler.use_broadcasting_hack == false) {
funcs::SetOutMemDescWithLogicalLayoutFusesSupport(
dev_ctx, out, dst_memory->get_desc());
} else {
auto dims = dst_memory->get_desc().dims();
auto out_md = dst_memory->get_desc();
if (handler.use_broadcasting_hack) {
auto dims = out_md.dims();
dims.insert(dims.begin(), non_const_x->dims()[0]);
dims[1] /= dims[0];
funcs::SetOutMemDescWithLogicalLayoutFusesSupport(
dev_ctx, out, dst_memory->get_desc().reshape(dims));
out_md = out_md.reshape(dims);
}
const auto fused_unsqueeze2_axes =
dev_ctx.HasDnnAttr("fused_unsqueeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_unsqueeze2_axes"))
: std::vector<int>();
if (!fused_unsqueeze2_axes.empty()) {
funcs::SetOutMemDescWithUnsqueeze2FuseSupport(
fused_unsqueeze2_axes, out, out_md);
} else {
out->set_mem_desc(out_md);
}
}
......
......@@ -14,7 +14,7 @@
#include "paddle/phi/kernels/matmul_grad_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/backends/onednn/matmul_utils.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
......
......@@ -16,7 +16,7 @@
#include "paddle/phi/kernels/matmul_kernel.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/backends/onednn/matmul_utils.h"
#include "paddle/phi/core/kernel_registry.h"
using dnnl::engine;
......
......@@ -18,6 +18,40 @@
namespace phi {
void SetOutMemDescWithLogicalLayoutFusesSupport(
const OneDNNContext& dev_ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const auto fused_unsqueeze2_axes =
dev_ctx.HasDnnAttr("fused_unsqueeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_unsqueeze2_axes"))
: std::vector<int>();
const auto fused_reshape2_shape =
dev_ctx.HasDnnAttr("fused_reshape2_shape")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape2_shape"))
: std::vector<int>();
const auto fused_squeeze2_axes =
dev_ctx.HasDnnAttr("fused_squeeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_squeeze2_axes"))
: std::vector<int>();
if (!fused_unsqueeze2_axes.empty()) {
funcs::SetOutMemDescWithUnsqueeze2FuseSupport(
fused_unsqueeze2_axes, out, out_md);
} else if (!fused_reshape2_shape.empty()) {
funcs::SetOutMemDescWithReshape2FuseSupport(
fused_reshape2_shape, out, out_md);
} else if (!fused_squeeze2_axes.empty()) {
out->set_mem_desc(out_md);
out->Resize(make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}
void SetInMemDescWithSqueeze2FuseSupport(
const std::vector<int> fused_squeeze2_axes,
DenseTensor* in,
......@@ -71,8 +105,8 @@ void TransposeKernel(const Context& dev_ctx,
const std::vector<int>& axis,
DenseTensor* out) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType() == AllocationType::CPU,
true,
dev_ctx.GetPlace().GetType(),
AllocationType::CPU,
errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace"));
SetInMemDescWithLogicalLayoutFusesSupport(
......@@ -135,14 +169,7 @@ void TransposeKernel(const Context& dev_ctx,
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
std::vector<int64_t> fake_strides(axis.size());
int total_stride = 1;
for (int i = static_cast<int>(x_vec_dims.size()) - 1; i >= 0; --i) {
fake_strides[axis[i]] = total_stride;
total_stride *= x_vec_dims[axis[i]];
}
auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis);
auto dst_md = dnnl::memory::desc(x_vec_dims, out_type, fake_strides);
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace());
......@@ -161,10 +188,11 @@ void TransposeKernel(const Context& dev_ctx,
permute_axis[axis[i]] = i;
}
funcs::SetOutMemDescWithLogicalLayoutFusesSupport(
SetOutMemDescWithLogicalLayoutFusesSupport(
dev_ctx,
out,
reorder_dst_memory_p->get_desc().permute_axes(permute_axis));
reorder_dst_memory_p->get_desc().permute_axes(
funcs::TransposeToPermuteAxes(axis)));
}
} // namespace phi
......
......@@ -92,11 +92,25 @@ foreach(TEST_INFERENCE_IR_PASS ${TEST_MKLDNN_IR_PASSES})
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES ${TEST_INFERENCE_IR_PASS})
endforeach()
file(
GLOB TEST_ONEDNN_IR_PASSES
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_onednn_*.py")
string(REPLACE ".py" "" TEST_ONEDNN_IR_PASSES "${TEST_ONEDNN_IR_PASSES}")
foreach(TEST_INFERENCE_IR_PASS ${TEST_ONEDNN_IR_PASSES})
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES ${TEST_INFERENCE_IR_PASS})
endforeach()
if(WITH_MKLDNN)
foreach(target ${TEST_MKLDNN_IR_PASSES})
py_test_modules(${target} MODULES ${target})
set_tests_properties(${target} PROPERTIES LABELS "RUN_TYPE=INFER")
endforeach()
foreach(target ${TEST_ONEDNN_IR_PASSES})
py_test_modules(${target} MODULES ${target})
set_tests_properties(${target} PROPERTIES LABELS "RUN_TYPE=INFER")
endforeach()
endif()
file(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestTranspose2Unsqueeze2OneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
channel = draw(st.sampled_from([1, 2, 4]))
transpose_axis = draw(
st.sampled_from([[0, 1, 2, 3], [2, 1, 3, 0], [3, 2, 1, 0]])
)
unsqueeze_axes = draw(st.sampled_from([[0, 1], [0, 4], [1, 2], [3]]))
transpose2_op = OpConfig(
type="transpose2",
inputs={
"X": ["transpose_x"],
},
outputs={
"Out": ["transpose_out"],
"XShape": ['transpose2_xshape'],
},
attrs={
"axis": transpose_axis,
"use_mkldnn": True,
},
)
unsqueeze2_op = OpConfig(
type="unsqueeze2",
inputs={"X": ["transpose_out"]},
outputs={"Out": ["unsqueeze_out"]},
attrs={
"axes": unsqueeze_axes,
},
)
model_net = [transpose2_op, unsqueeze2_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
"transpose_x": TensorConfig(
data_gen=partial(generate_input, [channel, 16, 64, 32])
)
},
outputs=["unsqueeze_out"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=[
"operator_unsqueeze2_onednn_fuse_pass",
],
)
yield config, ["transpose2"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
passes=[
"operator_unsqueeze2_onednn_fuse_pass",
],
)
class TestElementwiseMulUnsqueeze2OneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
batch_size = draw(st.sampled_from([1, 2, 4]))
channel = draw(st.sampled_from([1, 3, 16]))
unsqueeze_axes = draw(st.sampled_from([[0, 1], [0, 4], [1, 2], [3]]))
elementwise_op = OpConfig(
type='elementwise_mul',
inputs={'X': ['eltwise_X'], 'Y': ['eltwise_Y']},
outputs={'Out': ['eltwise_output']},
attrs={"use_mkldnn": True},
)
unsqueeze2_op = OpConfig(
type="unsqueeze2",
inputs={"X": ["eltwise_output"]},
outputs={"Out": ["unsqueeze_out"]},
attrs={
"axes": unsqueeze_axes,
},
)
model_net = [elementwise_op, unsqueeze2_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
"eltwise_X": TensorConfig(
data_gen=partial(
generate_input, [batch_size, channel, 100, 100]
)
),
"eltwise_Y": TensorConfig(
data_gen=partial(
generate_input, [batch_size, channel, 100, 100]
)
),
},
outputs=["unsqueeze_out"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=[
"operator_unsqueeze2_onednn_fuse_pass",
],
)
yield config, ["elementwise_mul"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
passes=[
"operator_unsqueeze2_onednn_fuse_pass",
],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册