未验证 提交 b522ca52 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring part 3: Changes in quantize and dequantize (#42766)

* added md support inside (de)quantizes

* added missing file

* changed paddle enforce text

* another paddle enforce change

* same as before

* removed broken tests
上级 6d0e4e4a
......@@ -44,14 +44,6 @@ class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
Functor functor;
functor(ctx);
}
......@@ -62,14 +54,6 @@ class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input OutGrad tensor"));
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input OutGrad tensor"));
Functor functor;
functor(ctx);
}
......
......@@ -36,100 +36,58 @@ template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto scale_data = ctx.Attr<float>("Scale");
auto scale_shift = ctx.Attr<float>("Shift");
bool with_shift = scale_shift != 0.0f;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_NE(scale_data, 0.0f,
platform::errors::InvalidArgument(
"Dequantization scale cannot be 0.0"));
PADDLE_ENFORCE_GE(scale_shift, 0,
platform::errors::Unimplemented(
"Dequantization shift must be nonnegative."));
PADDLE_ENFORCE_LE(
scale_shift, 255,
platform::errors::Unimplemented(
"Dequantization shift must be less than or equal to 255."));
auto* x = ctx.Input<Tensor>("Input");
const auto quantization_scale = ctx.Attr<float>("Scale");
const auto quantization_shift = ctx.Attr<float>("Shift");
const bool with_shift = quantization_shift != 0.0f;
auto* out = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE(quantization_scale != 0.0f,
platform::errors::InvalidArgument(
"Dequantization scale must be different than 0.0f"));
PADDLE_ENFORCE(
quantization_shift <= 255 && quantization_shift >= 0,
platform::errors::InvalidArgument(
"Dequantization shift must be lower or equal to ",
"255 and greater or equal to 0, but got %f", quantization_shift));
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
const T* input_data = input->data<T>();
float* output_data = output->mutable_data<float>(ctx.GetPlace());
float reorder_shift = -scale_shift / scale_data;
auto src_tz = phi::vectorize<int64_t>(input->dims());
auto dst_tz = phi::vectorize<int64_t>(output->dims());
dnnl::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype()));
MKLDNNMemoryFormat src_fmt = input->format();
std::string key =
platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
dnnl::primitive_attr attri;
int mask = 0;
float reorder_scale = 1. / scale_data;
attri.set_output_scales(mask, {reorder_scale});
if (with_shift) {
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
std::fill(output_data, output_data + output->numel(), reorder_shift);
}
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(
dst_tz.size(), MKLDNNMemoryFormat::nchw));
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
if (with_shift)
std::fill(output_data, output_data + output->numel(), reorder_shift);
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
auto x_tz = phi::vectorize<int64_t>(x->dims());
auto x_paddle_dtype = framework::TransToProtoVarType(x->dtype());
auto out_paddle_dtype = framework::TransToProtoVarType(out->dtype());
dnnl::primitive_attr attrs;
static constexpr int32_t mask = 0; // same shift and scale for whole tensor
const float reorder_scale = 1. / quantization_scale;
attrs.set_output_scales(mask, {reorder_scale});
if (with_shift) {
attrs.set_zero_points(DNNL_ARG_SRC, mask,
{static_cast<int32_t>(quantization_shift)});
}
platform::ReorderMKLDNNHandler reorder_handler(
x_tz, x_paddle_dtype, framework::ToMKLDNNDataType(x_paddle_dtype),
out_paddle_dtype, framework::ToMKLDNNDataType(out_paddle_dtype),
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory, *dst_memory);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory));
out->set_mem_desc(reorder_dst_memory_p->get_desc());
}
};
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/quantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -34,83 +35,73 @@ template <typename T>
class QuantOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input");
auto scale_data = ctx.Attr<float>("Scale");
auto scale_shift = ctx.Attr<float>("Shift");
bool with_shift = scale_shift != 0.0f;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_NE(
scale_data, 0.0f,
platform::errors::InvalidArgument("Quantization scale cannot be 0.0"));
PADDLE_ENFORCE_GE(scale_shift, 0,
platform::errors::Unimplemented(
"Quantization shift must be nonnegative."));
PADDLE_ENFORCE_LE(
scale_shift, 255,
platform::errors::Unimplemented(
"Quantization shift must be less than or equal to 255."));
auto* x = ctx.Input<Tensor>("Input");
auto* out = ctx.Output<Tensor>("Output");
const auto quantization_scale = ctx.Attr<float>("Scale");
const auto quantization_shift = ctx.Attr<float>("Shift");
const bool with_scale = quantization_scale != 1.0f;
const bool with_shift = quantization_shift != 0.0f;
PADDLE_ENFORCE_NE(quantization_scale, 0.0f,
platform::errors::InvalidArgument(
"Quantization scale must be different than 0.0f"));
PADDLE_ENFORCE(
quantization_shift <= 255 && quantization_shift >= 0,
platform::errors::InvalidArgument(
"Quantization shift must be lower or equal to ",
"255 and greater or equal to 0, but got %f", quantization_shift));
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline;
auto src_tz = phi::vectorize<int64_t>(input->dims());
auto dst_tz = phi::vectorize<int64_t>(output->dims());
auto x_tz = phi::vectorize<int64_t>(x->dims());
const T* input_data = input->data<T>();
const bool is_negative_input = ctx.Attr<bool>("is_negative_input");
const bool bfloat16 = ctx.Attr<bool>("bfloat16");
bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16");
dnnl::primitive_attr attrs;
static constexpr int32_t mask = 0;
// TODO(jczaja): Refactor with Acquire API
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
dnnl::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {scale_data});
if (with_scale) {
attrs.set_output_scales(mask, {quantization_scale});
}
if (with_shift) {
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
// memset casts scale_shift to unsigned char (uint8_t) internally
std::memset(output_data, scale_shift, output->numel());
attrs.set_zero_points(DNNL_ARG_DST, mask,
{static_cast<int32_t>(quantization_shift)});
}
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format());
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
framework::proto::VarType::Type x_paddle_dtype =
framework::TransToProtoVarType(x->dtype());
framework::proto::VarType::Type out_paddle_dtype;
std::shared_ptr<dnnl::memory::desc> dst_md;
if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>(
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
out_paddle_dtype = framework::proto::VarType::BF16;
} else if (is_negative_input && !with_shift) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
out_paddle_dtype = framework::proto::VarType::INT8;
} else {
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format);
out_paddle_dtype = framework::proto::VarType::UINT8;
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
platform::ReorderMKLDNNHandler reorder_handler(
x_tz, x_paddle_dtype, framework::ToMKLDNNDataType(x_paddle_dtype),
out_paddle_dtype, framework::ToMKLDNNDataType(out_paddle_dtype),
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory, *dst_memory);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory));
out->set_mem_desc(reorder_dst_memory_p->get_desc());
}
};
} // namespace operators
......
......@@ -1057,6 +1057,14 @@ class ReorderMKLDNNHandler {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
}
std::shared_ptr<dnnl::reorder> AcquireReorder(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p,
const dnnl::primitive_attr& attrs) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p),
attrs);
}
private:
std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_, vtype_dst_;
......
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
import paddle
class TestDeQuantizeOp(OpTest):
......@@ -110,19 +111,6 @@ class TestDeQuantizeOpBf16(TestDeQuantizeOp):
self.data_type = 'uint16'
class TestDeQuantizeOp_ZeroScale(TestDeQuantizeOp):
def set_scale(self):
self.scale = 0.0
def prepare_output_int8(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Dequantization scale cannot be 0.0')
# 2-dim input
# P - positive input, with shift
class TestDeQuantizeOpShift_2_P(TestDeQuantizeOp):
......@@ -177,28 +165,6 @@ class TestDeQuantizeOpShift_4_N(TestDeQuantizeOpShift_2_N):
self.input_size = [2, 3, 4, 5]
class TestDeQuantizeOp_NegativeShift(TestDeQuantizeOp):
def set_shift(self):
self.shift = -10.0
def prepare_output_int8(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Dequantization shift must be nonnegative.')
class TestDeQuantizeOp_TooBigShift(TestDeQuantizeOp_NegativeShift):
def set_shift(self):
self.shift = 300.0
def test_check_output(self):
self.assertRaises(
AttributeError, self.check_raise_error,
'Dequantization shift must be less than or equal to 255.')
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest
import paddle
class TestQuantizeOp(OpTest):
......@@ -104,19 +105,6 @@ class TestQuantizeOp2(TestQuantizeOp):
self.is_nagative = False
class TestQuantizeOp_ZeroScale(TestQuantizeOp):
def set_scale(self):
self.scale = 0.0
def prepare_output(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Quantization scale cannot be 0.0')
# 2-dim input
# P - positive input
class TestQuantizeOpShift_NCHW_2_P(TestQuantizeOp):
......@@ -201,34 +189,6 @@ class TestQuantizeOpShift_NHWC_4_N(TestQuantizeOpShift_NCHW_4_N):
self.output_format = 'NHWC'
class TestQuantizeOp_NegativeShift(TestQuantizeOp):
def set_is_negative(self):
self.is_nagative = False
def set_scale(self):
self.scale = 100.0
def set_shift(self):
self.shift = -10.0
def prepare_output(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Quantization shift must be nonnegative.')
class TestQuantizeOp_TooBigShift(TestQuantizeOp_NegativeShift):
def set_shift(self):
self.shift = 300.0
def test_check_output(self):
self.assertRaises(
AttributeError, self.check_raise_error,
'Quantization shift must be less than or equal to 255.')
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册