未验证 提交 b522ca52 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring part 3: Changes in quantize and dequantize (#42766)

* added md support inside (de)quantizes

* added missing file

* changed paddle enforce text

* another paddle enforce change

* same as before

* removed broken tests
上级 6d0e4e4a
...@@ -44,14 +44,6 @@ class MKLDNNActivationKernel ...@@ -44,14 +44,6 @@ class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor"));
PADDLE_ENFORCE_NE(
x->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for X tensor"));
Functor functor; Functor functor;
functor(ctx); functor(ctx);
} }
...@@ -62,14 +54,6 @@ class MKLDNNActivationGradKernel ...@@ -62,14 +54,6 @@ class MKLDNNActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(diff_y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Input OutGrad tensor"));
PADDLE_ENFORCE_NE(diff_y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input OutGrad tensor"));
Functor functor; Functor functor;
functor(ctx); functor(ctx);
} }
......
...@@ -36,100 +36,58 @@ template <typename T> ...@@ -36,100 +36,58 @@ template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> { class DeQuantOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input"); auto* x = ctx.Input<Tensor>("Input");
auto scale_data = ctx.Attr<float>("Scale"); const auto quantization_scale = ctx.Attr<float>("Scale");
auto scale_shift = ctx.Attr<float>("Shift"); const auto quantization_shift = ctx.Attr<float>("Shift");
bool with_shift = scale_shift != 0.0f; const bool with_shift = quantization_shift != 0.0f;
auto* output = ctx.Output<Tensor>("Output"); auto* out = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_NE(scale_data, 0.0f, PADDLE_ENFORCE(quantization_scale != 0.0f,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Dequantization scale cannot be 0.0")); "Dequantization scale must be different than 0.0f"));
PADDLE_ENFORCE_GE(scale_shift, 0,
platform::errors::Unimplemented( PADDLE_ENFORCE(
"Dequantization shift must be nonnegative.")); quantization_shift <= 255 && quantization_shift >= 0,
PADDLE_ENFORCE_LE( platform::errors::InvalidArgument(
scale_shift, 255, "Dequantization shift must be lower or equal to ",
platform::errors::Unimplemented( "255 and greater or equal to 0, but got %f", quantization_shift));
"Dequantization shift must be less than or equal to 255."));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
auto x_tz = phi::vectorize<int64_t>(x->dims());
const T* input_data = input->data<T>(); auto x_paddle_dtype = framework::TransToProtoVarType(x->dtype());
float* output_data = output->mutable_data<float>(ctx.GetPlace()); auto out_paddle_dtype = framework::TransToProtoVarType(out->dtype());
float reorder_shift = -scale_shift / scale_data; dnnl::primitive_attr attrs;
static constexpr int32_t mask = 0; // same shift and scale for whole tensor
auto src_tz = phi::vectorize<int64_t>(input->dims());
auto dst_tz = phi::vectorize<int64_t>(output->dims()); const float reorder_scale = 1. / quantization_scale;
dnnl::memory::data_type src_dt = paddle::framework::ToMKLDNNDataType( attrs.set_output_scales(mask, {reorder_scale});
framework::TransToProtoVarType(input->dtype()));
MKLDNNMemoryFormat src_fmt = input->format(); if (with_shift) {
attrs.set_zero_points(DNNL_ARG_SRC, mask,
std::string key = {static_cast<int32_t>(quantization_shift)});
platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
dnnl::primitive_attr attri;
int mask = 0;
float reorder_scale = 1. / scale_data;
attri.set_output_scales(mask, {reorder_scale});
if (with_shift) {
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
std::fill(output_data, output_data + output->numel(), reorder_shift);
}
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
platform::MKLDNNFormatForSize(
dst_tz.size(), MKLDNNMemoryFormat::nchw));
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
if (with_shift)
std::fill(output_data, output_data + output->numel(), reorder_shift);
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
} }
platform::ReorderMKLDNNHandler reorder_handler(
x_tz, x_paddle_dtype, framework::ToMKLDNNDataType(x_paddle_dtype),
out_paddle_dtype, framework::ToMKLDNNDataType(out_paddle_dtype),
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory, *dst_memory); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); out->set_mem_desc(reorder_dst_memory_p->get_desc());
output->set_format(GetMKLDNNFormat(*dst_memory));
} }
}; };
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "dnnl.hpp" #include "dnnl.hpp"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/quantize_op.h" #include "paddle/fluid/operators/quantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -34,83 +35,73 @@ template <typename T> ...@@ -34,83 +35,73 @@ template <typename T>
class QuantOpKernel : public framework::OpKernel<T> { class QuantOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<Tensor>("Input"); auto* x = ctx.Input<Tensor>("Input");
auto scale_data = ctx.Attr<float>("Scale"); auto* out = ctx.Output<Tensor>("Output");
auto scale_shift = ctx.Attr<float>("Shift");
bool with_shift = scale_shift != 0.0f; const auto quantization_scale = ctx.Attr<float>("Scale");
auto* output = ctx.Output<Tensor>("Output"); const auto quantization_shift = ctx.Attr<float>("Shift");
const bool with_scale = quantization_scale != 1.0f;
PADDLE_ENFORCE_NE( const bool with_shift = quantization_shift != 0.0f;
scale_data, 0.0f,
platform::errors::InvalidArgument("Quantization scale cannot be 0.0")); PADDLE_ENFORCE_NE(quantization_scale, 0.0f,
PADDLE_ENFORCE_GE(scale_shift, 0, platform::errors::InvalidArgument(
platform::errors::Unimplemented( "Quantization scale must be different than 0.0f"));
"Quantization shift must be nonnegative.")); PADDLE_ENFORCE(
PADDLE_ENFORCE_LE( quantization_shift <= 255 && quantization_shift >= 0,
scale_shift, 255, platform::errors::InvalidArgument(
platform::errors::Unimplemented( "Quantization shift must be lower or equal to ",
"Quantization shift must be less than or equal to 255.")); "255 and greater or equal to 0, but got %f", quantization_shift));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline; auto x_tz = phi::vectorize<int64_t>(x->dims());
auto src_tz = phi::vectorize<int64_t>(input->dims());
auto dst_tz = phi::vectorize<int64_t>(output->dims());
const T* input_data = input->data<T>(); const bool is_negative_input = ctx.Attr<bool>("is_negative_input");
const bool bfloat16 = ctx.Attr<bool>("bfloat16");
bool is_negative_input = ctx.Attr<bool>("is_negative_input"); dnnl::primitive_attr attrs;
bool bfloat16 = ctx.Attr<bool>("bfloat16"); static constexpr int32_t mask = 0;
// TODO(jczaja): Refactor with Acquire API if (with_scale) {
std::shared_ptr<dnnl::memory> src_memory; attrs.set_output_scales(mask, {quantization_scale});
std::shared_ptr<dnnl::memory> dst_memory; }
std::shared_ptr<reorder> reorder_p;
std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout);
dnnl::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {scale_data});
if (with_shift) { if (with_shift) {
dnnl::post_ops post_operations; attrs.set_zero_points(DNNL_ARG_DST, mask,
post_operations.append_sum(); {static_cast<int32_t>(quantization_shift)});
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
// memset casts scale_shift to unsigned char (uint8_t) internally
std::memset(output_data, scale_shift, output->numel());
} }
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, framework::proto::VarType::Type x_paddle_dtype =
input->format()); framework::TransToProtoVarType(x->dtype());
src_memory = std::make_shared<dnnl::memory>(src_md, engine, framework::proto::VarType::Type out_paddle_dtype;
to_void_cast<T>(input_data));
std::shared_ptr<dnnl::memory::desc> dst_md;
if (bfloat16) { if (bfloat16) {
platform::SetDstMemoryQuantized<paddle::platform::bfloat16>( out_paddle_dtype = framework::proto::VarType::BF16;
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format);
} else if (is_negative_input && !with_shift) { } else if (is_negative_input && !with_shift) {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine, out_paddle_dtype = framework::proto::VarType::INT8;
dst_md, dst_memory, out_format);
} else { } else {
platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine, out_paddle_dtype = framework::proto::VarType::UINT8;
dst_md, dst_memory, out_format);
} }
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri)); platform::ReorderMKLDNNHandler reorder_handler(
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd)); x_tz, x_paddle_dtype, framework::ToMKLDNNDataType(x_paddle_dtype),
out_paddle_dtype, framework::ToMKLDNNDataType(out_paddle_dtype),
dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
out, x->mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory, *dst_memory); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); out->set_mem_desc(reorder_dst_memory_p->get_desc());
output->set_format(GetMKLDNNFormat(*dst_memory));
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -1057,6 +1057,14 @@ class ReorderMKLDNNHandler { ...@@ -1057,6 +1057,14 @@ class ReorderMKLDNNHandler {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p)); return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p));
} }
std::shared_ptr<dnnl::reorder> AcquireReorder(
std::shared_ptr<dnnl::memory> dst_memory_p,
std::shared_ptr<dnnl::memory> src_memory_p,
const dnnl::primitive_attr& attrs) {
return std::make_shared<dnnl::reorder>(*(src_memory_p), *(dst_memory_p),
attrs);
}
private: private:
std::vector<int64_t> dims_; std::vector<int64_t> dims_;
framework::proto::VarType::Type vtype_, vtype_dst_; framework::proto::VarType::Type vtype_, vtype_dst_;
......
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16 from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
import paddle
class TestDeQuantizeOp(OpTest): class TestDeQuantizeOp(OpTest):
...@@ -110,19 +111,6 @@ class TestDeQuantizeOpBf16(TestDeQuantizeOp): ...@@ -110,19 +111,6 @@ class TestDeQuantizeOpBf16(TestDeQuantizeOp):
self.data_type = 'uint16' self.data_type = 'uint16'
class TestDeQuantizeOp_ZeroScale(TestDeQuantizeOp):
def set_scale(self):
self.scale = 0.0
def prepare_output_int8(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Dequantization scale cannot be 0.0')
# 2-dim input # 2-dim input
# P - positive input, with shift # P - positive input, with shift
class TestDeQuantizeOpShift_2_P(TestDeQuantizeOp): class TestDeQuantizeOpShift_2_P(TestDeQuantizeOp):
...@@ -177,28 +165,6 @@ class TestDeQuantizeOpShift_4_N(TestDeQuantizeOpShift_2_N): ...@@ -177,28 +165,6 @@ class TestDeQuantizeOpShift_4_N(TestDeQuantizeOpShift_2_N):
self.input_size = [2, 3, 4, 5] self.input_size = [2, 3, 4, 5]
class TestDeQuantizeOp_NegativeShift(TestDeQuantizeOp):
def set_shift(self):
self.shift = -10.0
def prepare_output_int8(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Dequantization shift must be nonnegative.')
class TestDeQuantizeOp_TooBigShift(TestDeQuantizeOp_NegativeShift):
def set_shift(self):
self.shift = 300.0
def test_check_output(self):
self.assertRaises(
AttributeError, self.check_raise_error,
'Dequantization shift must be less than or equal to 255.')
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
...@@ -17,6 +17,7 @@ from __future__ import print_function ...@@ -17,6 +17,7 @@ from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
import paddle
class TestQuantizeOp(OpTest): class TestQuantizeOp(OpTest):
...@@ -104,19 +105,6 @@ class TestQuantizeOp2(TestQuantizeOp): ...@@ -104,19 +105,6 @@ class TestQuantizeOp2(TestQuantizeOp):
self.is_nagative = False self.is_nagative = False
class TestQuantizeOp_ZeroScale(TestQuantizeOp):
def set_scale(self):
self.scale = 0.0
def prepare_output(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Quantization scale cannot be 0.0')
# 2-dim input # 2-dim input
# P - positive input # P - positive input
class TestQuantizeOpShift_NCHW_2_P(TestQuantizeOp): class TestQuantizeOpShift_NCHW_2_P(TestQuantizeOp):
...@@ -201,34 +189,6 @@ class TestQuantizeOpShift_NHWC_4_N(TestQuantizeOpShift_NCHW_4_N): ...@@ -201,34 +189,6 @@ class TestQuantizeOpShift_NHWC_4_N(TestQuantizeOpShift_NCHW_4_N):
self.output_format = 'NHWC' self.output_format = 'NHWC'
class TestQuantizeOp_NegativeShift(TestQuantizeOp):
def set_is_negative(self):
self.is_nagative = False
def set_scale(self):
self.scale = 100.0
def set_shift(self):
self.shift = -10.0
def prepare_output(self):
self.output = np.zeros(self.input_size)
self.outputs = {'Output': self.output}
def test_check_output(self):
self.assertRaises(AttributeError, self.check_raise_error,
'Quantization shift must be nonnegative.')
class TestQuantizeOp_TooBigShift(TestQuantizeOp_NegativeShift):
def set_shift(self):
self.shift = 300.0
def test_check_output(self):
self.assertRaises(
AttributeError, self.check_raise_error,
'Quantization shift must be less than or equal to 255.')
if __name__ == '__main__': if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册