From 52bbaae9c86b214df12a6d7d84decbd83866b6cf Mon Sep 17 00:00:00 2001 From: jakpiase Date: Fri, 11 Feb 2022 10:48:28 +0100 Subject: [PATCH] Added shape (U)INT8/BF16/FP32 oneDNN kernel (#36033) * added shape oneDNN kernel * removed unnecessary import from test * added skipping tests for GPU * refactoring * refactored shape kernel * added tests in new framework * removed one line * minor change * added newline at EOF * added formatting * added attributes as extra --- .../fluid/operators/mkldnn/shape_mkldnn_op.cc | 43 +++++++++++++ paddle/fluid/operators/shape_op.cc | 25 ++++++++ paddle/fluid/platform/mkldnn_helper.h | 9 --- .../ir/inference/test_mkldnn_shape_op.py | 63 +++++++++++++++++++ .../unittests/mkldnn/test_shape_mkldnn_op.py | 62 ++++++++++++++++++ 5 files changed, 193 insertions(+), 9 deletions(-) create mode 100644 paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py create mode 100644 python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py diff --git a/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc new file mode 100644 index 00000000000..780c6e7f153 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/shape_mkldnn_op.cc @@ -0,0 +1,43 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/shape_op.h" +#include "paddle/fluid/platform/mkldnn_helper.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +template +class ShapeMKLDNNKernel : public ShapeKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + ShapeKernel::Compute(ctx); + + auto* out = ctx.Output("Out"); + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size())); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(shape, MKLDNN, paddle::platform::CPUPlace, + ops::ShapeMKLDNNKernel, + ops::ShapeMKLDNNKernel, + ops::ShapeMKLDNNKernel, + ops::ShapeMKLDNNKernel); diff --git a/paddle/fluid/operators/shape_op.cc b/paddle/fluid/operators/shape_op.cc index dd135b89714..5b7ccdde810 100644 --- a/paddle/fluid/operators/shape_op.cc +++ b/paddle/fluid/operators/shape_op.cc @@ -35,6 +35,21 @@ class ShapeOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Out", {in_dim.size()}); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + protected: framework::OpKernelType GetKernelTypeForVar( const std::string &var_name, const framework::Tensor &tensor, @@ -58,6 +73,16 @@ Shape Operator. Return the shape of the input. )DOC"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16", "int8"}) + .AsExtra(); } }; diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 7a528cf8d6b..9dbfe7013fa 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -346,31 +346,22 @@ inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) { switch (tensor_rank) { case 1: return dnnl::memory::format_tag::a; - break; case 2: return dnnl::memory::format_tag::ab; - break; case 3: return dnnl::memory::format_tag::abc; - break; case 4: return dnnl::memory::format_tag::abcd; - break; case 5: return dnnl::memory::format_tag::abcde; - break; case 6: return dnnl::memory::format_tag::abcdef; - break; case 7: return dnnl::memory::format_tag::abcdefg; - break; case 8: return dnnl::memory::format_tag::abcdefgh; - break; case 9: return dnnl::memory::format_tag::abcdefghi; - break; default: PADDLE_THROW(platform::errors::Unimplemented( "Paddle support tensors with rank in range <1, 9>, but received " diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py new file mode 100644 index 00000000000..5b23669b98d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shape_op.py @@ -0,0 +1,63 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +from hypothesis import given +import hypothesis.strategies as st + + +class TestMkldnnShapeOp(MkldnnAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self, *args, **kwargs): + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(kwargs[ + 'in_dtype']) + + shape_op = OpConfig( + type="shape", + inputs={"Input": ["input_data"]}, + outputs={"Out": ["output_data"]}) + + program_config = ProgramConfig( + ops=[shape_op], + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input, + *args, **kwargs)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + @given( + in_shape=st.lists( + st.integers( + min_value=1, max_value=3), min_size=1, max_size=9), + in_dtype=st.sampled_from([np.float32, np.uint16, np.int8, np.uint8])) + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py new file mode 100644 index 00000000000..41e6344a0a1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shape_mkldnn_op.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool +import paddle +from paddle.fluid import core +from paddle.fluid.op import Operator + + +@OpTestTool.skip_if_not_cpu_bf16() +class TestShape3DFP32OneDNNOp(OpTest): + def setUp(self): + self.op_type = "shape" + self.config() + self.attrs = {'use_mkldnn': True} + self.inputs = {'Input': np.zeros(self.shape).astype(self.dtype)} + self.outputs = {'Out': np.array(self.shape)} + + def config(self): + self.shape = [5, 7, 4] + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + +class TestShape6DBF16OneDNNOp(TestShape3DFP32OneDNNOp): + def config(self): + self.shape = [10, 2, 3, 4, 5, 2] + self.dtype = np.uint16 + + +class TestShape9DINT8OneDNNOp(TestShape3DFP32OneDNNOp): + def config(self): + self.shape = [1, 2, 3, 4, 5, 6, 7, 8, 9] + self.dtype = np.int8 + + +class TestShape2DUINT8OneDNNOp(TestShape3DFP32OneDNNOp): + def config(self): + self.shape = [7, 11] + self.dtype = np.uint8 + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() -- GitLab