未验证 提交 f3c14762 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add int8 support for matmulV2 (#44908)

上级 075d7219
......@@ -659,7 +659,7 @@ float ComputeOutputScale(const ExecutionContext &ctx) {
return alpha * scale_out / (scale_x * scale_y);
}
template <typename T>
template <typename T, typename T_out>
void ExecuteMatMulV2(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine onednn_engine,
......@@ -675,16 +675,16 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
int execution_number = 0) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
MatMulV2MKLDNNHandler<T, T, T_out> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
IsOutputFused(ctx),
x_strides_override,
y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......@@ -707,17 +707,41 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args);
astream.wait();
auto format = paddle::platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
auto format =
MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
}
template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); }
void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
paddle::platform::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
constexpr bool is_int8 = IsInt8<T>();
constexpr bool is_bfloat16 = IsBfloat16<T>();
const bool force_fp32_output = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
RunKernel<float>(ctx);
} else if (is_bfloat16) {
RunKernel<paddle::platform::bfloat16>(ctx);
} else if (fuse_relu) {
RunKernel<uint8_t>(ctx);
} else {
RunKernel<int8_t>(ctx);
}
}
private:
void CalculateMatrixDims(const ExecutionContext &ctx,
......@@ -768,6 +792,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
}
}
template <typename T_out>
void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine();
......@@ -793,18 +818,18 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
CalculateMatrixDims(
ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out,
out_dims);
ExecuteMatMulV2<T, T_out>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_bd_dims,
trans_x,
y,
y_bd_dims,
trans_y,
out,
out_dims);
}
};
......@@ -939,113 +964,113 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (trans_x && trans_y) {
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
true,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
true,
&dy_tmp,
dy_bd_dims,
2);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
true,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
true,
&dy_tmp,
dy_bd_dims,
2);
} else if (trans_x) {
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
false,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
false,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
y,
y_dims,
false,
dout,
dout_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
false,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
} else if (trans_y) {
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
false,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
false,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
true,
x,
x_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
} else {
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
true,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
dout,
dout_dims,
false,
y,
y_dims,
true,
&dx_tmp,
dx_bd_dims,
1);
ExecuteMatMulV2<T, T>(ctx,
dev_ctx,
onednn_engine,
ctx.GetPlace(),
x,
x_dims,
true,
dout,
dout_dims,
false,
&dy_tmp,
dy_bd_dims,
2);
}
if (x_dims != dx_bd_dims) {
......@@ -1234,34 +1259,13 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
S8,
0,
MatMulMKLDNNKernel<int8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
0,
MatMulMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
0,
MatMulV2MKLDNNKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
matmul,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
0,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_grad,
MKLDNN,
......@@ -1273,7 +1277,9 @@ REGISTER_OP_KERNEL(matmul_v2,
MKLDNN,
::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_v2_grad,
MKLDNN,
......
......@@ -416,16 +416,16 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
bool trans_y,
Tensor *out) const {
static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
false,
vec_placeholder,
vec_placeholder);
MatMulV2MKLDNNHandler<XT, YT, XT> handler(ctx,
onednn_engine,
ctx.GetPlace(),
x_dims,
trans_x,
y_dims,
trans_y,
false,
vec_placeholder,
vec_placeholder);
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......
......@@ -860,8 +860,18 @@ class ReductionMKLDNNHandler
};
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
template <typename XT, typename YT, typename OT>
class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> {
: public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
public:
MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine engine,
......@@ -873,8 +883,8 @@ class MatMulV2MKLDNNHandler
bool is_output_fused,
const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
: paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims);
......@@ -934,28 +944,42 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
if (is_output_fused) {
if (!IsInt8<OT>() && !IsBfloat16<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<OT>(), out_strides);
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
}
// TODO(jczaja) : Adapt to int8
float ComputeOutputScale(const framework::ExecutionContext& ctx) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") &&
ctx.HasAttr("Scale_out")) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
alpha *= scale_out / (scale_x * scale_y);
}
return alpha;
}
dnnl::primitive_attr CreateMatmulAttrs(
const framework::ExecutionContext& ctx) {
dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations;
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
if (alpha != 1.0f) {
matmul_attrs.set_output_scales(0, {alpha});
float scale_out = ComputeOutputScale(ctx);
if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {scale_out});
}
if (ctx.HasInput("ResidualData")) {
......@@ -993,9 +1017,23 @@ class MatMulV2MKLDNNHandler
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data));
to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
paddle::framework::Tensor* output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence Tensor size is bigger that dst memory
// primitive size. So would we request less memory that is there and it
// triggers an
// assertion. So as there is no 'any' format here we can leave default size
// of Tensor as computed in ComputeInferShape
OT* ptr = output->mutable_data<OT>(this->place_);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
}
};
......@@ -1099,11 +1137,11 @@ class ActivationMKLDNNHandler
static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) {
std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish")
if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6")
} else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") {
} else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册