From 0caa08ea409d52070e1cbcf5c5cb068312cb0241 Mon Sep 17 00:00:00 2001 From: Physher <50398879+Minghui-Intel@users.noreply.github.com> Date: Tue, 9 Jul 2019 16:29:18 +0800 Subject: [PATCH] Add mkldnn int8 mul-op kernel (#17834) --- .../fluid/operators/mkldnn/mul_mkldnn_op.cc | 433 ++++++++++++++++++ paddle/fluid/operators/mul_op.cc | 52 +++ paddle/fluid/operators/mul_op.h | 2 + paddle/fluid/platform/mkldnn_helper.h | 11 + .../mkldnn/test_mul_int8_mkldnn_op.py | 166 +++++++ .../tests/unittests/test_operator_desc.py | 3 +- 6 files changed, 666 insertions(+), 1 deletion(-) create mode 100644 paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_mul_int8_mkldnn_op.py diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc new file mode 100644 index 0000000000..4819bb300d --- /dev/null +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -0,0 +1,433 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/fluid/operators/mul_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using framework::DDim; +using framework::ExecutionContext; +using framework::Tensor; +using mkldnn::inner_product_forward; +using mkldnn::memory; +using mkldnn::prop_kind; +using mkldnn::stream; +using platform::MKLDNNDeviceContext; +using platform::to_void_cast; + +template +class MulPrimitiveFactory { + public: + explicit MulPrimitiveFactory(const mkldnn::engine &engine) + : engine_(engine) {} + + virtual ~MulPrimitiveFactory() {} + + virtual inner_product_forward CreateMulPrimitive( + const Tensor *input_x, const Tensor *input_y, Tensor *output, + const ExecutionContext &ctx) { + /* check format and reorder if need */ + int x_num_col_dims = ctx.Attr("x_num_col_dims"); + int y_num_col_dims = ctx.Attr("y_num_col_dims"); + + auto x_matrix = UpdateDataFormat(input_x, x_num_col_dims, ctx); + auto y_matrix = UpdateDataFormat(input_y, y_num_col_dims, ctx); + + auto output_dim = output->dims(); + if (output_dim.size() != 2) { + output->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + + if (mul_) { + UpdateDataPointers(ctx, output, &x_matrix); + return *mul_; + } + + auto src_desc = CreateMemDescriptor(&x_matrix, memory::format::nc); + x_input_ = CreateMemory(src_desc, &x_matrix); + y_input_ = TransposeInputY(&y_matrix); + auto dst_desc = CreateMemDescriptor(output, memory::format::any); + + mul_ = CreateMulPrimitive(*x_input_, *y_input_, dst_desc, output, ctx); + return *mul_; + } + + protected: + template + Tensor UpdateDataFormat(const Tensor *data, int num_col_dims, + const ExecutionContext &ctx) { + Tensor x_tmp; + Tensor data_matrix; + memory::format src_fmt = data->format(); + memory::format dst_fmt; + auto src_mdesc = CreateMemDescriptor(data, src_fmt); + + if ((data->dims().size() == 4 && + src_fmt != (dst_fmt = memory::format::nchw)) || + (data->dims().size() == 5 && + dst_fmt != (dst_fmt = memory::format::ncdhw))) { + auto dst_mdesc = CreateMemDescriptor(data, dst_fmt); + x_tmp.mutable_data(ctx.GetPlace(), data->memory_size()); + + Reorder(src_mdesc, dst_mdesc, to_void_cast(data->data()), + to_void_cast(x_tmp.data())); + + x_tmp.Resize(data->dims()); + x_tmp.set_format((memory::format)dst_mdesc.data.format); + data_matrix = framework::ReshapeToMatrix(x_tmp, num_col_dims); + } else { + data_matrix = framework::ReshapeToMatrix(*data, num_col_dims); + } + + return data_matrix; + } + + void UpdateDataPointers(const ExecutionContext &ctx, Tensor *out, + const Tensor *in) { + x_input_->set_data_handle(to_void_cast(in->data())); + output_->set_data_handle(out->mutable_data(ctx.GetPlace())); + + if (out->format() == memory::format::format_undef) { + auto output_format = output_->get_primitive_desc().desc().data.format; + out->set_format((memory::format)output_format); + } + } + + template + memory::desc CreateMemDescriptor( + const Tensor *tensor, memory::format format, + memory::data_type type = platform::MKLDNNGetDataType()) { + auto dims = framework::vectorize2int(tensor->dims()); + return platform::MKLDNNMemDesc(dims, type, format); + } + + template + memory::desc CreateMemDescriptor( + const std::vector &dims, memory::format format, + memory::data_type type = platform::MKLDNNGetDataType()) { + return platform::MKLDNNMemDesc(dims, type, format); + } + + template + memory CreateMemory(const memory::desc &desc, const Tensor *tensor) { + return memory({desc, engine_}, to_void_cast(tensor->data())); + } + + memory CreateDstMemory( + const inner_product_forward::primitive_desc &mul_prim_desc, + const ExecutionContext &ctx, Tensor *output) { + auto dst_prim_desc = mul_prim_desc.dst_primitive_desc(); + auto buffer_size = dst_prim_desc.get_size(); + + OT *output_data = output->mutable_data(ctx.GetPlace(), buffer_size); + output->set_format((memory::format)dst_prim_desc.desc().data.format); + return memory(dst_prim_desc, to_void_cast(output_data)); + } + + memory Reorder(const memory::desc &src_desc, const memory::desc &dst_desc, + void *src_data, void *dst_data = NULL) { + auto src_mem = memory({src_desc, engine_}, src_data); + auto dst_mem = dst_data ? memory({dst_desc, engine_}, dst_data) + : memory({dst_desc, engine_}); + + auto reorder = mkldnn::reorder(src_mem, dst_mem); + stream(stream::kind::eager).submit({reorder}).wait(); + + return dst_mem; + } + + memory TransposeInputY(const Tensor *input_y) { + auto dims = framework::vectorize2int(input_y->dims()); + std::swap(dims[0], dims[1]); // Correct output dimensions + auto src_desc = CreateMemDescriptor(dims, memory::format::io); + auto dst_desc = CreateMemDescriptor(dims, memory::format::oi); + return Reorder(src_desc, dst_desc, to_void_cast(input_y->data())); + } + + inner_product_forward CreateMulPrimitive(const memory &x_memory, + const memory &y_memory, + const memory::desc &dst_desc, + Tensor *output, + const ExecutionContext &ctx) { + const auto y_desc = y_memory.get_primitive_desc().desc(); + const auto x_desc = x_memory.get_primitive_desc().desc(); + + auto mul_prim_desc = CreateMulPrimDesc(x_desc, y_desc, dst_desc); + output_ = CreateDstMemory(mul_prim_desc, ctx, output); + + return inner_product_forward(mul_prim_desc, x_memory, y_memory, *output_); + } + + inner_product_forward::primitive_desc CreateMulPrimDesc( + const memory::desc &x_desc, const memory::desc &y_desc, + const memory::desc &dst_desc) { + auto mul_desc = inner_product_forward::desc(prop_kind::forward, x_desc, + y_desc, dst_desc); + + return inner_product_forward::primitive_desc(mul_desc, engine_); + } + + protected: + const mkldnn::engine &engine_; + boost::optional x_input_; + boost::optional y_input_; + boost::optional output_; + boost::optional mul_; +}; // namespace operators + +template +class QuantMulPrimitiveFactory : public MulPrimitiveFactory { + public: + using MulPrimitiveFactory::MulPrimitiveFactory; + + virtual inner_product_forward CreateMulPrimitive( + const Tensor *x_input, const Tensor *y_input, Tensor *output, + const ExecutionContext &ctx) { + /* check data format and reorder if need */ + int x_num_col_dims = ctx.Attr("x_num_col_dims"); + int y_num_col_dims = ctx.Attr("y_num_col_dims"); + auto scale_y = ctx.Attr>("scale_y"); + + auto x_matrix = + this->template UpdateDataFormat(x_input, x_num_col_dims, ctx); + auto y_matrix = + this->template UpdateDataFormat(y_input, y_num_col_dims, ctx); + + auto output_dim = output->dims(); + if (output_dim.size() != 2) { + output->Resize({x_matrix.dims()[0], y_matrix.dims()[1]}); + } + + if (this->mul_) { + this->UpdateDataPointers(ctx, output, &x_matrix); + return *(this->mul_); + } + + auto src_desc = + this->template CreateMemDescriptor(&x_matrix, memory::format::nc); + this->x_input_ = this->template CreateMemory(src_desc, &x_matrix); + + const auto trans_y = this->TransposeInputY(&y_matrix); + this->y_input_ = QuantInputY(trans_y, scale_y); + + auto dst_desc = + this->template CreateMemDescriptor(output, memory::format::any); + + this->mul_ = CreateMulPrimitive(*(this->x_input_), *(this->y_input_), + dst_desc, output, ctx); + return *(this->mul_); + } + + memory ReorderWithScale(const memory::desc &src_desc, + const memory::desc &dst_desc, void *src_data, + const std::vector &scale) { + auto mask = scale.size() > 1 ? 1 : 0; + mkldnn::primitive_attr attr; + attr.set_output_scales(mask, scale); + + auto src_mem = memory({src_desc, this->engine_}, src_data); + auto dst_mem = memory({dst_desc, this->engine_}); + + auto reorder_pd = mkldnn::reorder::primitive_desc( + src_mem.get_primitive_desc(), dst_mem.get_primitive_desc(), attr); + + auto reorder = mkldnn::reorder(reorder_pd, src_mem, dst_mem); + stream(stream::kind::eager).submit({reorder}).wait(); + + return dst_mem; + } + + memory QuantInputY(memory input_y, const std::vector &scale_y) { + const auto &dims = input_y.get_primitive_desc().desc().data.dims; + auto ndims = input_y.get_primitive_desc().desc().data.ndims; + auto y_dims = std::vector(dims, dims + ndims); + + auto user_y_desc = + this->template CreateMemDescriptor(y_dims, memory::format::oi); + auto y_desc = + this->template CreateMemDescriptor(y_dims, memory::format::oi); + + return ReorderWithScale(user_y_desc, y_desc, input_y.get_data_handle(), + scale_y); + } + + mkldnn::primitive_attr CreateMulAttr(const ExecutionContext &ctx, + bool force_fp32_output) { + mkldnn::primitive_attr mul_attr; + + auto scale_y_data = ctx.Attr>("scale_y"); + auto scale_x_data = ctx.Attr("scale_x"); + auto scale_out_data = + force_fp32_output ? 1.0f : ctx.Attr("scale_out"); + + bool is_multi_channel = scale_y_data.size() > 1; + int count = is_multi_channel ? scale_y_data.size() : 1; + std::vector output_shift_scale(count); + for (int i = 0; i < count; i++) { + if (scale_y_data[i] == 0.0) + output_shift_scale[i] = scale_out_data; + else + output_shift_scale[i] = + scale_out_data / (scale_x_data * scale_y_data[i]); + } + int mul_mask = is_multi_channel ? 1 : 0; + mul_attr.set_output_scales(mul_mask, output_shift_scale); + + return mul_attr; + } + + inner_product_forward CreateMulPrimitive(const memory &x_memory, + const memory &y_memory, + const memory::desc &dst_desc, + Tensor *output, + const ExecutionContext &ctx) { + const auto x_desc = x_memory.get_primitive_desc().desc(); + const auto y_desc = y_memory.get_primitive_desc().desc(); + bool force_fp32_output = ctx.Attr("force_fp32_output"); + + mkldnn::primitive_attr mul_attr = CreateMulAttr(ctx, force_fp32_output); + auto mul_prim_desc = CreateMulPrimDesc(x_desc, y_desc, dst_desc, mul_attr); + + this->output_ = this->CreateDstMemory(mul_prim_desc, ctx, output); + + return inner_product_forward(mul_prim_desc, x_memory, y_memory, + *(this->output_)); + } + + inner_product_forward::primitive_desc CreateMulPrimDesc( + const memory::desc &x_desc, const memory::desc &y_desc, + const memory::desc &dst_desc, const mkldnn::primitive_attr &mul_attr) { + const auto &mul_desc = inner_product_forward::desc( + prop_kind::forward, x_desc, y_desc, dst_desc); + + return inner_product_forward::primitive_desc(mul_desc, mul_attr, + this->engine_); + } +}; + +static std::string GetHash(const Tensor *input_x, const Tensor *input_y, + const std::string &suffix) { + auto dim2str = [](const DDim &operand_dims) { + std::string str = ""; + for (int i = 0; i < operand_dims.size(); ++i) { + str += std::to_string(operand_dims[i]) + "-"; + } + return str; + }; + + std::string hash = std::to_string((unsigned)input_x->format()) + + std::to_string((unsigned)input_x->type()) + + dim2str(input_x->dims()) + + std::to_string((unsigned)input_y->format()) + + std::to_string((unsigned)input_y->type()) + + dim2str(input_y->dims()) + suffix; + + return hash; +} + +/* OT: output data type */ +template +std::shared_ptr> GetPrimitiveFactory( + const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx, + const Tensor *input_x, const Tensor *input_y, + const mkldnn::engine &mkldnn_engine, bool enable_quant) { + const std::string key = GetHash(input_x, input_y, ctx.op().Output("Out")); + + auto prim_creator = std::static_pointer_cast>( + dev_ctx.GetBlob(key)); + + if (prim_creator == nullptr) { + prim_creator = + enable_quant + ? std::make_shared>( + mkldnn_engine) + : std::make_shared>(mkldnn_engine); + dev_ctx.SetBlob(key, prim_creator); + } + + return prim_creator; +} + +template +inner_product_forward GetMulPrimitive(const MKLDNNDeviceContext &dev_ctx, + const ExecutionContext &ctx, + const Tensor *input_x, + const Tensor *input_y, Tensor *output, + const mkldnn::engine &mkldnn_engine) { + bool enable_quant = + std::is_same::value || std::is_same::value; + bool force_fp32_output = ctx.Attr("force_fp32_output"); + + if (enable_quant && !force_fp32_output) { + return GetPrimitiveFactory(dev_ctx, ctx, input_x, input_y, + mkldnn_engine, enable_quant) + ->CreateMulPrimitive(input_x, input_y, output, ctx); + + } else { + return GetPrimitiveFactory(dev_ctx, ctx, input_x, input_y, + mkldnn_engine, enable_quant) + ->CreateMulPrimitive(input_x, input_y, output, ctx); + } +} + +/* XT: input x data type, YT: input y data type */ +template +class MulMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const ExecutionContext &ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto &dev_ctx = ctx.template device_context(); + const auto &mkldnn_engine = dev_ctx.GetEngine(); + + const Tensor *x = ctx.Input("X"); + const Tensor *y = ctx.Input("Y"); + Tensor *out = ctx.Output("Out"); + auto out_dims = out->dims(); + + auto mul = GetMulPrimitive(dev_ctx, ctx, x, y, out, mkldnn_engine); + + stream(stream::kind::eager).submit({mul}).wait(); + + if (out_dims.size() != 2) { + out->Resize(out_dims); + } + out->set_layout(DataLayout::kMKLDNN); + out->set_format(out->format()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace, + U8, ops::kMULMKLDNNINT8, + ops::MulMKLDNNKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(mul, MKLDNN, ::paddle::platform::CPUPlace, + S8, ops::kMULMKLDNNINT8, + ops::MulMKLDNNKernel); + +REGISTER_OP_KERNEL(mul, MKLDNN, ::paddle::platform::CPUPlace, + ops::MulMKLDNNKernel); diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index bbf9fbfa1f..381f7c1e98 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -17,6 +17,9 @@ limitations under the License. */ #include #include #include +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace operators { @@ -76,6 +79,30 @@ class MulOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); ctx->ShareLoD("X", /*->*/ "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + auto input_data_type = ctx.Input("X")->type(); +#ifdef PADDLE_WITH_MKLDNN + if (library == framework::LibraryType::kPlain && + platform::CanMKLDNNBeUsed(ctx)) { + library = framework::LibraryType::kMKLDNN; + layout = framework::DataLayout::kMKLDNN; + + if (input_data_type == framework::DataTypeTrait::DataType || + input_data_type == framework::DataTypeTrait::DataType) { + customized_type_value = kMULMKLDNNINT8; + } + } +#endif + + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library, customized_type_value); + } }; class MulOpMaker : public framework::OpProtoAndCheckerMaker { @@ -84,6 +111,9 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "(Tensor), The first input tensor of mul op."); AddInput("Y", "(Tensor), The second input tensor of mul op."); AddOutput("Out", "(Tensor), The output tensor of mul op."); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddAttr( "x_num_col_dims", R"DOC((int, default 1), The mul_op can take tensors with more than two @@ -114,6 +144,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { )DOC") .SetDefault(1) .EqualGreaterThan(1); + AddAttr("scale_x", + "scale_x to used for int8 input data x." + "Only used with MKL-DNN INT8") + .SetDefault(1.0f); + AddAttr>("scale_y", + "scale_y to used for int8 input data y." + "Only used with MKL-DNN INT8") + .SetDefault({1.0f}); + AddAttr("scale_out", + "scale_out to be used for int8 output data." + "Only used with MKL-DNN INT8") + .SetDefault(1.0f); + AddAttr( + "force_fp32_output", + "(bool, default false) Force quantize kernel output FP32, only " + "used in quantized MKL-DNN.") + .SetDefault(false); AddComment(R"DOC( Mul Operator. @@ -237,14 +284,19 @@ class MulDoubleGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; REGISTER_OPERATOR(mul, ops::MulOp, ops::MulOpMaker, ops::MulOpInferVarType, ops::MulOpGradMaker); + REGISTER_OPERATOR(mul_grad, ops::MulGradOp, ops::MulDoubleGradMaker); + REGISTER_OPERATOR(mul_grad_grad, ops::MulDoubleGradOp); + REGISTER_OP_CPU_KERNEL( mul, ops::MulKernel, ops::MulKernel); + REGISTER_OP_CPU_KERNEL( mul_grad, ops::MulGradKernel, ops::MulGradKernel); + REGISTER_OP_CPU_KERNEL( mul_grad_grad, ops::MulDoubleGradKernel, diff --git a/paddle/fluid/operators/mul_op.h b/paddle/fluid/operators/mul_op.h index c77eb5c4cc..3a13e0576e 100644 --- a/paddle/fluid/operators/mul_op.h +++ b/paddle/fluid/operators/mul_op.h @@ -24,6 +24,8 @@ namespace operators { using Tensor = framework::Tensor; +constexpr int kMULMKLDNNINT8 = 1; + template class MulKernel : public framework::OpKernel { public: diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index e53064893e..dafdb4eab9 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include #include #include #include "paddle/fluid/framework/operator.h" @@ -89,6 +90,16 @@ inline mkldnn::memory::data_type MKLDNNGetDataType() { return mkldnn::memory::f32; } +template <> +inline mkldnn::memory::data_type MKLDNNGetDataType() { + return mkldnn::memory::s8; +} + +template <> +inline mkldnn::memory::data_type MKLDNNGetDataType() { + return mkldnn::memory::u8; +} + inline void Reorder(const mkldnn::memory& src, const mkldnn::memory& dst) { auto reorder_prim = mkldnn::reorder(src, dst); std::vector pipeline; diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_mul_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_mul_int8_mkldnn_op.py new file mode 100644 index 0000000000..51ab00e191 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_mul_int8_mkldnn_op.py @@ -0,0 +1,166 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle.fluid.core as core +from paddle.fluid.tests.unittests.op_test import OpTest +''' + test case for s8 * s8 +''' + + +class TestMKLDNNMulOpS8S8(OpTest): + def setUp(self): + self.op_type = "mul" + self.init_kernel_type() + self.init_data_type() + self.init_data() + self.attrs = { + "use_mkldnn": self.use_mkldnn, + "scale_x": self.scale_x, + "scale_y": self.scale_y, + "scale_out": self.scale_out, + "force_fp32_output": self.force_fp32, + } + + def init_kernel_type(self): + self.use_mkldnn = True + self.force_fp32 = True + + def init_data_type(self): + self.srctype = np.uint8 + self.dsttype = np.float32 if self.force_fp32 else np.int8 + + def init_data(self): + self.scale_x = 0.6 + self.scale_y = [0.8] + self.scale_out = 1.0 + + # limit random range inside |-127, 127| to avoid overflow on SKL + if self.srctype == np.int8: + A_data = np.random.randint(-127, 127, (2, 5)).astype(np.int8) + else: + A_data = np.random.randint(0, 127, (2, 5)).astype(np.uint8) + + B_data = np.random.uniform(-127, 127, (5, 3)).astype(np.float32) + + quant_B = np.round(B_data * self.scale_y[0]).astype(np.int) + output = np.dot(A_data, quant_B) + + scale_output_shift = (self.scale_out) / \ + (self.scale_x * self.scale_y[0]) + + if (self.force_fp32): + output = (output * scale_output_shift).astype(self.dsttype) + else: + output = np.round(output * scale_output_shift).astype(self.dsttype) + + self.inputs = {'X': A_data, 'Y': B_data} + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace(), atol=0) + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +''' + test case for s8 * u8 +''' + + +class TestMKLDNNMulOpS8U8(TestMKLDNNMulOpS8S8): + def init_data_type(self): + self.srctype = np.uint8 + self.dsttype = np.float32 if self.force_fp32 else np.int8 + + +''' + test case for s8 * s8 +''' + + +class TestMKLDNNMulOpS8S8WithFlatten(TestMKLDNNMulOpS8S8): + def setUp(self): + self.op_type = "mul" + self.init_kernel_type() + self.init_data_type() + self.init_data() + self.attrs = { + "use_mkldnn": self.use_mkldnn, + "scale_x": self.scale_x, + "scale_y": self.scale_y, + "scale_out": self.scale_out, + "force_fp32_output": self.force_fp32, + "x_num_col_dims": 2, + "y_num_col_dims": 2, + } + + def init_data(self): + self.scale_x = 0.6 + self.scale_y = [0.8] + self.scale_out = 1.0 + + # limit random range inside |-127, 127| to avoid overflow on SKL + if self.srctype == np.int8: + A_data = np.random.randint(-127, 127, (3, 4, 4, 3)).astype(np.int8) + else: + A_data = np.random.randint(0, 127, (3, 4, 4, 3)).astype(np.uint8) + + B_data = np.random.uniform(-127, 127, + (2, 6, 1, 2, 3)).astype(np.float32) + + A_data_reshape = A_data.reshape(3 * 4, 4 * 3) + B_data_reshape = B_data.reshape(2 * 6, 1 * 2 * 3) + + quant_B = np.round(B_data_reshape * self.scale_y[0]).astype(np.int) + output = np.dot(A_data_reshape, quant_B) + + scale_output_shift = (self.scale_out) / \ + (self.scale_x * self.scale_y[0]) + + if (self.force_fp32): + output = (output * scale_output_shift).astype(self.dsttype) + else: + output = np.round(output * scale_output_shift).astype(self.dsttype) + + output = output.reshape(3, 4, 1, 2, 3) + + self.inputs = {'X': A_data, 'Y': B_data} + self.outputs = {'Out': output} + + +''' + test case for s8 * u8 +''' + + +class TestMKLDNNMulOpS8U8WithFlatten(TestMKLDNNMulOpS8S8WithFlatten): + def init_data_type(self): + self.srctype = np.uint8 + self.dsttype = np.float32 if self.force_fp32 else np.int8 + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_operator_desc.py b/python/paddle/fluid/tests/unittests/test_operator_desc.py index aa9634a2d4..5932112c3c 100644 --- a/python/paddle/fluid/tests/unittests/test_operator_desc.py +++ b/python/paddle/fluid/tests/unittests/test_operator_desc.py @@ -69,7 +69,8 @@ class TestOperator(unittest.TestCase): set(mul_op.attr_names), set([ "x_num_col_dims", "y_num_col_dims", "op_role", "op_role_var", - "op_namescope", "op_callstack" + "use_mkldnn", "scale_x", "scale_y", "scale_out", + "force_fp32_output", "op_namescope", "op_callstack" ])) self.assertEqual(mul_op.has_attr("x_num_col_dims"), True) self.assertEqual(mul_op.attr_type("x_num_col_dims"), core.AttrType.INT) -- GitLab