test_conv2d_bf16_mkldnn_op.py 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17 18
import numpy as np

19
from paddle.fluid import core
20
from paddle.fluid.tests.unittests.eager_op_test import (
21 22
    OpTest,
    OpTestTool,
23
    convert_float_to_uint16,
24 25 26
)
from paddle.fluid.tests.unittests.test_conv2d_op import (
    TestConv2DOp,
27
    conv2d_forward_naive,
28
)
29 30 31 32 33 34 35 36


def conv2d_residual_naive(out, residual):
    assert out.shape == residual.shape
    out = np.add(out, residual)
    return out


37 38 39
@unittest.skipIf(
    not core.supports_bfloat16(), "place does not support BF16 evaluation"
)
40
class TestConv2DBF16Op(TestConv2DOp):
41 42 43 44 45 46
    def setUp(self):
        self.op_type = "conv2d"
        self.use_cudnn = False
        self.exhaustive_search = False
        self.use_cuda = False
        self.use_mkldnn = True
47
        self._cpu_only = True
48 49 50 51 52 53 54 55 56 57 58
        self.weight_type = np.float32
        self.input_type = np.float32
        self.mkldnn_data_type = "bfloat16"
        self.force_fp32_output = False
        self.init_group()
        self.init_dilation()
        self.init_test_case()
        self.init_fuse_relu()
        self.init_fuse_residual()
        self.init_data_type()
        self.init_force_fp32_output()
59
        self.init_infer_or_train()
60

61
        self.conv2d_param = {
62 63
            'stride': self.stride,
            'pad': self.pad,
64
            'dilation': self.dilations,
65
        }
66

67 68
        self.input = np.random.random(self.input_size).astype(np.float32)
        self.filter = np.random.random(self.filter_size).astype(np.float32)
69 70 71

        self.inputs_fp32 = {'Input': self.input, 'Filter': self.filter}

72 73 74
        conv_out, _, _, _, _ = conv2d_forward_naive(
            self.input, self.filter, self.groups, self.conv2d_param
        )
75 76 77 78
        self.conv_output_float = conv_out

        if self.fuse_residual:
            self.input_residual = np.random.random(
79 80
                self.input_residual_size
            ).astype(np.float32)
81
            self.conv_output_float = conv2d_residual_naive(
82 83
                self.conv_output_float, self.input_residual
            )
84 85 86 87
            self.conv_output = convert_float_to_uint16(self.conv_output_float)
            self.outputs = {'Output': self.conv_output}
        elif self.force_fp32_output:
            self.outputs = {'Output': self.conv_output_float.astype(np.float32)}
88 89 90 91
        else:
            self.outputs = {
                'Output': convert_float_to_uint16(self.conv_output_float)
            }
92 93 94 95

        if self.input_type is not np.float32:
            self.input = convert_float_to_uint16(self.input)

96 97 98
        if self.weight_type is not np.float32:
            self.filter = convert_float_to_uint16(self.filter)

99
        self.inputs = {
100 101 102 103
            'Input': self.input,
            'Filter': OpTest.np_dtype_to_fluid_dtype(
                self.filter.astype(self.weight_type)
            ),
104 105 106
        }

        if self.fuse_residual:
107
            self.op_type = "fused_conv2d"
108
            self.inputs['ResidualData'] = OpTest.np_dtype_to_fluid_dtype(
109 110
                convert_float_to_uint16(self.input_residual)
            )
111 112 113 114 115 116 117 118 119 120

        self.attrs = {
            'strides': self.stride,
            'paddings': self.pad,
            'groups': self.groups,
            'dilations': self.dilations,
            'use_cudnn': self.use_cudnn,
            'use_mkldnn': self.use_mkldnn,
            'mkldnn_data_type': self.mkldnn_data_type,
            'force_fp32_output': self.force_fp32_output,
121
            'fuse_residual_connection': self.fuse_residual,
122 123
        }

124 125
        self.init_additional_attrs()

126 127 128 129 130 131 132 133 134 135 136 137 138
    def test_check_output(self):
        self.check_output_with_place(core.CPUPlace())

    def test_check_grad(self):
        pass

    def test_check_grad_no_filter(self):
        pass

    def test_check_grad_no_input(self):
        pass

    def init_test_case(self):
C
cnn 已提交
139
        TestConv2DOp.init_test_case(self)
140
        self.input_size = [1, 6, 12, 12]  # NCHW
141
        f_c = self.input_size[1] // self.groups
142 143 144 145 146 147
        o_c = 15
        self.input_residual_size = [1, o_c, 10, 10]
        self.filter_size = [o_c, f_c, 3, 3]

    def init_padding(self):
        pass
148 149 150

    def init_data_type(self):
        self.weight_type = np.float32
151
        self.input_type = np.uint16
152 153 154 155 156 157 158 159 160 161

    def init_force_fp32_output(self):
        self.force_fp32_output = False

    def init_fuse_relu(self):
        self.fuse_activation = "relu"

    def init_fuse_residual(self):
        self.fuse_residual = True

162 163 164 165 166 167
    def init_infer_or_train(self):
        self.weight_type = np.float32

    def init_additional_attrs(self):
        self.attrs['is_test'] = True

168

169 170 171 172 173 174 175 176
@OpTestTool.skip_if_not_cpu_bf16()
class TestConv2DWithGradBF16Op(TestConv2DBF16Op):
    def init_fuse_relu(self):
        self.fuse_activation = None

    def init_fuse_residual(self):
        self.fuse_residual = None

177 178 179 180 181 182
    def init_additional_attrs(self):
        self.attrs['is_test'] = False

    def init_infer_or_train(self):
        self.weight_type = np.uint16

183 184 185 186 187 188 189 190
    def test_check_grad(self):
        dout = self.conv_output_float
        x = self.inputs_fp32['Input']
        w = self.inputs_fp32['Filter']

        dx, dweights = conv_backward(dout, x, w, self.conv2d_param)

        self.check_grad_with_place(
191 192
            core.CPUPlace(),
            ["Input", "Filter"],
193 194
            "Output",
            user_defined_grads=[dx, dweights],
195 196
            user_defined_grad_outputs=[convert_float_to_uint16(dout)],
        )
197 198 199 200 201 202 203 204 205

    def test_check_grad_no_filter(self):
        dout = self.conv_output_float
        x = self.inputs_fp32['Input']
        w = self.inputs_fp32['Filter']

        dx, _ = conv_backward(dout, x, w, self.conv2d_param)

        self.check_grad_with_place(
206 207
            core.CPUPlace(),
            ["Input"],
208
            "Output",
209
            {'Filter'},
210
            user_defined_grads=[dx],
211 212
            user_defined_grad_outputs=[convert_float_to_uint16(dout)],
        )
213 214 215 216 217 218 219 220 221

    def test_check_grad_no_input(self):
        dout = self.conv_output_float
        x = self.inputs_fp32['Input']
        w = self.inputs_fp32['Filter']

        _, dweights = conv_backward(dout, x, w, self.conv2d_param)

        self.check_grad_with_place(
222 223
            core.CPUPlace(),
            ["Filter"],
224
            "Output",
225
            {'Input'},
226
            user_defined_grads=[dweights],
227 228
            user_defined_grad_outputs=[convert_float_to_uint16(dout)],
        )
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243


def conv_backward(dout, x, w, params):
    padding = params['pad'][0]
    stride = params['stride']

    dx = np.zeros_like(x)
    dweights = np.zeros_like(w)

    N, IC, H, W = x.shape
    OC, _, KH, KW = w.shape

    H_out = int(1 + (H + 2 * padding - KH) / stride[0])
    W_out = int(1 + (W + 2 * padding - KW) / stride[1])

244
    x_padded = np.pad(x, ((0,), (0,), (padding,), (padding,)), 'constant')
245 246 247 248 249 250 251 252

    for n in range(N):
        for oc in range(OC):
            for i in range(KH):
                for j in range(KW):
                    for k in range(H_out):
                        for l in range(W_out):
                            for ic in range(IC):
253 254 255 256 257 258 259 260 261 262 263
                                dweights[oc, ic, i, j] += (
                                    x_padded[
                                        n,
                                        ic,
                                        i + k * stride[0],
                                        j + l * stride[1],
                                    ]
                                    * dout[n, oc, k, l]
                                )

    dx_padded = np.pad(dx, ((0,), (0,), (padding,), (padding,)), 'constant')
264 265 266 267 268 269 270 271 272 273 274 275 276

    w_ = np.zeros_like(w)
    for i in range(KH):
        for j in range(KW):
            w_[:, :, i, j] = w[:, :, KH - i - 1, KW - j - 1]

    for n in range(N):
        for oc in range(OC):
            for i in range(H_out):
                for j in range(W_out):
                    for kh in range(KH):
                        for kw in range(KW):
                            for ic in range(IC):
277 278 279 280 281 282 283 284
                                dx_padded[
                                    n,
                                    ic,
                                    stride[0] * i + kh,
                                    stride[1] * j + kw,
                                ] += (
                                    dout[n, oc, i, j] * w[oc, ic, kh, kw]
                                )
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306

    if padding == 0:
        dx = dx_padded
    else:
        dx = dx_padded[:, :, padding:-padding, padding:-padding]

    return dx.astype(np.float32), dweights.astype(np.float32)


class TestConv2DBF16WithPadding1(TestConv2DWithGradBF16Op):
    def init_test_case(self):
        TestConv2DWithGradBF16Op.init_test_case(self)
        self.pad = [1, 1]


class TestConv2DBF16WithStride2(TestConv2DWithGradBF16Op):
    def init_test_case(self):
        TestConv2DWithGradBF16Op.init_test_case(self)
        self.stride = [2, 3]


class TestConv2D(TestConv2DBF16Op):
307 308 309 310 311 312 313 314 315 316 317 318 319
    def init_test_case(self):
        self.pad = [0, 0]
        self.stride = [1, 1]
        self.input_size = [2, 3, 5, 5]  # NCHW
        self.input_residual_size = [2, 6, 3, 3]
        assert np.mod(self.input_size[1], self.groups) == 0
        f_c = self.input_size[1] // self.groups
        self.filter_size = [6, f_c, 3, 3]

    def init_data_type(self):
        self.input_type = np.uint16


C
cnn 已提交
320
class TestWithPad(TestConv2D):
321
    def init_test_case(self):
C
cnn 已提交
322
        TestConv2D.init_test_case(self)
323 324 325 326
        self.pad = [1, 1]
        self.input_residual_size = [2, 6, 5, 5]


C
cnn 已提交
327
class TestWithGroup(TestConv2D):
328 329 330 331
    def init_group(self):
        self.groups = 3


332
class TestWithStride(TestConv2DBF16Op):
333 334 335 336 337 338 339 340
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [2, 2]
        self.input_size = [2, 3, 6, 6]
        self.input_residual_size = [2, 6, 3, 3]
        assert np.mod(self.input_size[1], self.groups) == 0
        f_c = self.input_size[1] // self.groups
        self.filter_size = [6, f_c, 3, 3]
341 342 343 344 345

    def init_data_type(self):
        self.input_type = np.uint16


346
class TestWithDilations(TestConv2DBF16Op):
347 348 349 350 351 352 353 354 355
    def init_test_case(self):
        self.pad = [1, 1]
        self.stride = [1, 1]
        self.dilations = [2, 2]
        self.input_size = [2, 3, 10, 10]
        self.input_residual_size = [2, 6, 8, 8]
        assert np.mod(self.input_size[1], self.groups) == 0
        f_c = self.input_size[1] // self.groups
        self.filter_size = [6, f_c, 3, 3]
356 357 358 359 360

    def init_data_type(self):
        self.input_type = np.uint16


361
class TestWith1x1ForceFP32Output(TestConv2DBF16Op):
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
    def init_test_case(self):
        self.pad = [0, 0]
        self.stride = [1, 1]
        self.input_size = [1, 3, 5, 5]
        assert np.mod(self.input_size[1], self.groups) == 0
        f_c = self.input_size[1] // self.groups
        self.filter_size = [6, f_c, 1, 1]

    def init_force_fp32_output(self):
        self.force_fp32_output = True

    def init_fuse_residual(self):
        self.fuse_residual = False


377
class TestWithInput1x1Filter1x1(TestConv2DBF16Op):
378 379 380 381 382 383 384 385 386 387 388 389 390 391
    def init_test_case(self):
        self.pad = [0, 0]
        self.stride = [1, 1]
        self.input_size = [2, 3, 1, 1]
        self.input_residual_size = [2, 6, 1, 1]
        assert np.mod(self.input_size[1], self.groups) == 0
        f_c = self.input_size[1] // self.groups
        self.filter_size = [6, f_c, 1, 1]

    def init_group(self):
        self.groups = 3


if __name__ == '__main__':
392
    from paddle import enable_static
393

394
    enable_static()
395
    unittest.main()