未验证 提交 e9ae8dd0 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Disable cache matmul v1 & refactoring (#35331)

* - refactoring progressing

- Fix

- compilation fix

- another compilation fix

- refactoring

* - fix

* - compilation fix

* - compilation fix

* - missing set_format

* - compilation fix

* - reverted setting memeory format

* - Brought back format

* - Fix

* - fixes after review

* CI rerun

* CI rerun
上级 36cdb6e2
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mkldnn/matmul_mkldnn_op.h"
#include <tuple>
using dnnl::memory;
using dnnl::primitive;
......@@ -20,6 +21,7 @@ using paddle::framework::DataLayout;
using paddle::framework::ExecutionContext;
using paddle::framework::vectorize;
using paddle::platform::GetMKLDNNFormat;
using paddle::platform::MKLDNNFormatForSize;
using paddle::platform::MKLDNNDeviceContext;
using paddle::platform::MKLDNNGetDataType;
using paddle::platform::to_void_cast;
......@@ -82,15 +84,39 @@ static Tensor FoldFirstAndLastDims(const MKLDNNDeviceContext& dev_ctx,
}
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static paddle::framework::DDim RowMatrixDimsFromVector(
const paddle::framework::DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static paddle::framework::DDim ColumnMatrixDimsFromVector(
const paddle::framework::DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
}
template <typename XT, typename YT, typename OT>
class MatMulMKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatMulMKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place, Tensor* x,
bool trans_x, Tensor* y, bool trans_y, Tensor* out,
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) {
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x->dims(), 0, trans_x);
auto mat_dim_y =
......@@ -115,117 +141,98 @@ class MatMulMKLDNNHandler
!trans_y ? memory::dims{N * K, N, 1} : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1};
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<T>(), out_strides);
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_dims, MKLDNNGetDataType<OT>(), out_strides);
dnnl::primitive_attr attrs;
if (scale != 1.0f) attrs.set_output_scales(0, {scale});
this->AcquireForwardPrimitiveDescriptor(attrs, x_md, y_md, out_md);
}
// Constructor for FWD MatMul
MatMulMKLDNNHandler(const mkldnn::engine engine, const ExecutionContext& ctx,
float scale)
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(
engine, ctx.GetPlace()),
matmul_dims_(GetMatmulDims(ctx)) {
dnnl::primitive_attr attr;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
constexpr unsigned tensor_wide_scale = 0;
attr.set_output_scales(tensor_wide_scale, {scale_out});
}
auto x_md = memory::desc(matmul_dims_.x_dims, MKLDNNGetDataType<XT>(),
matmul_dims_.x_strides);
auto y_md = memory::desc(matmul_dims_.y_dims, MKLDNNGetDataType<YT>(),
matmul_dims_.y_strides);
auto out_md = memory::desc(matmul_dims_.out_dims, MKLDNNGetDataType<OT>(),
matmul_dims_.out_strides);
this->AcquireForwardPrimitiveDescriptor(attr, x_md, y_md, out_md);
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data));
to_void_cast<YT>(input_data));
}
};
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
// Get row matrix shape from a vector shape. If the rank of x_dim > 1, the
// original x_dim is returned.
static paddle::framework::DDim RowMatrixDimsFromVector(
const paddle::framework::DDim& x_dim) {
return x_dim.size() > 1 ? x_dim : paddle::framework::make_ddim({1, x_dim[0]});
}
// Get column matrix shape from a vector shape. If the ran of y_dim > 1, the
// original y_dim is returned.
static paddle::framework::DDim ColumnMatrixDimsFromVector(
const paddle::framework::DDim& y_dim) {
return y_dim.size() > 1 ? y_dim : paddle::framework::make_ddim({y_dim[0], 1});
}
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorToMatrixSequence(
Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
public:
void Execute(const paddle::framework::Tensor* x,
const paddle::framework::Tensor* y,
paddle::framework::Tensor* out) {
const auto src_memory_p = this->AcquireSrcMemory(x);
const auto weights_memory_p = this->AcquireWeightsMemory(y);
const auto dst_memory_p = this->AcquireDstMemory(out);
auto matmul_p = this->AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> matmul_args = {
{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_WEIGHTS, *weights_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
// Simulate batch matmul by processing in loop
void* x_ptr = src_memory_p->get_data_handle();
void* y_ptr = weights_memory_p->get_data_handle();
void* out_ptr = dst_memory_p->get_data_handle();
auto offsets = this->GetOffsets();
for (uint16_t i = 0; i < this->GetBatchSize(); ++i) {
src_memory_p->set_data_handle(x_ptr);
weights_memory_p->set_data_handle(y_ptr);
dst_memory_p->set_data_handle(out_ptr);
matmul_p->execute(astream, {
{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p},
});
x_ptr = static_cast<char*>(x_ptr) + std::get<0>(offsets);
y_ptr = static_cast<char*>(y_ptr) + std::get<1>(offsets);
out_ptr = static_cast<char*>(out_ptr) + std::get<2>(offsets);
}
astream.wait();
/**
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
* Out = matmul(x, y)
*
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
bool trans_x, bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
}
ReshapeTensorToMatrixSequence(x, mat_dim_x);
ReshapeTensorToMatrixSequence(y, mat_dim_y);
}
template <typename XT, typename YT, typename OT>
class MatMulFactory {
public:
void CreateAndExecute(const ExecutionContext& ctx) {
SetDNNLEngine(ctx);
if (IsInitialized()) {
UpdateDataPointers(ctx);
Execute();
SetOutputFormat(ctx);
return;
}
CreateMemories(ctx);
CreatePrimitive(ctx);
Execute();
SetOutputFormat(ctx);
SetInitialized();
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
paddle::framework::Tensor* output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence Tensor size is bigger that dst memory
// primitive size. So would we request less memory that is there and it
// triggers an
// assertion. So as there is no 'any' format here we can leave default size
// of Tensor as computed in ComputeInferShape
OT* ptr = output->mutable_data<OT>(this->place_);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
private:
......@@ -234,47 +241,6 @@ class MatMulFactory {
out_strides;
};
void SetDNNLEngine(const ExecutionContext& ctx) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
engine_ = dev_ctx.GetEngine();
}
template <typename T>
dnnl::memory CreateMemory(const memory::dims& dims,
const memory::dims& strides, const T* data) {
auto md = memory::desc(dims, MKLDNNGetDataType<T>(), strides);
return dnnl::memory(md, engine_, to_void_cast(data));
}
std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(in_rank, axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::pair<paddle::operators::math::MatDescriptor, memory::dims>
GetInputDimsAndStrides(const ExecutionContext& ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
......@@ -310,6 +276,15 @@ class MatMulFactory {
return std::make_pair(mat_dim, strides);
}
float ComputeOutputScale(const ExecutionContext& ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
float alpha = ctx.Attr<float>("alpha");
return alpha * scale_out / (scale_x * scale_y);
}
bool IsInputFused(const ExecutionContext& ctx) const {
return !(ctx.Attr<std::vector<int>>("fused_reshape_X").empty() &&
ctx.Attr<std::vector<int>>("fused_reshape_Y").empty());
......@@ -322,14 +297,6 @@ class MatMulFactory {
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
const memory::dim N, memory::dim b,
memory::dims* out_strides) const {
if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
*out_strides = {N, b * N, 1};
}
}
MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
paddle::operators::math::MatDescriptor mat_dim_x;
memory::dims strides_x;
......@@ -381,125 +348,112 @@ class MatMulFactory {
return {x_dims, y_dims, out_dims, strides_x, strides_y, out_strides};
}
void CreateMemories(const ExecutionContext& ctx) {
auto matmul_dims = GetMatmulDims(ctx);
std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
x_mem_ = CreateMemory<XT>(matmul_dims.x_dims, matmul_dims.x_strides,
ctx.Input<Tensor>("X")->data<XT>());
y_mem_ = CreateMemory<YT>(matmul_dims.y_dims, matmul_dims.y_strides,
ctx.Input<Tensor>("Y")->data<YT>());
out_mem_ = CreateMemory<OT>(
matmul_dims.out_dims, matmul_dims.out_strides,
ctx.Output<Tensor>("Out")->mutable_data<OT>(ctx.GetPlace()));
}
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
paddle::platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
float ComputeOutputScale(const ExecutionContext& ctx) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.Attr<bool>("force_fp32_output");
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
float alpha = ctx.Attr<float>("alpha");
return alpha * scale_out / (scale_x * scale_y);
}
PADDLE_ENFORCE_EQ(in_rank, axis_size,
paddle::platform::errors::InvalidArgument(
"The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
void CreatePrimitive(const ExecutionContext& ctx) {
dnnl::primitive_attr attr;
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
constexpr unsigned tensor_wide_scale = 0;
attr.set_output_scales(tensor_wide_scale, {scale_out});
}
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
paddle::platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
auto matmul_d = dnnl::matmul::desc(x_mem_.get_desc(), y_mem_.get_desc(),
out_mem_.get_desc());
auto matmul_pd = dnnl::matmul::primitive_desc(matmul_d, attr, engine_);
matmul_prim_ = dnnl::matmul(matmul_pd);
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
void Execute() {
dnnl::stream stream(engine_);
void* x_ptr = x_mem_.get_data_handle();
void* y_ptr = y_mem_.get_data_handle();
void* out_ptr = out_mem_.get_data_handle();
for (uint16_t i = 0; i < batch_size_; i++) {
x_mem_.set_data_handle(x_ptr);
y_mem_.set_data_handle(y_ptr);
out_mem_.set_data_handle(out_ptr);
matmul_prim_.execute(stream, {
{MKLDNN_ARG_SRC, x_mem_},
{MKLDNN_ARG_WEIGHTS, y_mem_},
{MKLDNN_ARG_DST, out_mem_},
});
x_ptr = static_cast<char*>(x_ptr) + x_offset_;
y_ptr = static_cast<char*>(y_ptr) + y_offset_;
out_ptr = static_cast<char*>(out_ptr) + out_offset_;
void CorrectStridesWhenFloatOutputFused(const ExecutionContext& ctx,
const memory::dim N, memory::dim b,
memory::dims* out_strides) const {
if (!IsInt8<OT>() && !IsBfloat16<OT>() && IsOutputFused(ctx)) {
*out_strides = {N, b * N, 1};
}
stream.wait();
}
void SetOutputFormat(const ExecutionContext& ctx) {
using paddle::platform::MKLDNNFormatForSize;
auto* out = ctx.Output<Tensor>("Out");
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
}
uint16_t GetBatchSize(void) const { return batch_size_; }
void UpdateDataPointers(const ExecutionContext& ctx) {
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
x_mem_.set_data_handle(to_void_cast(x->data<XT>()));
y_mem_.set_data_handle(to_void_cast(y->data<YT>()));
out_mem_.set_data_handle(out->mutable_data<OT>(ctx.GetPlace()));
std::tuple<uint32_t, uint32_t, uint32_t> GetOffsets() const {
return std::make_tuple(x_offset_, y_offset_, out_offset_);
}
// If initialized, x memory should've been already initialized
bool IsInitialized() { return initialized_; }
void SetInitialized() { initialized_ = true; }
private:
struct memory_offsets {
size_t x_offset;
size_t y_offset;
size_t out_offset;
};
dnnl::engine engine_;
dnnl::memory x_mem_;
dnnl::memory y_mem_;
dnnl::memory out_mem_;
dnnl::matmul matmul_prim_;
MatMulDims matmul_dims_;
uint32_t x_offset_;
uint32_t y_offset_;
uint32_t out_offset_;
uint16_t batch_size_;
bool initialized_ = false;
};
template <typename XT, typename YT, typename OT>
static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const ExecutionContext& ctx) {
const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
std::string key = paddle::platform::CreateKey(dev_ctx, batch_size, out_name);
key = paddle::platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
if (factory == nullptr) {
factory = std::make_shared<MatMulFactory<XT, YT, OT>>();
dev_ctx.SetBlob(key, factory);
/**
* Reshape a tensor to 3-D or 2-D tensor by matrix descriptor.
*
* The shape would be [BatchSize, H, W] or [H, W].
* If transposed, `H,W` will be swapped.
*/
static void ReshapeTensorToMatrixSequence(
Tensor* x, const paddle::operators::math::MatDescriptor& descriptor) {
int64_t h, w;
h = descriptor.height_;
w = descriptor.width_;
if (descriptor.trans_) {
std::swap(w, h);
}
if (descriptor.batch_size_) {
x->Resize({descriptor.batch_size_, h, w});
} else {
x->Resize({h, w});
}
}
/**
* Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor
* Out = matmul(x, y)
*
* This method will first calculate X,Y matrix sequence, and then calculate
* the out shape.
*
* Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2]
* The out = [BatchSize, H1, W2]
*
* If there is no batch size in `X` and `Y`, the out will be [H1, W2]
* If any of `X` and `Y` has batch size BatchSize, the out will have the
* BatchSize.
*/
static void ReshapeXYOutToMatrixSequence(Tensor* x, Tensor* y, Tensor* out,
bool trans_x, bool trans_y) {
auto x_dim = RowMatrixDimsFromVector(x->dims());
auto y_dim = ColumnMatrixDimsFromVector(y->dims());
auto mat_dim_x =
paddle::operators::math::CreateMatrixDescriptor(x_dim, 0, trans_x);
auto mat_dim_y =
paddle::operators::math::CreateMatrixDescriptor(y_dim, 0, trans_y);
if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) {
out->Resize({mat_dim_x.height_, mat_dim_y.width_});
} else {
out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_),
mat_dim_x.height_, mat_dim_y.width_});
}
return factory;
ReshapeTensorToMatrixSequence(x, mat_dim_x);
ReshapeTensorToMatrixSequence(y, mat_dim_y);
}
// Choose appropriate primitive factory implementation based on inferred
// Choose appropriate Handler instances based on inferred
// output type (uint8, int8 or float).
template <typename XT, typename YT>
static void ExecuteMatMul(const ExecutionContext& ctx) {
......@@ -507,31 +461,41 @@ static void ExecuteMatMul(const ExecutionContext& ctx) {
constexpr bool is_bfloat16 = IsBfloat16<XT>();
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<Tensor>("Out");
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
const auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
GetPrimitiveFactory<XT, YT, float>(ctx)->CreateAndExecute(ctx);
MatMulMKLDNNHandler<XT, YT, float>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
} else if (is_bfloat16) {
GetPrimitiveFactory<XT, YT, paddle::platform::bfloat16>(ctx)
->CreateAndExecute(ctx);
MatMulMKLDNNHandler<XT, YT, paddle::platform::bfloat16>(dev_ctx.GetEngine(),
ctx, alpha)
.Execute(x, y, out);
} else if (fuse_relu) {
GetPrimitiveFactory<XT, YT, uint8_t>(ctx)->CreateAndExecute(ctx);
MatMulMKLDNNHandler<XT, YT, uint8_t>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
} else {
GetPrimitiveFactory<XT, YT, int8_t>(ctx)->CreateAndExecute(ctx);
MatMulMKLDNNHandler<XT, YT, int8_t>(dev_ctx.GetEngine(), ctx, alpha)
.Execute(x, y, out);
}
}
template <typename T>
class DNNLMatMulKernel : public paddle::framework::OpKernel<T> {
class MatMulMKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
paddle::platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected "
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
MKLDNNDeviceContext::tls().log_lib_version();
ExecuteMatMul<T, T>(ctx);
}
};
......@@ -547,7 +511,7 @@ void MatMulGradMKLDNNKernel<T>::Compute(const ExecutionContext& ctx) const {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"), 1,
platform::errors::Unimplemented(
"DNNL matmul doesn't support multiple heads. Expected "
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
......@@ -577,8 +541,9 @@ void MatMulGradMKLDNNKernel<T>::ExecuteMatMulGrad(
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
MatMulMKLDNNHandler<T> handler(engine, ctx.GetPlace(), &x_combined, trans_x,
&y_combined, trans_y, out, alpha);
MatMulMKLDNNHandler<T, T, T> handler(engine, ctx.GetPlace(), &x_combined,
trans_x, &y_combined, trans_y, out,
alpha);
const auto src_memory_p = handler.AcquireSrcMemory(&x_combined);
const auto weights_memory_p = handler.AcquireWeightsMemory(&y_combined);
......@@ -679,9 +644,9 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(matmul, MKLDNN, ::paddle::platform::CPUPlace,
DNNLMatMulKernel<float>,
DNNLMatMulKernel<paddle::platform::bfloat16>,
DNNLMatMulKernel<int8_t>, DNNLMatMulKernel<uint8_t>);
MatMulMKLDNNKernel<float>,
MatMulMKLDNNKernel<paddle::platform::bfloat16>,
MatMulMKLDNNKernel<int8_t>, MatMulMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::MatMulGradMKLDNNKernel<float>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册