diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc index b43dddfcf19db3ec959656a3c3918d2be244af5a..8f519de075760ebfebe82b3aa27f125bbd584301 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc @@ -86,7 +86,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel { platform::ReductionMKLDNNHandler handler_sum( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out"))); + ctx.InputName(framework::GradVarName("Out")), + CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); reduction_p->execute(astream, {{DNNL_ARG_SRC, *reorder_src_memory_p}, diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h index df827117a0d302294e1c6259b92c21c682838f31..e5d20893335f702c0188ff7a8deaa2b41b848b85 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h @@ -81,5 +81,20 @@ class EltwiseMKLDNNKernel : public framework::OpKernel { z->set_format(platform::GetMKLDNNFormat(*dst_memory)); } }; + +inline std::vector CalculateBroadcastedDims(const Tensor* x, + const Tensor* y) { + const auto src_tz = framework::vectorize(x->dims()); + const auto dst_tz = framework::vectorize(y->dims()); + + size_t j = 0; + std::vector dst_tz_ex(src_tz.size(), 1); + for (size_t i = 0; i < src_tz.size(); ++i) { + dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; + if (j == dst_tz.size()) break; + } + + return dst_tz_ex; +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc index c9209cc39d5e35353cc9eb50fa6b54ac67c99db4..1c246e8d18937087639129d32001a297eec3ca42 100644 --- a/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc +++ b/paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc @@ -105,7 +105,8 @@ class EltwiseMulMKLDNNGradKernel : public ElemwiseGradKernel { platform::ReductionMKLDNNHandler handler_sum( dnnl::algorithm::reduction_sum, 0.0f, 0.0f, dev_ctx, mkldnn_engine, ctx.GetPlace(), dout, dy, - ctx.InputName(framework::GradVarName("Out"))); + ctx.InputName(framework::GradVarName("Out")), + CalculateBroadcastedDims(dout, dy)); auto dy_memory_p = handler_sum.AcquireDstMemory(dy); auto reduction_p = handler_sum.AcquireForwardPrimitive(); // As source we use mem object with results from binary operation diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b18c16c8c71f7a98e1f65079031cbecc947d0344 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_max_mkldnn_op.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +template +class ReduceMaxMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_max); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_max, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMaxMKLDNNKernel, + ops::ReduceMaxMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9eed0d7eb0427e83c2eb1e7c6ed4a2d533778fe --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mean_mkldnn_op.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +template +class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_mean); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_mean, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMeanMKLDNNKernel, + ops::ReduceMeanMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce63a1485471f714ab4a9266f4a37843c3810a1f --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_min_mkldnn_op.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +template +class ReduceMinMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_min); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_min, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceMinMKLDNNKernel, + ops::ReduceMinMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h new file mode 100644 index 0000000000000000000000000000000000000000..7e09aaa126effe73bf4389c94542018dc200fe45 --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using paddle::framework::LoDTensor; +using paddle::framework::Tensor; +using platform::to_void_cast; + +template +class ReduceMKLDNNKernel : public framework::OpKernel { + public: + void RunKernel(const framework::ExecutionContext& ctx, + dnnl::algorithm reduction_type) const { + auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + + const auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + + auto reduce_dims = ctx.Attr>("dim"); + bool reduce_all = ctx.Attr("reduce_all"); + bool keep_dim = ctx.Attr("keep_dim"); + + std::vector output_dims = + CalculateOutputDims(input, output, reduce_dims, reduce_all, keep_dim); + + auto input_dims = framework::vectorize(input->dims()); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + // oneDNN reduce op does not support edge case in which memory is being + // copied without actual reduction. + // In that case reorder must be executed to maintain compatibility with + // PaddlePaddle reduce op + if (input_dims == output_dims) { + mkldnn::memory::data_type input_type = + framework::ToMKLDNNDataType(input->type()); + std::string key = platform::CreateKey( + dev_ctx, input_dims, input->format(), input->format(), input_type); + platform::ReorderMKLDNNHandler reorder_handler( + input_dims, input->type(), input_type, dev_ctx, onednn_engine, key); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + input->format(), platform::to_void_cast(input->data())); + + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + output, input->format(), ctx.GetPlace()); + + auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, + reorder_dst_memory_p); + + platform::RecordEvent record_reorder("int_reorder", + platform::EventRole::kUniqueOp); + + reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); + astream.wait(); + + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format( + platform::GetMKLDNNFormat(reorder_dst_memory_p->get_desc().reshape( + paddle::framework::vectorize(output->dims())))); + } else { + platform::ReductionMKLDNNHandler handler( + reduction_type, 0.0f, 0.0f, dev_ctx, onednn_engine, ctx.GetPlace(), + input, output, ctx.InputName("X"), output_dims); + + auto src_memory_p = handler.AcquireSrcMemory(input); + auto dst_memory_p = handler.AcquireDstMemory(output); + + std::unordered_map reduction_args = { + {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; + + auto reduction_p = handler.AcquireForwardPrimitive(); + + reduction_p->execute(astream, reduction_args); + astream.wait(); + output->set_layout(framework::DataLayout::kMKLDNN); + output->set_format( + platform::GetMKLDNNFormat(dst_memory_p->get_desc().reshape( + paddle::framework::vectorize(output->dims())))); + } + } + + private: + std::vector CalculateOutputDims(const Tensor* input, + const Tensor* output, + std::vector& reduce_dims, + bool reduce_all, + bool keep_dim) const { + if (keep_dim) return framework::vectorize(output->dims()); + + if (reduce_all) + return std::vector(framework::vectorize(input->dims()).size(), + 1); + + std::vector output_dims(framework::vectorize(input->dims())); + for (size_t i = 0; i < reduce_dims.size(); ++i) { + reduce_dims[i] = (reduce_dims[i] >= 0) + ? reduce_dims[i] + : input->dims().size() + reduce_dims[i]; + output_dims[reduce_dims[i]] = 1; + } + + return output_dims; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4676589e68910a7845a57c84ed4af2283c42328f --- /dev/null +++ b/paddle/fluid/operators/reduce_ops/mkldnn/reduce_sum_mkldnn_op.cc @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h" + +namespace paddle { +namespace operators { + +template +class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx, dnnl::algorithm::reduction_sum); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(reduce_sum, MKLDNN, paddle::platform::CPUPlace, + ops::ReduceSumMKLDNNKernel, + ops::ReduceSumMKLDNNKernel); diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.h b/paddle/fluid/operators/reduce_ops/reduce_op.h index 25f9453571ac632e70b0755ca1e5566eb5bf6ee6..280464ea85279319c82551163c461a5ce0c4c3a7 100644 --- a/paddle/fluid/operators/reduce_ops/reduce_op.h +++ b/paddle/fluid/operators/reduce_ops/reduce_op.h @@ -489,6 +489,30 @@ class ReduceOp : public framework::OperatorWithKernel { } } } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + // choose cudnn kernel if the runtime supported. + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + + if (ctx.Input("X")->dims().size() > 5) + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + + if (input_data_type == framework::proto::VarType::FP16) { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "float16 can only be used on GPU place")); + } + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ReduceOpUseInputPlace : public ReduceOp { @@ -579,6 +603,9 @@ class ReduceOpMaker : public framework::OpProtoAndCheckerMaker { "(int, default -1)" "The dtype of output, default value is -1, the dtype is same as intput") .SetDefault(-1); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); AddComment(string::Sprintf(R"DOC( %s Operator. diff --git a/paddle/fluid/platform/mkldnn_reuse.h b/paddle/fluid/platform/mkldnn_reuse.h index c79b642c51b1f58b08f70829d5c27024f094334d..0c45da63edd70ed26e427b6faec070e5292f283e 100644 --- a/paddle/fluid/platform/mkldnn_reuse.h +++ b/paddle/fluid/platform/mkldnn_reuse.h @@ -638,7 +638,8 @@ class ReductionMKLDNNHandler const float eps, const MKLDNNDeviceContext& dev_ctx, const mkldnn::engine engine, platform::Place cpu_place, const Tensor* x, const Tensor* y, - const std::string& uniq_name) + const std::string& uniq_name, + std::vector output_dims) : platform::MKLDNNHandlerT( dev_ctx, engine, cpu_place, platform::CreateKey(dev_ctx, framework::vectorize(x->dims()), @@ -653,20 +654,11 @@ class ReductionMKLDNNHandler platform::errors::InvalidArgument("Wrong format set for X tensor.")); const auto src_tz = framework::vectorize(x->dims()); - const auto dst_tz = framework::vectorize(y->dims()); - - // For oneDNN dimensionality should match so we need to - // extend Y tensor dims with values of 1 (before and after pattern) - int j = 0; - std::vector dst_tz_ex(src_tz.size(), 1); - for (size_t i = 0; i < src_tz.size(); ++i) { - dst_tz_ex[i] = (src_tz[i] != dst_tz[j]) ? 1 : dst_tz[j++]; - } const auto src_md = dnnl::memory::desc( src_tz, platform::MKLDNNGetDataType(), x->format()); const auto dst_md = memory::desc( - dst_tz_ex, platform::MKLDNNGetDataType(), x->format()); + output_dims, platform::MKLDNNGetDataType(), x->format()); this->AcquireForwardPrimitiveDescriptor(algo, src_md, dst_md, p, eps); } diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..a894d042e426c0f224d3fe13a5ded10c44cddbe5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_bf16_mkldnn_op.py @@ -0,0 +1,185 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16 +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +@unittest.skipIf(core.is_compiled_with_cuda(), + "core is compiled with CUDA which has no BF implementation") +@skip_check_grad_ci(reason="not implemented") +class TestReduceSumDefaultBF16ONEDNNOp(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.outputs = {'Out': x_fp32.sum(axis=0)} + self.attrs = {'use_mkldnn': self.use_mkldnn} + + def test_check_output(self): + self.check_output(check_dygraph=False) + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 10, 5, 5)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [2]} + self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeBF16ONEDNNOp( + TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(2, 3, 5, 6)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} + self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsBF16ONEDNNOp( + TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(2, 7, 3, 5)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [-1, -2, -3, -4]} + self.outputs = {'Out': x_fp32.sum(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.random((2, 5, 3, 2, 2)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'dim': (2, 3, 4), 'keep_dim': True, 'use_mkldnn': True} + self.outputs = { + 'Out': x_fp32.sum(axis=tuple(self.attrs['dim']), + keepdims=self.attrs['keep_dim']) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum5DReduceAllKeepDimsBF16ONEDNNOp( + TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(2, 5, 3, 2, 4)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'reduce_all': True, 'keep_dim': True, 'use_mkldnn': True} + self.outputs = {'Out': x_fp32.sum(keepdims=self.attrs['keep_dim'])} + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + x_fp32 = np.random.normal(size=(4, 3, 2, 3)).astype('float32') + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.sum()} + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMax3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'dim': [-1], 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.max(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMax4DNegativeAndPositiveDimsBF16ONEDNNOp( + TestReduceSumDefaultBF16ONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10, 9)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'dim': [-1, 0, 1], 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.max(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMin3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'dim': [2], 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.min(axis=tuple(self.attrs['dim']))} + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceMean3DBF16ONEDNNOp(TestReduceSumDefaultBF16ONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + x_fp32 = np.random.random((5, 6, 10)).astype("float32") + x_bf16 = convert_float_to_uint16(x_fp32) + self.inputs = {'X': x_bf16} + self.attrs = {'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': x_fp32.sum(axis=0) / x_fp32.shape[0]} + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c913b9eeea27df8757f7b4dba8e6c49bff4c9a85 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_reduce_mkldnn_op.py @@ -0,0 +1,194 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +import paddle + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSumDefaultONEDNNOp(OpTest): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.outputs = {'Out': self.inputs['X'].sum(axis=0)} + self.attrs = {'use_mkldnn': self.use_mkldnn} + + def test_check_output(self): + self.check_output() + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [2]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeONEDNNOp( + TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [0, 1, 2, 3]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllWithoutReduceAllAttributeNegativeDimsONEDNNOp( + TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 10, 5, 5)).astype("float32")} + self.attrs = {'use_mkldnn': self.use_mkldnn, 'dim': [-1, -2, -3, -4]} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum5DKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((2, 5, 3, 2, 2)).astype("float32")} + self.attrs = {'dim': (2, 3, 4), 'keep_dim': True, 'use_mkldnn': True} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=tuple(self.attrs['dim']), + keepdims=self.attrs['keep_dim']) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum5DReduceAllKeepDimsONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((2, 5, 3, 2, 2)).astype("float32")} + self.attrs = {'reduce_all': True, 'keep_dim': True, 'use_mkldnn': True} + self.outputs = { + 'Out': self.inputs['X'].sum(keepdims=self.attrs['keep_dim']) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceSum4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_sum" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")} + self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.inputs['X'].sum()} + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMax3DONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': [-1], 'use_mkldnn': self.use_mkldnn} + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci( + reason="reduce_max is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMax4DNegativeAndPositiveDimsONEDNNOp( + TestReduceSumDefaultONEDNNOp): + """Remove Max with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_max" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10, 9)).astype("float32")} + self.attrs = {'dim': [-1, 0, 1], 'use_mkldnn': self.use_mkldnn} + self.outputs = { + 'Out': self.inputs['X'].max(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci( + reason="reduce_min is discontinuous non-derivable function," + " its gradient check is not supported by unittest framework.") +class TestReduceMin3DONEDNNOp(TestReduceSumDefaultONEDNNOp): + """Remove Min with subgradient from gradient check to confirm the success of CI.""" + + def setUp(self): + self.op_type = "reduce_min" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': [2], 'use_mkldnn': self.use_mkldnn} + self.outputs = { + 'Out': self.inputs['X'].min(axis=tuple(self.attrs['dim'])) + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceMean3DONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 10)).astype("float32")} + self.attrs = {'dim': [0], 'use_mkldnn': self.use_mkldnn} + self.outputs = { + 'Out': self.inputs['X'].sum(axis=0) / self.inputs['X'].shape[0] + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceMean4DReduceAllONEDNNOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((5, 6, 8, 10)).astype("float32")} + self.attrs = {'reduce_all': True, 'use_mkldnn': self.use_mkldnn} + self.outputs = { + 'Out': + self.inputs['X'].sum() / np.asarray(self.inputs['X'].shape).prod() + } + + +@skip_check_grad_ci(reason="not implemented") +class TestReduceMeanNoReduce1DOp(TestReduceSumDefaultONEDNNOp): + def setUp(self): + self.op_type = "reduce_mean" + self.use_mkldnn = True + self.inputs = {'X': np.random.random((1)).astype("float32")} + self.attrs = {'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.inputs['X']} + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index 5bb4c8a63028697468545848870750090424b2ed..5de4bffd1601c8f3fa345a34cb70cf013253c011 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -421,6 +421,8 @@ STATIC_MODE_TESTING_LIST = [ 'test_reader_reset', 'test_recurrent_op', 'test_reduce_op', + 'test_reduce_mkldnn_op', + 'test_reduce_bf16_mkldnn_op', 'test_ref_by_trainer_id_op', 'test_registry', 'test_regularizer',