未验证 提交 b0d1ac16 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bf16 pool2d and unify bf16 unit tests (#29039)

* Add bf16 pool2d and unify bf16 unit tests

* Add change default ops test
上级 fddea674
......@@ -2103,8 +2103,8 @@ PDNode *patterns::Bfloat16Placement::operator()(
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "elementwise_mul", "fc",
"fusion_gru", "gelu", "layer_norm", "matmul", "reshape2", "softmax",
"sum", "transpose2"});
"fusion_gru", "gelu", "layer_norm", "matmul", "pool2d", "reshape2",
"softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
......@@ -136,7 +136,7 @@ TEST(Bfloat16PlacementPass, enabled_conv_and_pool) {
MainTest({"conv2d", "pool2d"}, 3);
}
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(7); }
TEST(Bfloat16PlacementPass, default_attr_value) { DefaultAttrTest(10); }
} // namespace ir
} // namespace framework
......
......@@ -181,7 +181,8 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNOpKernel<float>,
ops::PoolMKLDNNOpKernel<int8_t>,
ops::PoolMKLDNNOpKernel<uint8_t>);
ops::PoolMKLDNNOpKernel<uint8_t>,
ops::PoolMKLDNNOpKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNGradOpKernel<float>);
......@@ -27,7 +27,6 @@ from paddle import enable_static
"place does not support BF16 evaluation")
class TestConcatBf16Op(OpTest):
def setUp(self):
enable_static()
self.op_type = "concat"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
......@@ -107,4 +106,5 @@ class TestAxis3Case(TestConcatBf16Op):
if __name__ == '__main__':
enable_static()
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import os
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
from paddle.fluid.tests.unittests.test_pool2d_op import TestPool2D_Op, avg_pool2D_forward_naive, max_pool2D_forward_naive
from paddle import enable_static
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestPoolBf16MklDNNOp(TestPool2D_Op):
def init_kernel_type(self):
self.use_mkldnn = True
def setUp(self):
TestPool2D_Op.setUp(self)
self.dtype = np.uint16
input = np.random.random(self.shape).astype(np.float32)
output = (self.pool2D_forward_naive(
input, self.ksize, self.strides, self.paddings, self.global_pool,
self.ceil_mode, self.exclusive, self.adaptive,
"float32")).astype(np.float32)
self.inputs = {'X': convert_float_to_uint16(input)}
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
class TestCase1Avg(TestPoolBf16MklDNNOp):
def init_test_case(self):
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [0, 0]
def init_global_pool(self):
self.global_pool = False
def init_exclusive(self):
self.exclusive = True
class TestCase2Avg(TestPoolBf16MklDNNOp):
def init_test_case(self):
self.shape = [2, 3, 7, 7]
self.ksize = [3, 3]
self.strides = [1, 1]
self.paddings = [1, 1]
def init_global_pool(self):
self.global_pool = False
def init_exclusive(self):
self.exclusive = False
class TestCase0Max(TestPoolBf16MklDNNOp):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase1Max(TestCase1Avg):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
class TestCase2Max(TestCase2Avg):
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
if __name__ == "__main__":
enable_static()
unittest.main()
......@@ -27,7 +27,6 @@ from paddle import enable_static
"place does not support BF16 evaluation")
class TestReshapeBf16Op(OpTest):
def setUp(self):
enable_static()
self.op_type = "reshape2"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
......@@ -59,4 +58,5 @@ class TestReshapeBf16Op(OpTest):
if __name__ == '__main__':
enable_static()
unittest.main()
......@@ -29,6 +29,8 @@ def stable_softmax(x):
return exps / np.sum(exps)
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestSoftmaxMKLDNNOp(TestSoftmaxOp):
def get_x_shape(self):
return [10, 10]
......
......@@ -25,7 +25,6 @@ from paddle import enable_static
"place does not support BF16 evaluation")
class TestTransposeOp(OpTest):
def setUp(self):
enable_static()
self.op_type = "transpose2"
self.use_mkldnn = True
self.mkldnn_data_type = "bfloat16"
......@@ -63,4 +62,5 @@ class TestBF16Case(TestTransposeOp):
if __name__ == '__main__':
enable_static()
unittest.main()
......@@ -425,6 +425,7 @@ STATIC_MODE_TESTING_LIST = [
'test_regularizer_api',
'test_reorder_lod_tensor',
'test_reshape_op',
'test_reshape_bf16_op',
'test_retinanet_detection_output',
'test_reverse_op',
'test_rmsprop_op',
......@@ -582,6 +583,7 @@ STATIC_MODE_TESTING_LIST = [
'test_var_conv_2d',
'test_batch_norm_mkldnn_op',
'test_concat_int8_mkldnn_op',
'test_concat_bf16_mkldnn_op',
'test_concat_mkldnn_op',
'test_conv2d_bf16_mkldnn_op',
'test_conv2d_int8_mkldnn_op',
......@@ -606,6 +608,7 @@ STATIC_MODE_TESTING_LIST = [
'test_multi_gru_fuse_pass',
'test_multi_gru_seq_fuse_pass',
'test_pool2d_int8_mkldnn_op',
'test_pool2d_bf16_mkldnn_op',
'test_pool2d_mkldnn_op',
'test_quantize_mkldnn_op',
'test_requantize_mkldnn_op',
......@@ -614,6 +617,7 @@ STATIC_MODE_TESTING_LIST = [
'test_sum_mkldnn_op',
'test_sum_bf16_mkldnn_op',
'test_transpose_int8_mkldnn_op',
'test_transpose_bf16_mkldnn_op',
'test_transpose_mkldnn_op',
'test_mkldnn_conv_activation_fuse_pass',
'test_mkldnn_conv_concat_relu_mkldnn_fuse_pass',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册