未验证 提交 52bbaae9 编写于 作者: J jakpiase 提交者: GitHub

Added shape (U)INT8/BF16/FP32 oneDNN kernel (#36033)

* added shape oneDNN kernel

* removed unnecessary import from test

* added skipping tests for GPU

* refactoring

* refactored shape kernel

* added tests in new framework

* removed one line

* minor change

* added newline at EOF

* added formatting

* added attributes as extra
上级 2db25f0d
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/shape_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
template <typename T>
class ShapeMKLDNNKernel : public ShapeKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ShapeKernel<T>::Compute(ctx);
auto* out = ctx.Output<Tensor>("Out");
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetPlainMKLDNNFormat(out->dims().size()));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(shape, MKLDNN, paddle::platform::CPUPlace,
ops::ShapeMKLDNNKernel<float>,
ops::ShapeMKLDNNKernel<paddle::platform::bfloat16>,
ops::ShapeMKLDNNKernel<int8_t>,
ops::ShapeMKLDNNKernel<uint8_t>);
......@@ -35,6 +35,21 @@ class ShapeOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", {in_dim.size()});
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::OperatorWithKernel::IndicateVarDataType(ctx, "Input");
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
......@@ -58,6 +73,16 @@ Shape Operator.
Return the shape of the input.
)DOC");
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16", "int8"})
.AsExtra();
}
};
......
......@@ -346,31 +346,22 @@ inline dnnl::memory::format_tag GetPlainMKLDNNFormat(int tensor_rank) {
switch (tensor_rank) {
case 1:
return dnnl::memory::format_tag::a;
break;
case 2:
return dnnl::memory::format_tag::ab;
break;
case 3:
return dnnl::memory::format_tag::abc;
break;
case 4:
return dnnl::memory::format_tag::abcd;
break;
case 5:
return dnnl::memory::format_tag::abcde;
break;
case 6:
return dnnl::memory::format_tag::abcdef;
break;
case 7:
return dnnl::memory::format_tag::abcdefg;
break;
case 8:
return dnnl::memory::format_tag::abcdefgh;
break;
case 9:
return dnnl::memory::format_tag::abcdefghi;
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Paddle support tensors with rank in range <1, 9>, but received "
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import MkldnnAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
from hypothesis import given
import hypothesis.strategies as st
class TestMkldnnShapeOp(MkldnnAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self, *args, **kwargs):
def generate_input(*args, **kwargs):
return np.random.random(kwargs['in_shape']).astype(kwargs[
'in_dtype'])
shape_op = OpConfig(
type="shape",
inputs={"Input": ["input_data"]},
outputs={"Out": ["output_data"]})
program_config = ProgramConfig(
ops=[shape_op],
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input,
*args, **kwargs)),
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, (1e-5, 1e-5)
@given(
in_shape=st.lists(
st.integers(
min_value=1, max_value=3), min_size=1, max_size=9),
in_dtype=st.sampled_from([np.float32, np.uint16, np.int8, np.uint8]))
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool
import paddle
from paddle.fluid import core
from paddle.fluid.op import Operator
@OpTestTool.skip_if_not_cpu_bf16()
class TestShape3DFP32OneDNNOp(OpTest):
def setUp(self):
self.op_type = "shape"
self.config()
self.attrs = {'use_mkldnn': True}
self.inputs = {'Input': np.zeros(self.shape).astype(self.dtype)}
self.outputs = {'Out': np.array(self.shape)}
def config(self):
self.shape = [5, 7, 4]
self.dtype = np.float32
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
class TestShape6DBF16OneDNNOp(TestShape3DFP32OneDNNOp):
def config(self):
self.shape = [10, 2, 3, 4, 5, 2]
self.dtype = np.uint16
class TestShape9DINT8OneDNNOp(TestShape3DFP32OneDNNOp):
def config(self):
self.shape = [1, 2, 3, 4, 5, 6, 7, 8, 9]
self.dtype = np.int8
class TestShape2DUINT8OneDNNOp(TestShape3DFP32OneDNNOp):
def config(self):
self.shape = [7, 11]
self.dtype = np.uint8
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册