未验证 提交 caf9d398 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add Conv Transpose BF16 (#30877)

* Add conv transpose BF16

* Share function GetWeightsTz

* Adjust to review and fix op compatibility

* Add bias to unique handler name

* Remove errors related to paddle enforce

* Add conv2d_transpose to bf16 list and kernel refator
上级 cbbe1274
......@@ -2192,9 +2192,9 @@ PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "elementwise_mul", "fc",
"fusion_gru", "gelu", "layer_norm", "matmul", "pool2d", "reshape2",
"softmax", "sum", "transpose2"});
{"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
"matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types;
}
......
......@@ -160,7 +160,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d_transpose", 1)
.LE("conv2d_transpose", 2)
.LE("elementwise_add", 1));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
......
......@@ -329,7 +329,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1)
.EQ("fc", 0)
.LE("conv2d_transpose", 1)
.LE("conv2d_transpose", 2)
.EQ("fake_quantize_abs_max", 0)
.EQ("fake_quantize_range_abs_max", 0)
.EQ("fake_quantize_moving_average_abs_max", 0)
......
......@@ -390,7 +390,7 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.LE("elementwise_add", 1)
.LE("elementwise_mul", 1)
.EQ("prelu", 0)
.LE("conv2d_transpose", 1)
.LE("conv2d_transpose", 2)
.LE("leaky_relu", 1)
.EQ("fc", 0)
.EQ("shuffle_channel", 0)
......
......@@ -290,6 +290,15 @@ void Conv2DTransposeOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force BF16 kernel output FP32, only "
"used in MKL-DNN BF16")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<std::string>("fuse_activation",
......@@ -671,7 +680,17 @@ REGISTER_OP_VERSION(conv2d_transpose)
"output_padding",
"In order to add additional size to one side of each dimension "
"in the output",
std::vector<int>{}));
std::vector<int>{}))
.AddCheckpoint(
R"ROC(
Upgrade conv2d transpose to add a new attributes [force_fp32_output, mkldnn_data_type].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("force_fp32_output",
"Force BF16 kernel output FP32, only used in MKL-DNN BF16",
false)
.NewAttr("mkldnn_data_type", "Data type of mkldnn kernel",
"float32"));
REGISTER_OP_VERSION(conv3d_transpose)
.AddCheckpoint(
......
......@@ -33,18 +33,6 @@ using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
inline void GetWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
if (groups > 1) {
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups;
}
}
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
const int groups,
const bool is_conv3d) {
......@@ -198,7 +186,7 @@ class ConvMKLDNNHandlerT
const auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize(filter->dims());
GetWeightsTz(weights_tz, groups);
platform::GetGroupConvWeightsTz(weights_tz, groups);
const auto dst_tz = paddle::framework::vectorize(output->dims());
......@@ -322,7 +310,7 @@ class ConvMKLDNNHandlerT
} else {
const K* filter_data = filter->data<K>();
auto weights_tz = framework::vectorize(filter->dims());
GetWeightsTz(weights_tz, groups);
platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(),
......@@ -640,7 +628,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g);
platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output->dims());
std::transform(dilations.begin(), dilations.end(), dilations.begin(),
......@@ -959,7 +947,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g);
platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output_grad->dims());
auto src_format = input->format();
......
......@@ -492,6 +492,19 @@ inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
}
}
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
if (groups > 1) {
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups;
}
}
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
op->GetAttrIfExists<bool>("use_quantizer"));
......
......@@ -250,10 +250,12 @@ class MKLDNNHandlerT {
astream.wait();
}
template <typename F = T>
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false) {
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user";
......@@ -262,6 +264,12 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx_.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
......@@ -1288,6 +1296,5 @@ static void SetDstMemoryQuantized(
dst_memory.reset(
new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data)));
}
} // namespace platform
} // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.test_conv2d_transpose_op import conv2dtranspose_forward_naive
from paddle import enable_static
def conv2d_bias_naive(out, bias):
_, out_c, _, _ = out.shape
for l in range(out_c):
out[:, l, :, :] = out[:, l, :, :] + bias[l]
return out
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestConv2DTransposeBF16MKLDNNOp(OpTest):
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
def test_check_grad_no_input(self):
pass
def test_check_grad_no_filter(self):
pass
def init_op_type(self):
self.data_format = "NCHW"
self.op_type = 'conv2d_transpose'
self._cpu_only = True
def init_test_case(self):
self.pad = [0, 0]
self.fuse_bias = False
self.use_mkldnn = True
self.is_test = True
self.bias_size = None
self.fuse_activation = ""
self.fuse_alpha = 0.0
self.fuse_beta = 0.0
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.groups = 1
self.output_size = None
self.output_padding = []
self.data_format = "NCHW"
self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT"
self.force_fp32_output = False
def setUp(self):
self.input_type = np.uint16
self.dtype = np.uint16
self.mkldnn_data_type = "bfloat16"
self.init_op_type()
self.init_test_case()
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'is_test': self.is_test,
'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
'force_fp32_output': self.force_fp32_output,
'data_format': self.data_format,
'fuse_activation': self.fuse_activation,
'fuse_alpha': self.fuse_alpha,
'fuse_beta': self.fuse_beta
}
if self.output_size is not None:
self.attrs['output_size'] = self.output_size
if len(self.output_padding) > 0:
self.attrs['output_padding'] = self.output_padding
output = conv2dtranspose_forward_naive(input, filter,
self.attrs).astype(np.float32)
if self.input_type is not np.float32:
input = convert_float_to_uint16(input)
self.inputs = {
'Input': input.view(self.input_type),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
if self.fuse_bias and self.bias_size is not None:
bias = np.random.random(self.bias_size).astype(np.float32)
output = conv2d_bias_naive(output, bias)
output = output.astype(np.float32)
self.attrs['fuse_bias'] = self.fuse_bias
self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias)
if self.fuse_activation == "relu":
output = np.maximum(output, 0).astype(np.float32)
output = output.astype(np.float32)
if not self.force_fp32_output:
output = convert_float_to_uint16(output, self.attrs['data_format'])
self.outputs['Output'] = output
class TestMKLDNNFuseBias(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNFuseBias, self).init_test_case()
self.pad = [1, 1]
self.fuse_bias = True
self.bias_size = [6]
class TestMKLDNNWithPad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithPad, self).init_test_case()
self.pad = [1, 1]
self.input_size = [2, 3, 10, 10]
class TestMKLDNNWithStride(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithStride, self).init_test_case()
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] # NCHW
class TestMKLDNNWithAsymPad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithAsymPad, self).init_test_case()
self.pad = [0, 0, 1, 2]
self.padding_algorithm = "EXPLICIT"
class TestMKLDNNWithSamePad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithSamePad, self).init_test_case()
self.pad = [0, 0]
self.padding_algorithm = "SAME"
class TestMKLDNNWithValidPad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithValidPad, self).init_test_case()
self.pad = [1, 1]
self.padding_algorithm = "VALID"
class TestMKLDNNWithValidPad_NHWC(TestMKLDNNWithValidPad):
def init_test_case(self):
super(TestMKLDNNWithValidPad_NHWC, self).init_test_case()
self.data_format = 'NHWC'
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
class TestConv2DTransposeMKLDNNWithDilationsExplicitPad(
TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestConv2DTransposeMKLDNNWithDilationsExplicitPad,
self).init_test_case()
self.stride = [2, 1]
self.dilations = [1, 2]
self.groups = 1
self.input_size = [4, 3, 8, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 4, 3]
self.pad = [1, 3, 2, 1]
self.padding_algorithm = "EXPLICIT"
if __name__ == '__main__':
enable_static()
unittest.main()
......@@ -82,6 +82,8 @@ class TestConv2DTransposeMKLDNNOp(TestConv2DTransposeOp):
self.attrs['fuse_activation'] = self.fuse_activation
self.attrs['fuse_alpha'] = self.fuse_alpha
self.attrs['fuse_beta'] = self.fuse_beta
self.attrs['mkldnn_data_type'] = 'float32'
self.attrs['force_fp32_output'] = False
self.outputs['Output'] = output
......@@ -150,3 +152,8 @@ class TestConv2DTransposeMKLDNNWithDilationsExplicitPad(
self.filter_size = [f_c, 6, 4, 3]
self.pad = [1, 3, 2, 1]
self.padding_algorithm = "EXPLICIT"
if __name__ == '__main__':
enable_static()
unittest.main()
......@@ -221,12 +221,18 @@ def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list):
def convert_float_to_uint16(float_list, data_format="NCHW"):
if data_format == "NHWC":
float_list = np.transpose(float_list, [0, 3, 1, 2])
new_output = []
for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
return np.reshape(new_output, float_list.shape).view(np.uint16)
if data_format == "NHWC":
new_output = np.transpose(new_output, [0, 2, 3, 1])
return new_output
class OpTest(unittest.TestCase):
......
......@@ -590,6 +590,7 @@ STATIC_MODE_TESTING_LIST = [
'test_conv2d_int8_mkldnn_op',
'test_conv2d_mkldnn_op',
'test_conv2d_transpose_mkldnn_op',
'test_conv2d_transpose_bf16_mkldnn_op',
'test_conv3d_mkldnn_op',
'test_dequantize_mkldnn_op',
'test_elementwise_add_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册