From a7944904d3b50aeffcaae61f00ecd5d6920805b3 Mon Sep 17 00:00:00 2001 From: Jacek Czaja Date: Thu, 18 Jun 2020 12:54:07 +0200 Subject: [PATCH] [oneDNN]elementwise_add and elementwise_mul int8 support (#24984) * Start implementing int8 eltwise add test=develop * - Fix to Michal PR * - Fix test=develop * - Lint fixes test=develop * - Added checking if elementwise_mul can be used test=develop * - Added attribs to skip_attrs_set test=develop * - Improved broadcasting test=develop - fixes to compilation - fix - fix - Lint fixes test=develop * - removed redundant condition test=develop Co-authored-by: Michal Gallus --- .../framework/ir/graph_pattern_detector.cc | 12 +- .../elementwise/elementwise_mul_op.h | 42 ++-- .../operators/elementwise/elementwise_op.h | 25 +- .../mkldnn/elementwise_add_mkldnn_op.cc | 90 ++----- .../mkldnn/elementwise_mkldnn_op.h | 79 ++++++ .../mkldnn/elementwise_mul_mkldnn_op.cc | 96 +------- paddle/fluid/platform/mkldnn_reuse.h | 92 +++++-- python/paddle/fluid/layers/nn.py | 6 +- .../mkldnn/test_elementwise_add_mkldnn_op.py | 103 +++++++- .../mkldnn/test_elementwise_mul_mkldnn_op.py | 224 ++++++++---------- python/paddle/tensor/math.py | 6 +- 11 files changed, 429 insertions(+), 346 deletions(-) create mode 100644 paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index c45d6b5282..5c2301d6e0 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -1881,8 +1881,16 @@ PDNode *patterns::MultipleQuantize::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() { const std::unordered_set &supported_op_types = { - "abs", "elementwise_add", "gelu", "leaky_relu", "relu", "softmax", - "sqrt", "swish", "tanh"}; + "abs", + "elementwise_mul", + "elementwise_add", + "gelu", + "leaky_relu", + "relu", + "softmax", + "sqrt", + "swish", + "tanh"}; auto possible_inplace_op = pattern->NewNode(inplace_to_be_op_repr()) ->assert_is_ops(supported_op_types); diff --git a/paddle/fluid/operators/elementwise/elementwise_mul_op.h b/paddle/fluid/operators/elementwise/elementwise_mul_op.h index c3695cabe7..718321b441 100644 --- a/paddle/fluid/operators/elementwise/elementwise_mul_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_mul_op.h @@ -28,38 +28,30 @@ class ElementwiseMulOp : public ElementwiseOp { using Tensor = framework::Tensor; using ElementwiseOp::ElementwiseOp; -#ifdef PADDLE_WITH_MKLDNN - static bool AreDimsAndFormatCorrect(const framework::ExecutionContext& ctx, - int simd_width, - mkldnn::memory::format_tag x_format) { - using Tensor = framework::Tensor; - using paddle::framework::vectorize; - using mkldnn::memory; - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto x_dims = vectorize(x->dims()); - const bool are_dims_divisable = !(x_dims[1] % simd_width); - const bool is_x_format_correct = x->format() == x_format; - const bool is_y_format_correct = vectorize(y->dims()).size() == 2; - return are_dims_divisable && is_x_format_correct && is_y_format_correct; - } -#endif - framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN using mkldnn::memory; - if (platform::CanMKLDNNBeUsed(ctx)) { - bool can_use_avx512_kernel = - platform::MayIUse(platform::avx512f) && - AreDimsAndFormatCorrect(ctx, 16, memory::format_tag::nChw16c); - if (can_use_avx512_kernel) { - return framework::OpKernelType(input_data_type, ctx.GetPlace(), - framework::DataLayout::kMKLDNN, - framework::LibraryType::kMKLDNN); + auto CanMKLDNNElementwiseMulBeUsed = [&]() { + auto x_dims = ctx.Input("X")->dims(); + auto y_dims = ctx.Input("Y")->dims(); + int rankdiff = x_dims.size() - y_dims.size(); + // TODO(jczaja): Remove this when oneDNN performance for scalar + // broadcasting + // is improved (Ernie large situation) + if (rankdiff != 0 && y_dims.size() == 1 && y_dims[0] == 1) { + return false; } + + return true; + }; + + if (platform::CanMKLDNNBeUsed(ctx) && CanMKLDNNElementwiseMulBeUsed()) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); } #endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); diff --git a/paddle/fluid/operators/elementwise/elementwise_op.h b/paddle/fluid/operators/elementwise/elementwise_op.h index c24b5c0208..f32086c94a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_op.h @@ -100,15 +100,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); #ifdef PADDLE_WITH_MKLDNN - auto CanMKLDNNElementwiseAddBeUsed = [&]() { - int axis = ctx.Attr("axis"); - int rankdiff = ctx.Input("X")->dims().size() - - ctx.Input("Y")->dims().size(); - return (rankdiff == 0) || (axis == -1) || (axis == rankdiff); - }; - - if (platform::CanMKLDNNBeUsed(ctx) && - (ctx.Type() != "elementwise_add" || CanMKLDNNElementwiseAddBeUsed())) { + if (platform::CanMKLDNNBeUsed(ctx)) { return framework::OpKernelType(input_data_type, ctx.GetPlace(), framework::DataLayout::kMKLDNN, framework::LibraryType::kMKLDNN); @@ -148,6 +140,21 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker { .SetDefault(""); AddAttr("y_data_format", "This parameter is no longer used.") .SetDefault(""); + /* int8 parameters */ + AddAttr("use_quantizer", + "(bool, default false) " + "Set to true for operators that should be quantized and use " + "int8 kernel. Only used on CPU.") + .SetDefault(false); + AddAttr("Scale_x", + "(float, default 1.0f), The quantize scale of X tensor") + .SetDefault(1.0f); + AddAttr("Scale_y", + "(float, default 1.0f), The quantize scale of Y tensor") + .SetDefault(1.0f); + AddAttr("Scale_out", + "(float, default 1.0f), The quantize scale of output data") + .SetDefault(1.0f); AddOpComment(); } diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index 98b79d6bb2..caaaf2c931 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -1,74 +1,21 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/memory/memcpy.h" -#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" - -#include "paddle/fluid/framework/data_layout_transform.h" -#include "paddle/fluid/platform/mkldnn_reuse.h" +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" namespace paddle { namespace operators { - -using framework::DataLayout; -using framework::Tensor; -using mkldnn::memory; -using mkldnn::primitive; -using mkldnn::reorder; -using mkldnn::stream; -using mkldnn::sum; - -template -class EltwiseAddMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto& dev_ctx = - ctx.template device_context(); - const auto& mkldnn_engine = dev_ctx.GetEngine(); - - const auto* x = ctx.Input("X"); - const auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - - platform::BinaryMKLDNNHandler handler( - dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, ctx.OutputName("Out")); - - const auto src_x_memory = handler.AcquireSrcMemory(x); - const auto src_y_memory = handler.AcquireSecondSrcMemory(y); - - // For Inplace src and and dst are the same memory object - const auto dst_memory = - x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z); - - const auto binary_prim = handler.AcquireForwardPrimitive(); - - mkldnn::stream astream(mkldnn_engine); - - const std::unordered_map args = { - {DNNL_ARG_SRC_0, *src_x_memory}, - {DNNL_ARG_SRC_1, *src_y_memory}, - {DNNL_ARG_DST, *dst_memory}}; - - binary_prim->execute(astream, args); - astream.wait(); - - z->set_layout(DataLayout::kMKLDNN); - z->set_format(platform::GetMKLDNNFormat(*dst_memory)); - } -}; - template class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { public: @@ -106,8 +53,11 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { namespace ops = paddle::operators; -REGISTER_OP_KERNEL(elementwise_add, MKLDNN, ::paddle::platform::CPUPlace, - ops::EltwiseAddMKLDNNKernel) +REGISTER_OP_KERNEL( + elementwise_add, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel) REGISTER_OP_KERNEL(elementwise_add_grad, MKLDNN, ::paddle::platform::CPUPlace, ops::EltwiseAddMKLDNNGradKernel) diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h new file mode 100644 index 0000000000..c5f55138d9 --- /dev/null +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -0,0 +1,79 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" + +#include "paddle/fluid/framework/data_layout_transform.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using framework::DataLayout; +using framework::Tensor; +using mkldnn::memory; +using mkldnn::primitive; +using mkldnn::stream; + +template +class EltwiseMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("X"); + const auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + + float scale_x = ctx.Attr("Scale_x"); + float scale_y = ctx.Attr("Scale_y"); + float scale_o = ctx.Attr("Scale_out"); + + int axis = ctx.Attr("axis"); + + platform::BinaryMKLDNNHandler handler( + BINARY_OP, axis, dev_ctx, mkldnn_engine, ctx.GetPlace(), x, y, z, + scale_x, scale_y, scale_o, ctx.OutputName("Out")); + + const auto src_x_memory = handler.AcquireSrcMemory(x); + const auto src_y_memory = handler.AcquireSecondSrcMemory(y); + + // For Inplace src and and dst are the same memory object + const auto dst_memory = + x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z); + + const auto binary_prim = handler.AcquireForwardPrimitive(); + + mkldnn::stream astream(mkldnn_engine); + + const std::unordered_map args = { + {DNNL_ARG_SRC_0, *src_x_memory}, + {DNNL_ARG_SRC_1, *src_y_memory}, + {DNNL_ARG_DST, *dst_memory}}; + + binary_prim->execute(astream, args); + astream.wait(); + + z->set_layout(DataLayout::kMKLDNN); + z->set_format(platform::GetMKLDNNFormat(*dst_memory)); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index 695ec23dba..c73b502a40 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,94 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" - -#include "paddle/fluid/operators/jit/kernels.h" -#include "paddle/fluid/platform/cpu_info.h" -#include "paddle/fluid/platform/mkldnn_helper.h" - -#ifdef PADDLE_WITH_XBYAK -#include "xbyak/xbyak.h" -#include "xbyak/xbyak_util.h" -#endif - -namespace paddle { -namespace operators { - -using framework::DataLayout; -using mkldnn::memory; -using platform::StringToMKLDNNFormat; - -template -static void ComputeBroadcastedMultiply(const T* x_data, const T* y_data, - T* z_data, int64_t n, int64_t c, - int64_t h, int64_t w, int simd_width, - void (*multiply)(const T*, const T*, T*, - int, int)) { - const int64_t C = c / simd_width; -#pragma omp parallel for collapse(2) - for (int ni = 0; ni < n; ni++) { - for (int ci = 0; ci < C; ci++) { - auto ptr_x = - x_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - - auto ptr_y = y_data + ni * C * simd_width + ci * simd_width; - auto ptr_z = - z_data + ni * C * h * w * simd_width + ci * h * w * simd_width; - - multiply(ptr_x, ptr_y, ptr_z, h, w); - } - } -} - -template -class ElementwiseMulMKLDNNKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using Tensor = framework::Tensor; - - int axis = ctx.Attr("axis"); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* z = ctx.Output("Out"); - const T* x_data = x->data(); - const T* y_data = y->data(); - T* z_data = z->mutable_data(ctx.GetPlace()); - - auto x_dims = x->dims(); - auto y_dims_untrimmed = y->dims(); - auto x_int_dims = paddle::framework::vectorize(x_dims); - - int pre, num, post, is_run_common_broadcast; - get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &num, &post, - &is_run_common_broadcast); - - if (post == 1) - PADDLE_THROW( - platform::errors::Unimplemented("Not implemented when post is 1.")); - - const int64_t n = x_dims[0]; - const int64_t c = x_dims[1]; - const int64_t h = x_dims[2]; - const int64_t w = x_dims[3]; - - const int simd_width = 16; - auto multiply = - jit::KernelFuncs, platform::CPUPlace>::Cache() - .At(0); - ComputeBroadcastedMultiply(x_data, y_data, z_data, n, c, h, w, simd_width, - multiply); - - z->set_layout(DataLayout::kMKLDNN); - z->set_format(x->format()); - } -}; -} // namespace operators -} // namespace paddle +#include "paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h" namespace ops = paddle::operators; -REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, - ops::ElementwiseMulMKLDNNKernel) +REGISTER_OP_KERNEL( + elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel, + ops::EltwiseMKLDNNKernel) diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index ff42bb4144..f76df5ca26 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -120,8 +120,12 @@ class MKLDNNHandlerT { return (dev_ctx_.GetBlob(key_p) != nullptr); } - template - void AcquireForwardPrimitiveDescriptor(Args&&... args) { + // If your primitive descriptor requires attributes, pass them as a + // first argument and paramters to descriptor constructor in the following + // arguments. Otherwise, all arguments will be forwarded to descriptor + // constructor, including the first one. + template + void AcquireForwardPrimitiveDescriptor(Arg&& first_arg, Args&&... args) { // Forward PD has to be passed to Grad op that // may be executed by diffrent thread, hence // for that one we use key that does not contain TID @@ -135,14 +139,34 @@ class MKLDNNHandlerT { fwd_pd_ = std::static_pointer_cast( dev_ctx_.GetBlob(key_pd)); if (fwd_pd_ == nullptr) { - auto fwd_desc = typename TForward::desc(std::forward(args)...); - fwd_pd_ = std::make_shared(fwd_desc, - engine_); + CreateForwardPrimitiveDescriptor(first_arg, + std::forward(args)...); dev_ctx_.SetBlob(key_pd, fwd_pd_); } } } + // Using sfinae to specialise variadic function. Workaround for not having + // if constexpr in C++ 11. + template + typename std::enable_if::type, + dnnl::primitive_attr>::value>::type + CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) { + auto fwd_desc = typename TForward::desc(std::forward(args)...); + fwd_pd_ = std::make_shared( + fwd_desc, first, engine_); + } + + template + typename std::enable_if::type, + dnnl::primitive_attr>::value>::type + CreateForwardPrimitiveDescriptor(First&& first, Args&&... args) { + auto fwd_desc = typename TForward::desc(std::forward(first), + std::forward(args)...); + fwd_pd_ = + std::make_shared(fwd_desc, engine_); + } + template void AcquireBackwardPrimitiveDescriptor(Args&&... args) { const std::string key_fwd_pd = key_common_ + "@forward_pd"; @@ -385,18 +409,23 @@ class MKLDNNHandler { template class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { public: - BinaryMKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, + BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis, + const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, Tensor* z, + float scale_x, float scale_y, float scale_z, const std::string& uniq_name) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, - platform::CreateKey(framework::vectorize(x->dims()), uniq_name)) { - // bradcasting combined with in-place may require longer key + platform::CreateKey( + framework::vectorize(x->dims()), + uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { + // bradcasting combined with in-place may require auto rankdiff = x->dims().size() - y->dims().size(); if (rankdiff > 0) { - this->key_ += std::to_string(rankdiff); - this->key_common_ += std::to_string(rankdiff); + auto suffix = std::to_string(rankdiff); + this->key_ += suffix; + this->key_common_ += suffix; } if (!this->isCached()) { @@ -423,16 +452,17 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { auto src1_md = dnnl::memory::desc( src_y_tz, platform::MKLDNNGetDataType(), y->format()); if (rankdiff > 0) { - std::vector ones(rankdiff, 1); - std::vector dims1_ex(src_y_tz); - dims1_ex.insert(dims1_ex.begin(), ones.begin(), ones.end()); + std::vector dims1_ex(rankdiff, 1); + dims1_ex.insert(next(dims1_ex.begin(), (axis == -1 ? rankdiff : axis)), + src_y_tz.begin(), src_y_tz.end()); src1_md = src1_md.reshape(dims1_ex); } const auto dst_md = memory::desc(dst_tz, platform::MKLDNNGetDataType(), MKLDNNMemoryFormat::any); - this->AcquireForwardPrimitiveDescriptor(dnnl::algorithm::binary_add, - src0_md, src1_md, dst_md); + auto attributes = CreateAttributes(algo, scale_x, scale_y, scale_z); + this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, + src1_md, dst_md); } } @@ -442,6 +472,38 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT { return this->AcquireMemoryFromPrimitive( this->fwd_pd_->src1_desc(), to_void_cast(input_data), "@src1_mem_p"); } + + private: + static inline dnnl::primitive_attr CreateAttributes(dnnl::algorithm op, + float scale_x, + float scale_y, + float scale_z) { + // Scales set in attributes for inputs contibute to the output equation + // in the following way (assuming no broadcasting takes place): + // output_i = scale_0 * x_i <+ or *> scale_1 * y_i; + // Hence we have to create scales that will: + // 1. Dequantize both values, by multiplying with (1.0 / scale_x_or_y) + // 2. Quantize their result to output scale range, by multiplying with + // (scale_z) + // If we combine these two, we end up with following equation + // output = scale_out * (1/scale_x * x <* or +> 1/scale_y * y) + // Hence, to mimic such behaviour using provided interface, + // For add operation the equation is equal to: + // output = (scale_out / scale_x) * x + (scale_out / scale_y) * y + // + // For mul operation on the other hand + // output = (scale_out / scale_x) * x * (1.0 / scale_y) * y + // + float scale_0 = scale_z / scale_x; + float scale_1 = + op == dnnl::algorithm::binary_add ? scale_z / scale_y : 1.0 / scale_y; + dnnl::primitive_attr attributes; + attributes.set_scales(/* input_x_id = */ DNNL_ARG_SRC_0, /* mask = */ 0, + {scale_0}); + attributes.set_scales(/* input_y_id = */ DNNL_ARG_SRC_1, /* mask = */ 0, + {scale_1}); + return attributes; + } }; class SumMKLDNNHandler : public MKLDNNHandler { diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index ebc97cad58..11a4d93324 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -11907,8 +11907,10 @@ for func in [ Default is None. It's used to print debug info for developers. Details: \ :ref:`api_guide_Name` " ], - skip_attrs_set={"x_data_format", "y_data_format", "axis" - }) + """\n""" + str(func.__doc__) + skip_attrs_set={ + "x_data_format", "y_data_format", "axis", "use_quantizer", + "Scale_x", "Scale_y", "Scale_out" + }) + """\n""" + str(func.__doc__) for func in []: op_proto = OpProtoHolder.instance().get_op_proto(func.__name__) diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py index ba0690841d..532c6a606d 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_add_mkldnn_op.py @@ -15,19 +15,11 @@ from __future__ import print_function import unittest import numpy as np -from paddle.fluid.tests.unittests.test_elementwise_add_op import * -''' -MKLDNN does not support tensors of dimensions number equal to 3. -Such dimensions cause exceptions in MKLDNN reorder primitive. -The DNNL-based kernel is used only when broadcasting is not required -(see GetExpectedKernelType() methods in elementwise_add_op.h). -''' +from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci +from paddle.fluid.tests.unittests.test_elementwise_add_op import TestElementwiseAddOp class TestMKLDNNElementwiseAddOp(TestElementwiseAddOp): - def init_data_format(self): - self.data_format = 'MKLDNN' - def init_kernel_type(self): self.use_mkldnn = True @@ -66,5 +58,96 @@ class TestMKLDNNElementwiseAddOp4(TestMKLDNNElementwiseAddOp): pass +class TestMKLDNNElementwiseAddOp_broadcast_3(TestMKLDNNElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12, 1) + + def init_axis(self): + self.axis = 1 + + +''' INT8 Tests ''' + + +@skip_check_grad_ci( + reason="oneDNN's int8 elementwise_ops don't implemend grad kernel.") +class TestInt8(TestElementwiseAddOp): + def init_kernel_type(self): + self.use_mkldnn = True + self._cpu_only = True + + def init_dtype(self): + self.dtype = np.int8 + + def init_input_output(self): + self.x = np.random.randint(0, 3, (12, 9)).astype("int8") + self.y = np.random.randint(0, 3, (12, 9)).astype("int8") + self.out = np.add(self.x, self.y) + + def init_scales(self): + self.attrs['Scale_x'] = 1.0 + self.attrs['Scale_y'] = 1.0 + self.attrs['Scale_out'] = 1.0 + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.init_scales() + self.check_output(check_dygraph=(self.use_mkldnn == False)) + + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +class TestInt8Scales(TestInt8): + def quantize(self, tensor, dt="int8"): + max_int = 127.0 if dt == "int8" else 255.0 + scale = max_int / np.abs(np.amax(tensor)) + quantized = np.round(scale * tensor).astype(dt) + return scale, quantized + + def init_input_output(self): + self.x_f = np.random.random((100, )).astype("float") + self.y_f = np.random.random((100, )).astype("float") + self.out_f = np.add(self.x_f, self.y_f) + + self.scale_x, self.x = self.quantize(self.x_f) + self.scale_y, self.y = self.quantize(self.y_f) + self.scale_o, self.out = self.quantize(self.out_f) + + def init_scales(self): + self.attrs['Scale_x'] = self.scale_x + self.attrs['Scale_y'] = self.scale_y + self.attrs['Scale_out'] = self.scale_o + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.init_scales() + int_atol = 1 # different quantization techniques + self.check_output( + check_dygraph=(self.use_mkldnn == False), atol=int_atol) + + +class TestUint8Scales(TestInt8Scales): + def init_input_output(self): + self.x_f = np.random.random((100, )).astype("float") + self.y_f = np.random.random((100, )).astype("float") + self.out_f = np.add(self.x_f, self.y_f) + + self.scale_x, self.x = self.quantize(self.x_f, "uint8") + self.scale_y, self.y = self.quantize(self.y_f, "uint8") + self.scale_o, self.out = self.quantize(self.out_f, "uint8") + + def init_dtype(self): + self.dtype = np.uint8 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py index b89b3adce3..d66f3dfb89 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_elementwise_mul_mkldnn_op.py @@ -15,137 +15,76 @@ from __future__ import print_function import unittest import numpy as np -from paddle.fluid.tests.unittests.op_test import OpTest -import paddle.fluid.core as core -from paddle.fluid.op import Operator -from paddle.fluid.tests.unittests.test_elementwise_mul_op import * -from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive -from paddle.fluid.tests.unittests.mkldnn.mkldnn_op_test import __assert_close -import paddle.fluid as fluid from paddle.fluid.tests.unittests.op_test import skip_check_grad_ci +from paddle.fluid.tests.unittests.test_elementwise_mul_op import ElementwiseMulOp -# For UT coverage, integrate conv2d + elementwise-mul so that nchw16C could be automatically chosen when mkldnn-kernel is enabled -@skip_check_grad_ci( - reason="TODO: this test cannot use white list to skip check_grad, need to add check_grad." -) -class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): - def setUp(self): +class TestMKLDNNElementwiseMulOp(ElementwiseMulOp): + def init_kernel_type(self): + self.use_mkldnn = True + + def init_dtype(self): self.dtype = np.float32 - self.init_dtype() - self.init_kernel_type() - self.init_axis() - self._cpu_only = True - self.pad = [0, 0] - self.stride = [1, 1] - self.groups = 1 - self.input_size = [1, 3, 5, 5] # NCHW - self.filter_size = [16, 3, 3, 3] - self.filter_size2 = [1, 16, 2, 2] - self.dilations = [1, 1] - self.use_cudnn = False - self.data_format = "ANYLAYOUT" - self.input = np.random.random(self.input_size).astype(self.dtype) - self.filter = np.random.random(self.filter_size).astype(self.dtype) - self.filter2 = np.random.random(self.filter_size2).astype(self.dtype) - self.elt_mul_y_size = [1, 16] - self.elt_mul_y = np.random.random(self.elt_mul_y_size).astype( - self.dtype) - conv2d_param = { - 'stride': self.stride, - 'pad': self.pad, - 'dilation': self.dilations - } - conv_out, _, _, _, _ = conv2d_forward_naive( - self.input, self.filter, self.groups, conv2d_param) #[1, 16, 2, 2] - self.conv_output = conv_out - self.elt_mul_output = self.conv_output * self.elt_mul_y.reshape( - 1, 16, 1, 1) # the result shape is [1, 16, 2, 2] - conv_output2, _, _, _, _ = conv2d_forward_naive( - self.elt_mul_output, self.filter2, self.groups, conv2d_param) - self.conv_output2 = conv_output2 - self.fetch_list = ["conv_output2"] + +class TestMKLDNNElementwiseMulOp2(TestMKLDNNElementwiseMulOp): + def init_input_output(self): + self.x = np.random.random((100, )).astype(self.dtype) + self.y = np.random.random((100, )).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + +class TestMKLDNNElementwiseMulOp3(TestMKLDNNElementwiseMulOp): + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [2, 3, 4, 5]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + +class TestMKLDNNElementwiseMulOp4(TestMKLDNNElementwiseMulOp): + def init_input_output(self): + self.x = np.random.uniform(1, 2, [2, 3, 4, 32]).astype(self.dtype) + self.y = np.random.uniform(1, 2, [4, 32]).astype(self.dtype) + self.out = np.multiply(self.x, self.y) + + # TODO(jczaja): Enable when grad is ready + def test_check_grad_normal(self): + pass + + def test_check_grad_ingore_x(self): + pass + + def test_check_grad_ingore_y(self): + pass + + +''' INT8 Tests ''' + + +@skip_check_grad_ci( + reason="oneDNN's int8 elementwise_ops don't implemend grad kernel.") +class TestInt8(ElementwiseMulOp): def init_kernel_type(self): self.use_mkldnn = True + self._cpu_only = True + + def init_dtype(self): + self.dtype = np.int8 + + def init_input_output(self): + self.x = np.random.randint(0, 3, (12, 9)).astype("int8") + self.y = np.random.randint(0, 3, (12, 9)).astype("int8") + self.out = np.multiply(self.x, self.y) - def init_axis(self): - self.axis = 0 + def init_scales(self): + self.attrs['Scale_x'] = 1.0 + self.attrs['Scale_y'] = 1.0 + self.attrs['Scale_out'] = 1.0 def test_check_output(self): - ground_truth = { - "input": self.input, - "filter": self.filter, - "filter2": self.filter2, - "conv_output": self.conv_output, - "elt_mul_y": self.elt_mul_y, - "elt_mul_output": self.elt_mul_output, - "conv_output2": self.conv_output2, - } - program = fluid.Program() - with fluid.program_guard(program): - block = program.global_block() - for name in ground_truth: - block.create_var( - name=name, dtype="float32", shape=ground_truth[name].shape) - conv2d_op = block.append_op( - type="conv2d", - inputs={ - "Input": block.var('input'), - 'Filter': block.var('filter') - }, - outputs={"Output": block.var('conv_output')}, - attrs={ - 'strides': self.stride, - 'paddings': self.pad, - 'groups': self.groups, - 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format - }) - elementwise_mul_op = block.append_op( - type="elementwise_mul", - inputs={ - 'X': block.var('conv_output'), - 'Y': block.var('elt_mul_y'), - }, - outputs={"Out": block.var('elt_mul_output')}, - attrs={ - 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn, - 'axis': self.axis - }) - conv2d_op2 = block.append_op( - type="conv2d", - inputs={ - "Input": block.var('elt_mul_output'), - 'Filter': block.var('filter2') - }, - outputs={"Output": block.var('conv_output2')}, - attrs={ - 'strides': self.stride, - 'paddings': self.pad, - 'groups': self.groups, - 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn, - 'use_mkldnn': self.use_mkldnn, - 'data_format': self.data_format - }) - place = core.CPUPlace() - exe = fluid.Executor(place) - out = exe.run( - program, - feed={ - name: ground_truth[name] - for name in ["input", "filter", "filter2", "elt_mul_y"] - }, - fetch_list=self.fetch_list) - - for id, name in enumerate(self.fetch_list): - self.assertTrue( - np.allclose( - ground_truth[name], out[id], atol=1e-4), name) + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.init_scales() + self.check_output(check_dygraph=(self.use_mkldnn == False)) def test_check_grad_normal(self): pass @@ -157,5 +96,48 @@ class TestElementwiseMulMKLDNNOp_Integrated_With_Convs(ElementwiseMulOp): pass +class TestInt8Scales(TestInt8): + def quantize(self, tensor, dt="int8"): + max_int = 127.0 if dt == "int8" else 255.0 + scale = max_int / np.abs(np.amax(tensor)) + quantized = np.round(scale * tensor).astype(dt) + return scale, quantized + + def init_input_output(self): + self.x_f = np.random.random((100, )).astype("float") + self.y_f = np.random.random((100, )).astype("float") + self.out_f = np.multiply(self.x_f, self.y_f) + + self.scale_x, self.x = self.quantize(self.x_f) + self.scale_y, self.y = self.quantize(self.y_f) + self.scale_o, self.out = self.quantize(self.out_f) + + def init_scales(self): + self.attrs['Scale_x'] = self.scale_x + self.attrs['Scale_y'] = self.scale_y + self.attrs['Scale_out'] = self.scale_o + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.init_scales() + int_atol = 1 # different quantization techniques + self.check_output( + check_dygraph=(self.use_mkldnn == False), atol=int_atol) + + +class TestUint8Scales(TestInt8Scales): + def init_input_output(self): + self.x_f = np.random.random((100, )).astype("float") + self.y_f = np.random.random((100, )).astype("float") + self.out_f = np.multiply(self.x_f, self.y_f) + + self.scale_x, self.x = self.quantize(self.x_f, "uint8") + self.scale_y, self.y = self.quantize(self.y_f, "uint8") + self.scale_o, self.out = self.quantize(self.out_f, "uint8") + + def init_dtype(self): + self.dtype = np.uint8 + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 9144a29ac9..7cc19186d0 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -703,9 +703,9 @@ for func in [ func.__doc__ = _generate_doc_string_( op_proto, additional_args_lines=additional_args_lines, - skip_attrs_set={"x_data_format", "y_data_format", "axis" - }) + """\n""" + str(func.__doc__) - + skip_attrs_set={"x_data_format", "y_data_format", "axis", + "use_quantizer", "Scale_x", "Scale_y", "Scale_out" + }) + """\n""" + str(func.__doc__) def sum(input, dim=None, dtype=None, keep_dim=False, name=None): """ -- GitLab