未验证 提交 e9ae8dd0 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Disable cache matmul v1 & refactoring (#35331)

* - refactoring progressing

- Fix

- compilation fix

- another compilation fix

- refactoring

* - fix

* - compilation fix

* - compilation fix

* - missing set_format

* - compilation fix

* - reverted setting memeory format

* - Brought back format

* - Fix

* - fixes after review

* CI rerun

* CI rerun
上级 36cdb6e2
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h" #include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
#include <tuple>
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
...@@ -20,6 +21,7 @@ using paddle::framework::DataLayout; ...@@ -20,6 +21,7 @@ using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext; using paddle::framework::ExecutionContext;
using paddle::framework::vectorize; using paddle::framework::vectorize;
using paddle::platform::GetMKLDNNFormat; using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNFormatForSize;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType; using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast; using paddle::platform::to_void_cast;
...@@ -82,15 +84,39 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx, ...@@ -82,15 +84,39 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
} }
template <typename T> template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static paddle::framework::DDim RowMatrixDimsFromVector(
const paddle::framework::DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static paddle::framework::DDim ColumnMatrixDimsFromVector(
const paddle::framework::DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
}
template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> { : public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
public: public:
MatMulMKLDNNHandler(const mkldnn::engine engine, MatMulMKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x, paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out, bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale) float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine, : paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
auto mat_dim_x = auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x); paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y = auto mat_dim_y =
...@@ -115,117 +141,98 @@ class MatMulMKLDNNHandler ...@@ -115,117 +141,98 @@ class MatMulMKLDNNHandler
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K}; !trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1}; memory::dims out_strides = memory::dims{M * N, N, 1};
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides); auto x_md = memory::desc(x_dims, MKLDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides); auto y_md = memory::desc(y_dims, MKLDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides); auto out_md = memory::desc(out_dims, MKLDNNGetDataType<OT>(), out_strides);
dnnl::primitive_attr attrs; dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale}); if (scale != 1.0f) attrs.set_output_scales(0, {scale});
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
} }
// Constructor for FWD MatMul
MatMulMKLDNNHandler(const mkldnn::engine engine, const ExecutionContext& ctx,
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(
engine, ctx.GetPlace()),
matmul_dims_(GetMatmulDims(ctx)) {
dnnl::primitive_attr attr;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
constexpr unsigned tensor_wide_scale = 0;
attr.set_output_scales(tensor_wide_scale, {scale_out});
}
auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType<XT>(),
matmul_dims_.x_strides);
auto y_md = memory::desc(matmul_dims_.y_dims, MKLDNNGetDataType<YT>(),
matmul_dims_.y_strides);
auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType<OT>(),
matmul_dims_.out_strides);
this->AcquireForwardPrimitiveDescriptor(attr, x_md, y_md, out_md);
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>(); const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data)); to_void_cast<YT>(input_data));
} }
};
template <typename T> public:
constexpr bool IsInt8() { void Execute(const paddle::framework::Tensor* x,
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; const paddle::framework::Tensor* y,
} paddle::framework::Tensor* out) {
const auto src_memory_p = this->AcquireSrcMemory(x);
template <typename T> const auto weights_memory_p = this->AcquireWeightsMemory(y);
constexpr bool IsBfloat16() { const auto dst_memory_p = this->AcquireDstMemory(out);
return std::is_same<T, paddle::platform::bfloat16>::value;
} auto matmul_p = this->AcquireForwardPrimitive();
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the std::unordered_map<int, dnnl::memory> matmul_args = {
// original x_dim is returned. {DNNL_ARG_SRC, *src_memory_p},
static paddle::framework::DDim RowMatrixDimsFromVector( {DNNL_ARG_WEIGHTS, *weights_memory_p},
const paddle::framework::DDim& x_dim) { {DNNL_ARG_DST, *dst_memory_p}};
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
} auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the // Simulate batch matmul by processing in loop
// original y_dim is returned. void* x_ptr = src_memory_p->get_data_handle();
static paddle::framework::DDim ColumnMatrixDimsFromVector( void* y_ptr = weights_memory_p->get_data_handle();
const paddle::framework::DDim& y_dim) { void* out_ptr = dst_memory_p->get_data_handle();
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1}); auto offsets = this->GetOffsets();
} for (uint16_t i = 0; i < this->GetBatchSize(); ++i) {
src_memory_p->set_data_handle(x_ptr);
/** weights_memory_p->set_data_handle(y_ptr);
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor. dst_memory_p->set_data_handle(out_ptr);
* matmul_p->execute(astream, {
* The shape would be [BatchSize, H, W] or [H, W]. {MKLDNN_ARG_SRC, *src_memory_p},
* If transposed, `H,W` will be swapped. {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
*/ {MKLDNN_ARG_DST, *dst_memory_p},
static void ReshapeTensorToMatrixSequence( });
Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) { x_ptr = static_cast<char*>(x_ptr) + std::get<0>(offsets);
int64_t h, w; y_ptr = static_cast<char*>(y_ptr) + std::get<1>(offsets);
h = descriptor.height_; out_ptr = static_cast<char*>(out_ptr) + std::get<2>(offsets);
w = descriptor.width_; }
if (descriptor.trans_) { astream.wait();
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
/** auto format =
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
* Out = matmul(x, y) out->set_format(format);
* out->set_layout(DataLayout::kMKLDNN);
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
bool trans_x, bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
} }
ReshapeTensorToMatrixSequence(x, mat_dim_x); std::shared_ptr<mkldnn::memory> AcquireDstMemory(
ReshapeTensorToMatrixSequence(y, mat_dim_y); paddle::framework::Tensor* output) {
} // We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
template <typename XT, typename YT, typename OT> // we have primitive that covers only one batch of Data and then shift
class MatMulFactory { // pointer for every new batch. Hence Tensor size is bigger that dst memory
public: // primitive size. So would we request less memory that is there and it
void CreateAndExecute(const ExecutionContext& ctx) { // triggers an
SetDNNLEngine(ctx); // assertion. So as there is no 'any' format here we can leave default size
if (IsInitialized()) { // of Tensor as computed in ComputeInferShape
UpdateDataPointers(ctx); OT* ptr = output->mutable_data<OT>(this->place_);
Execute(); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
SetOutputFormat(ctx);
return;
}
CreateMemories(ctx);
CreatePrimitive(ctx);
Execute();
SetOutputFormat(ctx);
SetInitialized();
} }
private: private:
...@@ -234,47 +241,6 @@ class MatMulFactory { ...@@ -234,47 +241,6 @@ class MatMulFactory {
out_strides; out_strides;
}; };
void SetDNNLEngine(const ExecutionContext& ctx) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
engine_ = dev_ctx.GetEngine();
}
template <typename T>
dnnl::memory CreateMemory(const memory::dims& dims,
const memory::dims& strides, const T* data) {
auto md = memory::desc(dims, MKLDNNGetDataType<T>(), strides);
return dnnl::memory(md, engine_, to_void_cast(data));
}
std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::pair<paddle::operators::math::MatDescriptor, memory::dims> std::pair<paddle::operators::math::MatDescriptor, memory::dims>
GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) { GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name); auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
...@@ -310,6 +276,15 @@ class MatMulFactory { ...@@ -310,6 +276,15 @@ class MatMulFactory {
return std::make_pair(mat_dim, strides); return std::make_pair(mat_dim, strides);
} }
float ComputeOutputScale(const ExecutionContext& ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
float alpha = ctx.Attr<float>("alpha");
return alpha * scale_out / (scale_x * scale_y);
}
bool IsInputFused(const ExecutionContext& ctx) const { bool IsInputFused(const ExecutionContext& ctx) const {
return !(ctx.Attr<std::vector<int>>("fused_reshape_X").empty() && return !(ctx.Attr<std::vector<int>>("fused_reshape_X").empty() &&
ctx.Attr<std::vector<int>>("fused_reshape_Y").empty()); ctx.Attr<std::vector<int>>("fused_reshape_Y").empty());
...@@ -322,14 +297,6 @@ class MatMulFactory { ...@@ -322,14 +297,6 @@ class MatMulFactory {
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty(); return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
} }
void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
const memory::dim N, memory::dim b,
memory::dims* out_strides) const {
if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
*out_strides = {N, b * N, 1};
}
}
MatMulDims GetMatmulDims(const ExecutionContext& ctx) { MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
paddle::operators::math::MatDescriptor mat_dim_x; paddle::operators::math::MatDescriptor mat_dim_x;
memory::dims strides_x; memory::dims strides_x;
...@@ -381,125 +348,112 @@ class MatMulFactory { ...@@ -381,125 +348,112 @@ class MatMulFactory {
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides}; return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
} }
void CreateMemories(const ExecutionContext& ctx) { std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
auto matmul_dims = GetMatmulDims(ctx); const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
x_mem_ = CreateMemory<XT>(matmul_dims.x_dims, matmul_dims.x_strides, auto axis_set = std::set<int>(axis.begin(), axis.end());
ctx.Input<Tensor>("X")->data<XT>()); PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
y_mem_ = CreateMemory<YT>(matmul_dims.y_dims, matmul_dims.y_strides, paddle::platform::errors::InvalidArgument(
ctx.Input<Tensor>("Y")->data<YT>()); "In an axis array, elements must be unique."));
out_mem_ = CreateMemory<OT>(
matmul_dims.out_dims, matmul_dims.out_strides,
ctx.Output<Tensor>("Out")->mutable_data<OT>(ctx.GetPlace()));
}
float ComputeOutputScale(const ExecutionContext& ctx) { PADDLE_ENFORCE_EQ(in_rank, axis_size,
float scale_x = ctx.Attr<float>("Scale_x"); paddle::platform::errors::InvalidArgument(
float scale_y = ctx.Attr<float>("Scale_y"); "The input dimension's size "
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output"); "should be equal to the axis's size. "
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out"); "But received dimension is %d, "
float alpha = ctx.Attr<float>("alpha"); "axis's size is %d",
return alpha * scale_out / (scale_x * scale_y); in_rank, axis_size));
}
void CreatePrimitive(const ExecutionContext& ctx) { PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
dnnl::primitive_attr attr; paddle::platform::errors::InvalidArgument(
float scale_out = ComputeOutputScale(ctx); "Axis values must be ranging from 0 to (dims - 1)."));
if (scale_out != 1.0f) {
constexpr unsigned tensor_wide_scale = 0;
attr.set_output_scales(tensor_wide_scale, {scale_out});
}
auto matmul_d = dnnl::matmul::desc(x_mem_.get_desc(), y_mem_.get_desc(), std::vector<int64_t> new_x(x.size());
out_mem_.get_desc()); for (size_t i = 0; i < x.size(); i++) {
auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine_); new_x[i] = x[axis[i]];
matmul_prim_ = dnnl::matmul(matmul_pd); }
return new_x;
} }
void Execute() { void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
dnnl::stream stream(engine_); const memory::dim N, memory::dim b,
memory::dims* out_strides) const {
void* x_ptr = x_mem_.get_data_handle(); if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
void* y_ptr = y_mem_.get_data_handle(); *out_strides = {N, b * N, 1};
void* out_ptr = out_mem_.get_data_handle();
for (uint16_t i = 0; i < batch_size_; i++) {
x_mem_.set_data_handle(x_ptr);
y_mem_.set_data_handle(y_ptr);
out_mem_.set_data_handle(out_ptr);
matmul_prim_.execute(stream, {
{MKLDNN_ARG_SRC, x_mem_},
{MKLDNN_ARG_WEIGHTS, y_mem_},
{MKLDNN_ARG_DST, out_mem_},
});
x_ptr = static_cast<char*>(x_ptr) + x_offset_;
y_ptr = static_cast<char*>(y_ptr) + y_offset_;
out_ptr = static_cast<char*>(out_ptr) + out_offset_;
} }
stream.wait();
} }
void SetOutputFormat(const ExecutionContext& ctx) { uint16_t GetBatchSize(void) const { return batch_size_; }
using paddle::platform::MKLDNNFormatForSize;
auto* out = ctx.Output<Tensor>("Out");
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
}
void UpdateDataPointers(const ExecutionContext& ctx) { std::tuple<uint32_t, uint32_t, uint32_t> GetOffsets() const {
auto* x = ctx.Input<Tensor>("X"); return std::make_tuple(x_offset_, y_offset_, out_offset_);
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
x_mem_.set_data_handle(to_void_cast(x->data<XT>()));
y_mem_.set_data_handle(to_void_cast(y->data<YT>()));
out_mem_.set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
} }
// If initialized, x memory should've been already initialized
bool IsInitialized() { return initialized_; }
void SetInitialized() { initialized_ = true; }
private: private:
struct memory_offsets { MatMulDims matmul_dims_;
size_t x_offset;
size_t y_offset;
size_t out_offset;
};
dnnl::engine engine_;
dnnl::memory x_mem_;
dnnl::memory y_mem_;
dnnl::memory out_mem_;
dnnl::matmul matmul_prim_;
uint32_t x_offset_; uint32_t x_offset_;
uint32_t y_offset_; uint32_t y_offset_;
uint32_t out_offset_; uint32_t out_offset_;
uint16_t batch_size_; uint16_t batch_size_;
bool initialized_ = false;
}; };
template <typename XT, typename YT, typename OT> /**
static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory( * Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
const ExecutionContext& ctx) { *
const auto& out_name = ctx.OutputName("Out"); * The shape would be [BatchSize, H, W] or [H, W].
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); * If transposed, `H,W` will be swapped.
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0]; */
std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name); static void ReshapeTensorToMatrixSequence(
key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
int64_t h, w;
auto factory = h = descriptor.height_;
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key)); w = descriptor.width_;
if (factory == nullptr) { if (descriptor.trans_) {
factory = std::make_shared<MatMulFactory<XT, YT, OT>>(); std::swap(w, h);
dev_ctx.SetBlob(key, factory); }
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
/**
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
* Out = matmul(x, y)
*
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
bool trans_x, bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
} }
return factory; ReshapeTensorToMatrixSequence(x, mat_dim_x);
ReshapeTensorToMatrixSequence(y, mat_dim_y);
} }
// Choose appropriate primitive factory implementation based on inferred // Choose appropriate Handler instances based on inferred
// output type (uint8, int8 or float). // output type (uint8, int8 or float).
template <typename XT, typename YT> template <typename XT, typename YT>
static void ExecuteMatMul(const ExecutionContext& ctx) { static void ExecuteMatMul(const ExecutionContext& ctx) {
...@@ -507,31 +461,41 @@ static void ExecuteMatMul(const ExecutionContext& ctx) { ...@@ -507,31 +461,41 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
constexpr bool is_bfloat16 = IsBfloat16<XT>(); constexpr bool is_bfloat16 = IsBfloat16<XT>();
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) { if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
GetPrimitiveFactory<XT, YT, float>(ctx)->CreateAndExecute(ctx); MatMulMKLDNNHandler<XT, YT, float>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
} else if (is_bfloat16) { } else if (is_bfloat16) {
GetPrimitiveFactory<XT, YT, paddle::platform::bfloat16>(ctx) MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(dev_ctx.GetEngine(),
->CreateAndExecute(ctx); ctx, alpha)
.Execute(x, y, out);
} else if (fuse_relu) { } else if (fuse_relu) {
GetPrimitiveFactory<XT, YT, uint8_t>(ctx)->CreateAndExecute(ctx); MatMulMKLDNNHandler<XT, YT, uint8_t>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
} else { } else {
GetPrimitiveFactory<XT, YT, int8_t>(ctx)->CreateAndExecute(ctx); MatMulMKLDNNHandler<XT, YT, int8_t>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
} }
} }
template <typename T> template <typename T>
class DNNLMatMulKernel : public paddle::framework::OpKernel<T> { class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) { if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1, ctx.Attr<int>("head_number"), 1,
paddle::platform::errors::Unimplemented( paddle::platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected " "oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d", "head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
} }
MKLDNNDeviceContext::tls().log_lib_version();
ExecuteMatMul<T, T>(ctx); ExecuteMatMul<T, T>(ctx);
} }
}; };
...@@ -547,7 +511,7 @@ void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext& ctx) const { ...@@ -547,7 +511,7 @@ void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext& ctx) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1, ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented( platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected " "oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d", "head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number"))); ctx.Attr<int>("head_number")));
} }
...@@ -577,8 +541,9 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad( ...@@ -577,8 +541,9 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f; float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x, MatMulMKLDNNHandler<T, T, T> handler(engine, ctx.GetPlace(), &x_combined,
&y_combined, trans_y, out, alpha); trans_x, &y_combined, trans_y, out,
alpha);
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined); const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined); const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
...@@ -679,9 +644,9 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>; ...@@ -679,9 +644,9 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
DNNLMatMulKernel<float>, MatMulMKLDNNKernel<float>,
DNNLMatMulKernel<paddle::platform::bfloat16>, MatMulMKLDNNKernel<paddle::platform::bfloat16>,
DNNLMatMulKernel<int8_t>, DNNLMatMulKernel<uint8_t>); MatMulMKLDNNKernel<int8_t>, MatMulMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulGradMKLDNNKernel<float>, ops::MatMulGradMKLDNNKernel<float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册