未验证 提交 fb16fea3 编写于 作者: S Sławomir Siwek 提交者: GitHub

cleanup unused code (#47762)

上级 14f261ad
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -18,10 +18,8 @@ using dnnl::memory; ...@@ -18,10 +18,8 @@ using dnnl::memory;
using paddle::framework::ExecutionContext; using paddle::framework::ExecutionContext;
using paddle::platform::MatMulV2MKLDNNHandler; using paddle::platform::MatMulV2MKLDNNHandler;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNFormatForSize;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast; using paddle::platform::to_void_cast;
using phi::DataLayout;
using phi::vectorize; using phi::vectorize;
using Tensor = phi::DenseTensor; using Tensor = phi::DenseTensor;
using paddle::framework::GradVarName; using paddle::framework::GradVarName;
...@@ -157,22 +155,6 @@ class MatMulMKLDNNHandler ...@@ -157,22 +155,6 @@ class MatMulMKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
} }
// Constructor for FWD MatMul
MatMulMKLDNNHandler(const dnnl::engine engine, const ExecutionContext &ctx)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(
engine, ctx.GetPlace()) {
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
auto matmul_dims_ = GetMatmulDims(ctx);
auto x_md = memory::desc(
matmul_dims_.x_dims, MKLDNNGetDataType<XT>(), matmul_dims_.x_strides);
auto y_md = memory::desc(
matmul_dims_.y_dims, MKLDNNGetDataType<YT>(), matmul_dims_.y_strides);
auto out_md = memory::desc(matmul_dims_.out_dims,
MKLDNNGetDataType<OT>(),
matmul_dims_.out_strides);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor *input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor *input) {
const YT *input_data = input->data<YT>(); const YT *input_data = input->data<YT>();
...@@ -201,8 +183,8 @@ class MatMulMKLDNNHandler ...@@ -201,8 +183,8 @@ class MatMulMKLDNNHandler
void *x_ptr = src_memory_p->get_data_handle(); void *x_ptr = src_memory_p->get_data_handle();
void *y_ptr = weights_memory_p->get_data_handle(); void *y_ptr = weights_memory_p->get_data_handle();
void *out_ptr = dst_memory_p->get_data_handle(); void *out_ptr = dst_memory_p->get_data_handle();
auto offsets = this->GetOffsets(); auto offsets = std::make_tuple(x_offset_, y_offset_, out_offset_);
for (uint16_t i = 0; i < this->GetBatchSize(); ++i) { for (uint16_t i = 0; i < batch_size_; ++i) {
src_memory_p->set_data_handle(x_ptr); src_memory_p->set_data_handle(x_ptr);
weights_memory_p->set_data_handle(y_ptr); weights_memory_p->set_data_handle(y_ptr);
dst_memory_p->set_data_handle(out_ptr); dst_memory_p->set_data_handle(out_ptr);
...@@ -229,182 +211,6 @@ class MatMulMKLDNNHandler ...@@ -229,182 +211,6 @@ class MatMulMKLDNNHandler
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
} }
private:
struct MatMulDims {
const memory::dims x_dims, y_dims, out_dims, x_strides, y_strides,
out_strides;
};
std::pair<phi::funcs::MatDescriptor, memory::dims> GetInputDimsAndStrides(
const ExecutionContext &ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<phi::DenseTensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto &MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector
: ColumnMatrixDimsFromVector;
phi::funcs::MatDescriptor mat_dim = phi::funcs::CreateMatrixDescriptor(
MatrixDimsFromVector(new_dims),
0,
ctx.Attr<bool>("transpose_" + input_name));
memory::dims strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(), strides.front() * shape2[i]);
}
strides = Transpose(strides, axis);
if (shape.size() == 4)
strides.erase(strides.begin());
else if (shape.size() == 2)
strides.insert(strides.begin(), shape[0] * shape[1]);
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return std::make_pair(mat_dim, strides);
}
float ComputeOutputScale(const ExecutionContext &ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
float alpha = ctx.Attr<float>("alpha");
return alpha * scale_out / (scale_x * scale_y);
}
bool IsInputFused(const ExecutionContext &ctx) const {
return !(ctx.Attr<std::vector<int>>("fused_reshape_X").empty() &&
ctx.Attr<std::vector<int>>("fused_reshape_Y").empty());
}
bool IsOutputFused(const ExecutionContext &ctx) const {
auto &fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto &fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
MatMulDims GetMatmulDims(const ExecutionContext &ctx) {
phi::funcs::MatDescriptor mat_dim_x;
memory::dims strides_x;
std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X");
phi::funcs::MatDescriptor mat_dim_y;
memory::dims strides_y;
std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
auto x_bs = mat_dim_x.batch_size_;
auto y_bs = mat_dim_y.batch_size_;
PADDLE_ENFORCE_EQ(x_bs > 0 && y_bs > 0 && x_bs != y_bs,
false,
paddle::platform::errors::InvalidArgument(
"If batch sizes of X and Y are positive,"
"they have to be equal."));
memory::dim out_bs = x_bs || y_bs ? std::max(x_bs, y_bs) : 1;
const memory::dim M = mat_dim_x.height_;
const memory::dim N = mat_dim_y.width_;
const memory::dim K = mat_dim_x.width_;
batch_size_ = 1;
if (out_bs > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
auto x_dims = GetDimForInput(ctx, "X");
auto y_dims = GetDimForInput(ctx, "Y");
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
x_bs /= batch_size_;
y_bs /= batch_size_;
out_bs /= batch_size_;
}
memory::dims x_dims = {x_bs > 0 ? x_bs : 1, M, K};
memory::dims y_dims = {y_bs > 0 ? y_bs : 1, K, N};
memory::dims out_dims = {out_bs, M, N};
x_offset_ = x_bs * M * K * sizeof(XT);
y_offset_ = y_bs * K * N * sizeof(YT);
out_offset_ = out_bs * M * N * sizeof(OT);
// Translate transA and transB
if (strides_x.empty())
strides_x = !ctx.Attr<bool>("transpose_X") ? memory::dims{M * K, K, 1}
: memory::dims{M * K, 1, M};
if (strides_y.empty())
strides_y = !ctx.Attr<bool>("transpose_Y") ? memory::dims{N * K, N, 1}
: memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
CorrectStridesWhenFloatOutputFused(ctx, N, out_bs, &out_strides);
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
}
std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
const std::vector<int> &axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(),
axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank,
axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank,
axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()),
axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
void CorrectStridesWhenFloatOutputFused(const ExecutionContext &ctx,
const memory::dim N,
memory::dim b,
memory::dims *out_strides) const {
if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
*out_strides = {N, b * N, 1};
}
}
uint16_t GetBatchSize(void) const { return batch_size_; }
std::tuple<uint32_t, uint32_t, uint32_t> GetOffsets() const {
return std::make_tuple(x_offset_, y_offset_, out_offset_);
}
dnnl::primitive_attr CreateMatmulAttrs(const ExecutionContext &ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
paddle::platform::AppendActivation(ctx, post_operations);
matmul_attrs.set_post_ops(post_operations);
return matmul_attrs;
}
private: private:
uint32_t x_offset_; uint32_t x_offset_;
uint32_t y_offset_; uint32_t y_offset_;
...@@ -465,55 +271,8 @@ static void ReshapeXYOutToMatrixSequence( ...@@ -465,55 +271,8 @@ static void ReshapeXYOutToMatrixSequence(
ReshapeTensorToMatrixSequence(y, mat_dim_y); ReshapeTensorToMatrixSequence(y, mat_dim_y);
} }
// Choose appropriate Handler instances based on inferred std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
// output type (uint8, int8 or float). const std::vector<int> &axis) {
template <typename XT, typename YT>
static void ExecuteMatMul(const ExecutionContext &ctx) {
constexpr bool is_int8 = IsInt8<XT>();
constexpr bool is_bfloat16 = IsBfloat16<XT>();
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
const bool fuse_relu =
ctx.HasAttr("fuse_activation")
? ctx.Attr<std::string>("fuse_activation") == "relu"
: false;
auto *x = ctx.Input<phi::DenseTensor>("X");
auto *y = ctx.Input<phi::DenseTensor>("Y");
auto *out = ctx.Output<phi::DenseTensor>("Out");
const auto &dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
MatMulMKLDNNHandler<XT, YT, float>(onednn_engine, ctx).Execute(x, y, out);
} else if (is_bfloat16) {
MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(onednn_engine, ctx)
.Execute(x, y, out);
} else if (fuse_relu) {
MatMulMKLDNNHandler<XT, YT, uint8_t>(onednn_engine, ctx).Execute(x, y, out);
} else {
MatMulMKLDNNHandler<XT, YT, int8_t>(onednn_engine, ctx).Execute(x, y, out);
}
}
template <typename T>
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
paddle::platform::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
ExecuteMatMul<T, T>(ctx);
}
};
static std::vector<int64_t> Transpose(const std::vector<int64_t> &x,
const std::vector<int> &axis) {
size_t in_rank = x.size(); size_t in_rank = x.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
...@@ -589,15 +348,6 @@ bool IsOutputFused(const ExecutionContext &ctx) { ...@@ -589,15 +348,6 @@ bool IsOutputFused(const ExecutionContext &ctx) {
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
} }
float ComputeOutputScale(const ExecutionContext &ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
return alpha * scale_out / (scale_x * scale_y);
}
template <typename T, typename T_out> template <typename T, typename T_out>
void ExecuteMatMulV2(const ExecutionContext &ctx, void ExecuteMatMulV2(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx, const MKLDNNDeviceContext &dev_ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册