未验证 提交 f3c14762 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add int8 support for matmulV2 (#44908)

上级 075d7219
...@@ -659,7 +659,7 @@ float ComputeOutputScale(const ExecutionContext &ctx) { ...@@ -659,7 +659,7 @@ float ComputeOutputScale(const ExecutionContext &ctx) {
return alpha * scale_out / (scale_x * scale_y); return alpha * scale_out / (scale_x * scale_y);
} }
template <typename T> template <typename T, typename T_out>
void ExecuteMatMulV2(const ExecutionContext &ctx, void ExecuteMatMulV2(const ExecutionContext &ctx,
const MKLDNNDeviceContext &dev_ctx, const MKLDNNDeviceContext &dev_ctx,
const dnnl::engine onednn_engine, const dnnl::engine onednn_engine,
...@@ -675,16 +675,16 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -675,16 +675,16 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
int execution_number = 0) { int execution_number = 0) {
std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X"); std::vector<int64_t> x_strides_override = GetInputStrides(ctx, "X");
std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y"); std::vector<int64_t> y_strides_override = GetInputStrides(ctx, "Y");
MatMulV2MKLDNNHandler<T> handler(ctx, MatMulV2MKLDNNHandler<T, T, T_out> handler(ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x_dims, x_dims,
trans_x, trans_x,
y_dims, y_dims,
trans_y, trans_y,
IsOutputFused(ctx), IsOutputFused(ctx),
x_strides_override, x_strides_override,
y_strides_override); y_strides_override);
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
...@@ -707,17 +707,41 @@ void ExecuteMatMulV2(const ExecutionContext &ctx, ...@@ -707,17 +707,41 @@ void ExecuteMatMulV2(const ExecutionContext &ctx,
auto &astream = MKLDNNDeviceContext::tls().get_stream(); auto &astream = MKLDNNDeviceContext::tls().get_stream();
matmul_p->execute(astream, matmul_args); matmul_p->execute(astream, matmul_args);
astream.wait(); astream.wait();
auto format =
auto format = paddle::platform::MKLDNNFormatForSize( MKLDNNFormatForSize(out->dims().size(), dnnl::memory::format_tag::nchw);
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(format); out->set_format(format);
out->set_layout(DataLayout::kMKLDNN);
} }
template <typename T> template <typename T>
class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const ExecutionContext &ctx) const override { RunKernel(ctx); } void Compute(const ExecutionContext &ctx) const override {
if (ctx.HasAttr("head_number")) {
PADDLE_ENFORCE_EQ(
ctx.Attr<int>("head_number"),
1,
paddle::platform::errors::Unimplemented(
"oneDNN matmul doesn't support multiple heads. Expected "
"head_number=1. But received `head_number` is %d",
ctx.Attr<int>("head_number")));
}
constexpr bool is_int8 = IsInt8<T>();
constexpr bool is_bfloat16 = IsBfloat16<T>();
const bool force_fp32_output = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
constexpr bool fuse_relu = false; // TODO(intel): Enable eltwise fuses
if (force_fp32_output || ((!is_int8) && (!is_bfloat16))) {
RunKernel<float>(ctx);
} else if (is_bfloat16) {
RunKernel<paddle::platform::bfloat16>(ctx);
} else if (fuse_relu) {
RunKernel<uint8_t>(ctx);
} else {
RunKernel<int8_t>(ctx);
}
}
private: private:
void CalculateMatrixDims(const ExecutionContext &ctx, void CalculateMatrixDims(const ExecutionContext &ctx,
...@@ -768,6 +792,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -768,6 +792,7 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
} }
template <typename T_out>
void RunKernel(const ExecutionContext &ctx) const { void RunKernel(const ExecutionContext &ctx) const {
const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &onednn_engine = dev_ctx.GetEngine(); const auto &onednn_engine = dev_ctx.GetEngine();
...@@ -793,18 +818,18 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -793,18 +818,18 @@ class MatMulV2MKLDNNKernel : public paddle::framework::OpKernel<T> {
CalculateMatrixDims( CalculateMatrixDims(
ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out); ctx, x_dims, y_dims, &x_bd_dims, &y_bd_dims, &out_dims, out);
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T_out>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x, x,
x_bd_dims, x_bd_dims,
trans_x, trans_x,
y, y,
y_bd_dims, y_bd_dims,
trans_y, trans_y,
out, out,
out_dims); out_dims);
} }
}; };
...@@ -939,113 +964,113 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -939,113 +964,113 @@ class MatMulV2GradMKLDNNKernel : public paddle::framework::OpKernel<T> {
ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims); ctx, &dx_tmp, &dy_tmp, x_dims, y_dims, &dx_bd_dims, &dy_bd_dims);
if (trans_x && trans_y) { if (trans_x && trans_y) {
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
y, y,
y_dims, y_dims,
true, true,
dout, dout,
dout_dims, dout_dims,
true, true,
&dx_tmp, &dx_tmp,
dx_bd_dims, dx_bd_dims,
1); 1);
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
dout, dout,
dout_dims, dout_dims,
true, true,
x, x,
x_dims, x_dims,
true, true,
&dy_tmp, &dy_tmp,
dy_bd_dims, dy_bd_dims,
2); 2);
} else if (trans_x) { } else if (trans_x) {
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
y, y,
y_dims, y_dims,
false, false,
dout, dout,
dout_dims, dout_dims,
true, true,
&dx_tmp, &dx_tmp,
dx_bd_dims, dx_bd_dims,
1); 1);
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x, x,
x_dims, x_dims,
false, false,
dout, dout,
dout_dims, dout_dims,
false, false,
&dy_tmp, &dy_tmp,
dy_bd_dims, dy_bd_dims,
2); 2);
} else if (trans_y) { } else if (trans_y) {
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
dout, dout,
dout_dims, dout_dims,
false, false,
y, y,
y_dims, y_dims,
false, false,
&dx_tmp, &dx_tmp,
dx_bd_dims, dx_bd_dims,
1); 1);
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
dout, dout,
dout_dims, dout_dims,
true, true,
x, x,
x_dims, x_dims,
false, false,
&dy_tmp, &dy_tmp,
dy_bd_dims, dy_bd_dims,
2); 2);
} else { } else {
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
dout, dout,
dout_dims, dout_dims,
false, false,
y, y,
y_dims, y_dims,
true, true,
&dx_tmp, &dx_tmp,
dx_bd_dims, dx_bd_dims,
1); 1);
ExecuteMatMulV2<T>(ctx, ExecuteMatMulV2<T, T>(ctx,
dev_ctx, dev_ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x, x,
x_dims, x_dims,
true, true,
dout, dout,
dout_dims, dout_dims,
false, false,
&dy_tmp, &dy_tmp,
dy_bd_dims, dy_bd_dims,
2); 2);
} }
if (x_dims != dx_bd_dims) { if (x_dims != dx_bd_dims) {
...@@ -1234,34 +1259,13 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>; ...@@ -1234,34 +1259,13 @@ template class MatMulGradMKLDNNKernel<paddle::platform::bfloat16>;
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul, REGISTER_OP_KERNEL(matmul,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
S8, MatMulV2MKLDNNKernel<float>,
0, MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulMKLDNNKernel<int8_t>); MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
U8,
0,
MatMulMKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(matmul,
MKLDNN,
::paddle::platform::CPUPlace,
FP32,
0,
MatMulV2MKLDNNKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(
matmul,
MKLDNN,
::paddle::platform::CPUPlace,
BF16,
0,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(matmul_grad, REGISTER_OP_KERNEL(matmul_grad,
MKLDNN, MKLDNN,
...@@ -1273,7 +1277,9 @@ REGISTER_OP_KERNEL(matmul_v2, ...@@ -1273,7 +1277,9 @@ REGISTER_OP_KERNEL(matmul_v2,
MKLDNN, MKLDNN,
::paddle::platform::CPUPlace, ::paddle::platform::CPUPlace,
MatMulV2MKLDNNKernel<float>, MatMulV2MKLDNNKernel<float>,
MatMulV2MKLDNNKernel<paddle::platform::bfloat16>); MatMulV2MKLDNNKernel<paddle::platform::bfloat16>,
MatMulV2MKLDNNKernel<int8_t>,
MatMulV2MKLDNNKernel<uint8_t>);
REGISTER_OP_KERNEL(matmul_v2_grad, REGISTER_OP_KERNEL(matmul_v2_grad,
MKLDNN, MKLDNN,
......
...@@ -416,16 +416,16 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> { ...@@ -416,16 +416,16 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
bool trans_y, bool trans_y,
Tensor *out) const { Tensor *out) const {
static const std::vector<int64_t> vec_placeholder; static const std::vector<int64_t> vec_placeholder;
MatMulV2MKLDNNHandler<XT> handler(ctx, MatMulV2MKLDNNHandler<XT, YT, XT> handler(ctx,
onednn_engine, onednn_engine,
ctx.GetPlace(), ctx.GetPlace(),
x_dims, x_dims,
trans_x, trans_x,
y_dims, y_dims,
trans_y, trans_y,
false, false,
vec_placeholder, vec_placeholder,
vec_placeholder); vec_placeholder);
const auto src_memory_p = handler.AcquireSrcMemory(x); const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y); const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......
...@@ -860,8 +860,18 @@ class ReductionMKLDNNHandler ...@@ -860,8 +860,18 @@ class ReductionMKLDNNHandler
}; };
template <typename T> template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
template <typename T>
constexpr bool IsBfloat16() {
return std::is_same<T, paddle::platform::bfloat16>::value;
}
template <typename XT, typename YT, typename OT>
class MatMulV2MKLDNNHandler class MatMulV2MKLDNNHandler
: public paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul> { : public paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul> {
public: public:
MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx, MatMulV2MKLDNNHandler(const framework::ExecutionContext& ctx,
const dnnl::engine engine, const dnnl::engine engine,
...@@ -873,8 +883,8 @@ class MatMulV2MKLDNNHandler ...@@ -873,8 +883,8 @@ class MatMulV2MKLDNNHandler
bool is_output_fused, bool is_output_fused,
const std::vector<int64_t>& x_strides_override, const std::vector<int64_t>& x_strides_override,
const std::vector<int64_t>& y_strides_override) const std::vector<int64_t>& y_strides_override)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine, : paddle::platform::MKLDNNHandlerNoCachingT<XT, dnnl::matmul>(engine,
cpu_place) { cpu_place) {
// M X K * K X N // M X K * K X N
std::vector<int64_t> x_dims(x_org_dims); std::vector<int64_t> x_dims(x_org_dims);
std::vector<int64_t> y_dims(y_org_dims); std::vector<int64_t> y_dims(y_org_dims);
...@@ -934,28 +944,42 @@ class MatMulV2MKLDNNHandler ...@@ -934,28 +944,42 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1]; out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
} }
if (is_output_fused) { if (!IsInt8<OT>() && !IsBfloat16<OT>() && is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims); out_strides = FakeTransposeStrides(out_ddims);
} }
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides); auto x_md = memory::desc(x_dims, MKLDNNGetDataType<XT>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides); auto y_md = memory::desc(y_dims, MKLDNNGetDataType<YT>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides); auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<OT>(), out_strides);
const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx); const dnnl::primitive_attr matmul_attrs = CreateMatmulAttrs(ctx);
this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md); this->AcquireForwardPrimitiveDescriptor(matmul_attrs, x_md, y_md, out_md);
} }
// TODO(jczaja) : Adapt to int8 float ComputeOutputScale(const framework::ExecutionContext& ctx) {
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f;
if (ctx.HasAttr("Scale_x") && ctx.HasAttr("Scale_y") &&
ctx.HasAttr("Scale_out")) {
float scale_x = ctx.Attr<float>("Scale_x");
float scale_y = ctx.Attr<float>("Scale_y");
bool force_fp32_out = ctx.HasAttr("force_fp32_output")
? ctx.Attr<bool>("force_fp32_output")
: false;
float scale_out = force_fp32_out ? 1.f : ctx.Attr<float>("Scale_out");
alpha *= scale_out / (scale_x * scale_y);
}
return alpha;
}
dnnl::primitive_attr CreateMatmulAttrs( dnnl::primitive_attr CreateMatmulAttrs(
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
dnnl::primitive_attr matmul_attrs; dnnl::primitive_attr matmul_attrs;
dnnl::post_ops post_operations; dnnl::post_ops post_operations;
float alpha = ctx.HasAttr("alpha") ? ctx.Attr<float>("alpha") : 1.0f; float scale_out = ComputeOutputScale(ctx);
if (alpha != 1.0f) { if (scale_out != 1.0f) {
matmul_attrs.set_output_scales(0, {alpha}); matmul_attrs.set_output_scales(0, {scale_out});
} }
if (ctx.HasInput("ResidualData")) { if (ctx.HasInput("ResidualData")) {
...@@ -993,9 +1017,23 @@ class MatMulV2MKLDNNHandler ...@@ -993,9 +1017,23 @@ class MatMulV2MKLDNNHandler
} }
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) { std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>(); const YT* input_data = input->data<YT>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(), return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
to_void_cast<T>(input_data)); to_void_cast<YT>(input_data));
}
std::shared_ptr<dnnl::memory> AcquireDstMemory(
paddle::framework::Tensor* output) {
// We cannot use base AcquireDstMemory as it makes an allocation request
// base on DST memory primitive size. This is fine in general, but in MatMul
// we have primitive that covers only one batch of Data and then shift
// pointer for every new batch. Hence Tensor size is bigger that dst memory
// primitive size. So would we request less memory that is there and it
// triggers an
// assertion. So as there is no 'any' format here we can leave default size
// of Tensor as computed in ComputeInferShape
OT* ptr = output->mutable_data<OT>(this->place_);
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
} }
}; };
...@@ -1099,11 +1137,11 @@ class ActivationMKLDNNHandler ...@@ -1099,11 +1137,11 @@ class ActivationMKLDNNHandler
static std::unordered_map<std::string, std::string> GetAttributeMap( static std::unordered_map<std::string, std::string> GetAttributeMap(
std::string act_type) { std::string act_type) {
std::unordered_map<std::string, std::string> attr_map; std::unordered_map<std::string, std::string> attr_map;
if (act_type == "swish") if (act_type == "swish") {
attr_map.emplace("beta", "fuse_alpha"); attr_map.emplace("beta", "fuse_alpha");
else if (act_type == "relu6") } else if (act_type == "relu6") {
attr_map.emplace("threshold", "fuse_alpha"); attr_map.emplace("threshold", "fuse_alpha");
else if (act_type == "hard_sigmoid") { } else if (act_type == "hard_sigmoid") {
attr_map.emplace("slope", "fuse_alpha"); attr_map.emplace("slope", "fuse_alpha");
attr_map.emplace("offset", "fuse_beta"); attr_map.emplace("offset", "fuse_beta");
} else if (act_type == "clip") { } else if (act_type == "clip") {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册