test_batch_norm_op.py 21.1 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Q
Qiao Longfei 已提交
15 16
import unittest
import numpy as np
Y
Yu Yang 已提交
17
from op_test import OpTest
18 19 20
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid.framework import grad_var_name
Q
Qiao Longfei 已提交
21 22


Y
Yu Yang 已提交
23 24 25 26 27 28 29 30 31 32 33
def get_backward_op(scope, op, no_grad_set):
    backward_op = core.Operator.backward(op, no_grad_set)
    for input in backward_op.input_vars():
        var = scope.var(input)
        var.get_tensor()
    for output in backward_op.output_vars():
        var = scope.var(output)
        var.get_tensor()
    return backward_op


34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
def _reference_testing(x, scale, offset, mean, var, epsilon, data_format):
    x_shape = x.shape
    if len(x_shape) == 2:
        if data_format == "NCHW":
            x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
        else:
            x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))

    if data_format == "NCHW":
        n, c, h, w = x.shape
        mean_tile = np.reshape(mean, (1, c, 1, 1))
        mean_tile = np.tile(mean_tile, (n, 1, h, w))
        var_tile = np.reshape(var, (1, c, 1, 1))
        var_tile = np.tile(var_tile, (n, 1, h, w))
        normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
        scale_tile = np.reshape(scale, (1, c, 1, 1))
        scale_tile = np.tile(scale_tile, (n, 1, h, w))
        offset_tile = np.reshape(offset, (1, c, 1, 1))
        offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
        y = normalized * scale_tile + offset_tile
    elif data_format == "NHWC":
        normalized = (x - mean) / np.sqrt(var + epsilon)
        y = normalized * scale + offset
    else:
        raise ValueError("Unknown data order.")

    if len(x_shape) == 2:
        y = np.reshape(y, x_shape)
    return y


Q
Qiao Longfei 已提交
65
def _reference_training(x, scale, offset, epsilon, data_format):
66 67 68 69 70 71 72
    x_shape = x.shape
    if len(x_shape) == 2:
        if data_format == "NCHW":
            x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
        else:
            x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))

Z
zchen0211 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90
    if data_format == "NCHW":
        n, c, h, w = x.shape
        x_square = x * x
        x_square_sum = np.sum(x_square, (0, 2, 3))
        x_sum = np.sum(x, axis=(0, 2, 3))
        element_count = np.size(x) / int(np.shape(x)[1])
        mean = x_sum / element_count
        var = x_square_sum / element_count - mean * mean
        mean_tile = np.reshape(mean, (1, c, 1, 1))
        mean_tile = np.tile(mean_tile, (n, 1, h, w))
        var_tile = np.reshape(var, (1, c, 1, 1))
        var_tile = np.tile(var_tile, (n, 1, h, w))
        normalized = (x - mean_tile) / np.sqrt(var_tile + epsilon)
        scale_tile = np.reshape(scale, (1, c, 1, 1))
        scale_tile = np.tile(scale_tile, (n, 1, h, w))
        offset_tile = np.reshape(offset, (1, c, 1, 1))
        offset_tile = np.reshape(offset_tile, (1, c, 1, 1))
        y = normalized * scale_tile + offset_tile
91 92
        if len(x_shape) == 2:
            y = np.reshape(y, (y.shape[0], y.shape[1]))
Z
zchen0211 已提交
93 94 95 96 97 98 99 100 101
        return y, mean, var
    elif data_format == "NHWC":
        x_square = x * x
        x_square_sum = np.sum(x_square, (0, 1, 2))
        x_sum = np.sum(x, axis=(0, 1, 2))
        element_count = np.size(x) / int(np.shape(x)[-1])
        mean = x_sum / element_count
        var = x_square_sum / element_count - mean * mean
        normalized = (x - mean) / np.sqrt(var + epsilon)
102 103 104 105
        y = normalized * scale + offset
        if len(x_shape) == 2:
            y = np.reshape(y, x_shape)
        return y, mean, var
Z
zchen0211 已提交
106 107
    else:
        raise ValueError("Unknown data order.")
Q
Qiao Longfei 已提交
108 109 110 111 112 113 114 115 116 117 118 119


def _reference_grad(x, grad_y, scale, mean, var, epsilon, data_format):
    # Use the following formulas to calculate gradients:
    # grad_scale =
    #   sum(grad_y * (x - mean)) * rsqrt(var + epsilon)
    #
    # grad_offset = sum(output_y)
    #
    # grad_x =
    #   1/N * scale * rsqrt(var + epsilon) * (N * grad_y - sum(grad_y) -
    #   (x - mean) * sum(grad_y * (x - mean)) / (var + epsilon))
Z
zchen0211 已提交
120 121

    # transfer from (N, C, H, W) to (N, H, W, C) to simplify computation
122 123 124 125 126 127 128 129 130 131 132 133
    x_shape = x.shape

    if len(x_shape) == 2:
        if data_format == "NCHW":
            x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
            grad_y = np.reshape(grad_y,
                                (grad_y.shape[0], grad_y.shape[1], 1, 1))
        else:
            x = np.reshape(x, (x.shape[0], 1, 1, x.shape[1]))
            grad_y = np.reshape(grad_y,
                                (grad_y.shape[0], 1, 1, grad_y.shape[1]))

Z
zchen0211 已提交
134 135 136 137 138
    if data_format == "NCHW":
        x = np.transpose(x, (0, 2, 3, 1))
        grad_y = np.transpose(grad_y, (0, 2, 3, 1))

        # raise ValueError("data_format must be NHWC, got %s." % data_format)
Q
Qiao Longfei 已提交
139 140 141 142 143 144 145
    grad_x = scale * (grad_y - np.mean(
        grad_y, axis=(0, 1, 2)) - (x - mean) * np.mean(
            grad_y * (x - mean), axis=(0, 1, 2)) /
                      (var + epsilon)) / np.sqrt(var + epsilon)
    grad_scale = np.sum(grad_y * (x - mean) / np.sqrt(var + epsilon),
                        axis=(0, 1, 2))
    grad_offset = np.sum(grad_y, axis=(0, 1, 2))
Z
zchen0211 已提交
146 147 148 149 150 151

    # transfer back to N, C, H, W
    if data_format == "NCHW":
        grad_x = np.transpose(grad_x, (0, 3, 1, 2))
        x = np.transpose(x, (0, 3, 1, 2))
        grad_y = np.transpose(grad_y, (0, 3, 1, 2))
152 153 154

    if len(x_shape) == 2:
        grad_x = np.reshape(grad_x, x_shape)
Q
Qiao Longfei 已提交
155 156 157 158 159 160 161 162 163 164 165 166 167
    return grad_x, grad_scale, grad_offset


def create_or_get_tensor(scope, var_name, var, place):
    tensor = scope.var(var_name).get_tensor()
    if var is not None:
        assert isinstance(var, np.ndarray)
        tensor.set_lod([[]])
        tensor.set_dims(var.shape)
        tensor.set(var, place)
    return tensor


Z
zchen0211 已提交
168 169
def set_output_grad(scope, outputs, place, feed_dict=None):
    def __set_tensor__(name, data=None):
Q
Qiao Longfei 已提交
170 171 172
        out_tensor = scope.find_var(name).get_tensor()
        grad_tensor = scope.var(grad_var_name(name)).get_tensor()
        out_dtype = out_tensor.dtype()
Z
zchen0211 已提交
173
        if data is None:
174
            if out_dtype == core.VarDesc.VarType.FP64:
Z
zchen0211 已提交
175
                data = np.ones(out_tensor.shape(), dtype=np.float64)
176
            elif out_dtype == core.VarDesc.VarType.FP32:
Z
zchen0211 已提交
177 178 179
                data = np.ones(out_tensor.shape(), dtype=np.float32)
            else:
                raise ValueError("Not supported data type " + str(out_dtype))
Q
Qiao Longfei 已提交
180 181 182
        grad_tensor.set(data, place)

    for output in outputs:
Z
zchen0211 已提交
183 184 185 186
        data = None
        if output in feed_dict:
            data = feed_dict[output]
        __set_tensor__(output, data)
Q
Qiao Longfei 已提交
187 188


189
class TestBatchNormOpInference(OpTest):
K
Kexin Zhao 已提交
190 191 192
    def setUp(self):
        self.dtype = np.float32

K
Kexin Zhao 已提交
193 194
    def __assert_close(self, tensor, np_array, msg, atol=1e-4):
        self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
195

K
Kexin Zhao 已提交
196
    def check_with_place(self, place, data_layout, dtype, shape):
K
Kexin Zhao 已提交
197 198 199 200 201 202 203 204 205 206
        epsilon = 0.00001
        if len(shape) == 2:
            x_shape = shape
            c = x_shape[1]
        else:
            n, h, w, c = shape[0], shape[1], shape[2], shape[3]
            if data_layout == "NHWC":
                x_shape = [n, h, w, c]
            elif data_layout == "NCHW":
                x_shape = [n, c, h, w]
K
Kexin Zhao 已提交
207
            else:
K
Kexin Zhao 已提交
208 209
                raise ValueError("Unknown data layout.")
        scale_shape = [c]
K
Kexin Zhao 已提交
210

K
Kexin Zhao 已提交
211
        x_val = np.random.random_sample(x_shape).astype(dtype)
K
Kexin Zhao 已提交
212 213
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        bias_val = np.random.random_sample(scale_shape).astype(np.float32)
K
Kexin Zhao 已提交
214

K
Kexin Zhao 已提交
215 216
        mean = np.zeros(scale_shape).astype(np.float32)
        variance = np.ones(scale_shape).astype(np.float32)
K
Kexin Zhao 已提交
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

        y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
                                   epsilon, data_layout).astype(dtype)

        scope = core.Scope()

        # create input
        x_tensor = create_or_get_tensor(scope, "x_val",
                                        OpTest.np_dtype_to_fluid_dtype(x_val),
                                        place)
        scale_tensor = create_or_get_tensor(
            scope, "scale_val",
            OpTest.np_dtype_to_fluid_dtype(scale_val), place)
        bias_tensor = create_or_get_tensor(
            scope, "bias_val", OpTest.np_dtype_to_fluid_dtype(bias_val), place)
        mean_tensor = create_or_get_tensor(scope, "mean",
                                           OpTest.np_dtype_to_fluid_dtype(mean),
                                           place)
        variance_tensor = create_or_get_tensor(
            scope, "variance", OpTest.np_dtype_to_fluid_dtype(variance), place)

        # create output
        y_tensor = create_or_get_tensor(scope, "y_out", None, place)
        saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
                                                 place)
        saved_variance_tensor = create_or_get_tensor(scope, "saved_variance",
                                                     None, place)
        mean_out_tensor = mean_tensor
        variance_out_tensor = variance_tensor

        batch_norm_op = Operator(
            "batch_norm",
            # inputs
            X="x_val",
            Scale="scale_val",
            Bias="bias_val",
            Mean="mean",
            Variance="variance",
            # outputs
            Y="y_out",
            MeanOut="mean",
            VarianceOut="variance",
            SavedMean="saved_mean",
            SavedVariance="saved_variance",
            # attrs
            is_test=True,
            data_layout=data_layout,
            epsilon=epsilon)

        batch_norm_op.run(scope, place)

        # check inference result
K
Kexin Zhao 已提交
269 270 271 272 273 274
        self.__assert_close(
            y_tensor,
            y_out,
            "inference output are different at " + str(place) + ", " +
            data_layout + ", " + str(np.dtype(dtype)) +
            str(np.array(y_tensor)) + str(y_out),
K
Kexin Zhao 已提交
275
            atol=1e-3)
K
Kexin Zhao 已提交
276 277

    def test_check_output(self):
K
Kexin Zhao 已提交
278
        places = [core.CPUPlace()]
K
Kexin Zhao 已提交
279 280 281 282 283
        if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
            places.append(core.CUDAPlace(0))

        for place in places:
            for data_format in ["NCHW", "NHWC"]:
K
Kexin Zhao 已提交
284 285 286
                self.check_with_place(place, data_format, self.dtype,
                                      [2, 3, 4, 5])
                self.check_with_place(place, data_format, self.dtype, [2, 3])
K
Kexin Zhao 已提交
287 288 289 290 291 292 293 294


class TestFP16BatchNormOpInference(TestBatchNormOpInference):
    def setUp(self):
        self.dtype = np.float16

    def test_check_output(self):
        places = []
K
Kexin Zhao 已提交
295 296
        if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
            place = core.CUDAPlace(0)
K
Kexin Zhao 已提交
297
            if core.is_float16_supported(place):
K
Kexin Zhao 已提交
298 299 300 301
                places.append(place)

        for place in places:
            for data_format in ["NCHW", "NHWC"]:
K
Kexin Zhao 已提交
302 303 304
                self.check_with_place(place, data_format, self.dtype,
                                      [2, 3, 4, 5])
                self.check_with_place(place, data_format, self.dtype, [2, 3])
K
Kexin Zhao 已提交
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325


class TestBatchNormOpTraining(OpTest):
    def __assert_close(self, tensor, np_array, msg, atol=1e-4):
        self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)

    def test_python_testing(self):
        data_format = "NHWC"
        epsilon = 0.00001

        n, h, w, c = 2, 3, 4, 5
        x_shape = [n, h, w, c]
        scale_shape = [c]

        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        bias_val = np.random.random_sample(scale_shape).astype(np.float32)

        mean = np.zeros(scale_shape).astype(np.float32)
        variance = np.ones(scale_shape).astype(np.float32)

326 327 328 329 330 331 332 333 334 335 336 337
        y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
                                   epsilon, "NHWC")

        # running N, C, H, W case
        # should produce the same results
        x_shape2 = [n, c, h, w]
        x_val2 = np.transpose(x_val, (0, 3, 1, 2))
        y_out2 = _reference_testing(x_val2, scale_val, bias_val, mean, variance,
                                    epsilon, "NCHW")

        # transfer (N, C, H, W) back to (N, H, W, C)
        y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
K
Kexin Zhao 已提交
338
        self.__assert_close(y_out, y_out2_trans, "inference output")
339 340
        print 'python: NHWC, NCHW, inference checking passed'

K
Kexin Zhao 已提交
341
    def test_python_training(self):
Z
zchen0211 已提交
342
        data_format = "NHWC"
Q
Qiao Longfei 已提交
343 344 345
        epsilon = 0.00001
        momentum = 0.9

Z
zchen0211 已提交
346
        # N, H, W, C: 2, 3, 4, 2
347
        n, h, w, c = 2, 3, 4, 5
Z
zchen0211 已提交
348 349
        x_shape = [n, h, w, c]
        scale_shape = [c]
Q
Qiao Longfei 已提交
350 351 352 353 354 355

        x_val = np.random.random_sample(x_shape).astype(np.float32)
        scale_val = np.random.random_sample(scale_shape).astype(np.float32)
        bias_val = np.random.random_sample(scale_shape).astype(np.float32)

        mean = np.zeros(scale_shape).astype(np.float32)
Z
zchen0211 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368
        variance = np.ones(scale_shape).astype(np.float32)

        # run forward
        y_out, saved_mean, var_ref = _reference_training(
            x_val, scale_val, bias_val, epsilon, "NHWC")

        #
        mean_out = saved_mean * (1. - momentum) + momentum * mean
        variance_out = var_ref * (1. - momentum) + momentum * variance
        saved_variance = 1. / np.sqrt(var_ref + epsilon)

        # running N, C, H, W case
        # should produce the same results
Z
zchen0211 已提交
369
        x_shape2 = [n, c, h, w]
Z
zchen0211 已提交
370 371 372 373 374 375 376 377 378
        x_val2 = np.transpose(x_val, (0, 3, 1, 2))
        y_out2, saved_mean2, var_ref2 = _reference_training(
            x_val2, scale_val, bias_val, epsilon, "NCHW")

        self.__assert_close(saved_mean, saved_mean2, "batch mean")
        self.__assert_close(var_ref, var_ref2, "batch variance")

        # transfer (N, C, H, W) back to (N, H, W, C)
        y_out2_trans = np.transpose(y_out2, (0, 2, 3, 1))
K
Kexin Zhao 已提交
379
        self.__assert_close(y_out, y_out2_trans, "batch output")
Z
zchen0211 已提交
380 381 382 383
        print 'python: NHWC, NCHW, forward checking passed'

        # test backward now
        # NHWC
Z
zchen0211 已提交
384 385 386
        self.y_grad = np.random.random_sample(x_shape).astype(np.float32)
        y_grad = self.y_grad
        # y_grad = np.ones(x_shape).astype(np.float32)
Z
zchen0211 已提交
387 388 389 390
        x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
            x_val, y_grad, scale_val, saved_mean, var_ref, epsilon, "NHWC")

        # NCHW
Z
zchen0211 已提交
391 392
        y_grad2 = np.transpose(y_grad, (0, 3, 1, 2))
        # y_grad2 = np.ones(x_shape2).astype(np.float32)
Z
zchen0211 已提交
393 394 395 396 397 398 399 400 401 402 403
        x_grad_ref2, scale_grad_ref2, bias_grad_ref2 = _reference_grad(
            x_val2, y_grad2, scale_val, saved_mean2, var_ref2, epsilon, "NCHW")

        self.__assert_close(scale_grad_ref, scale_grad_ref2, "scale gradient")
        self.__assert_close(bias_grad_ref, bias_grad_ref2, "bias gradient")

        x_grad_transpose = np.transpose(x_grad_ref2, (0, 2, 3, 1))
        self.__assert_close(x_grad_ref, x_grad_transpose, "x gradient")
        print 'python: NHWC, NCHW, backward checking passed'

    def test_forward_backward(self):
Q
QI JUN 已提交
404
        def test_with_place(place, data_layout, shape):
405 406 407 408
            # attr
            epsilon = 0.00001
            momentum = 0.9

409 410 411
            if len(shape) == 2:
                x_shape = shape
                c = shape[1]
412
            else:
413 414 415 416 417 418 419 420
                # n, h, w, c = 2, 3, 4, 2
                n, h, w, c = shape[0], shape[1], shape[2], shape[3]
                if data_format == "NHWC":
                    x_shape = [n, h, w, c]
                elif data_format == "NCHW":
                    x_shape = [n, c, h, w]
                else:
                    raise ValueError("Unknown data type.")
421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
            scale_shape = [c]

            x_val = np.random.random_sample(x_shape).astype(np.float32)
            scale_val = np.random.random_sample(scale_shape).astype(np.float32)
            bias_val = np.random.random_sample(scale_shape).astype(np.float32)

            mean = np.zeros(scale_shape).astype(np.float32)
            variance = np.ones(scale_shape).astype(np.float32)

            # run forward
            y_out, saved_mean, var_ref = _reference_training(
                x_val, scale_val, bias_val, epsilon, data_format)

            # update moving mean and variance
            mean_out = saved_mean * (1. - momentum) + momentum * mean
            variance_out = var_ref * (1. - momentum) + momentum * variance
            saved_variance = 1. / np.sqrt(var_ref + epsilon)

            #  for gradient test
            # y_grad = np.ones(x_shape).astype(np.float32)
            y_grad = np.zeros(x_shape).astype(np.float32)
442 443 444 445
            if len(y_grad.shape) == 2:
                y_grad[0, 0] = 1.
            else:
                y_grad[0, 0, 0, 0] = 1.
446 447 448 449
            # y_grad = np.random.random_sample(x_shape).astype(np.float32)
            x_grad_ref, scale_grad_ref, bias_grad_ref = _reference_grad(
                x_val, y_grad, scale_val, saved_mean, var_ref, epsilon,
                data_format)
Q
Qiao Longfei 已提交
450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487

            scope = core.Scope()

            # create input
            x_tensor = create_or_get_tensor(scope, "x_val", x_val, place)
            scale_tensor = create_or_get_tensor(scope, "scale_val", scale_val,
                                                place)
            bias_tensor = create_or_get_tensor(scope, "bias_val", bias_val,
                                               place)
            mean_tensor = create_or_get_tensor(scope, "mean", mean, place)
            variance_tensor = create_or_get_tensor(scope, "variance", variance,
                                                   place)

            # create output
            y_tensor = create_or_get_tensor(scope, "y_out", None, place)
            saved_mean_tensor = create_or_get_tensor(scope, "saved_mean", None,
                                                     place)
            saved_variance_tensor = create_or_get_tensor(
                scope, "saved_variance", None, place)
            mean_out_tensor = mean_tensor
            variance_out_tensor = variance_tensor

            batch_norm_op = Operator(
                "batch_norm",
                # inputs
                X="x_val",
                Scale="scale_val",
                Bias="bias_val",
                Mean="mean",
                Variance="variance",
                # outputs
                Y="y_out",
                MeanOut="mean",
                VarianceOut="variance",
                SavedMean="saved_mean",
                SavedVariance="saved_variance",
                # attrs
                is_test=False,
Q
QI JUN 已提交
488
                data_layout=data_layout,
Q
Qiao Longfei 已提交
489 490 491
                momentum=momentum,
                epsilon=epsilon)

D
dzhwinter 已提交
492
            batch_norm_op.run(scope, place)
Q
Qiao Longfei 已提交
493 494 495 496 497 498 499

            # check forward result
            self.__assert_close(y_tensor, y_out, "y_out")
            self.__assert_close(saved_mean_tensor, saved_mean, "saved_mean")
            self.__assert_close(saved_variance_tensor, saved_variance,
                                "saved_variance")
            self.__assert_close(mean_out_tensor, mean_out, "mean_out")
D
dzhwinter 已提交
500
            if isinstance(place, core.CUDAPlace):
Q
Qiao Longfei 已提交
501 502 503 504 505
                atol = 5e-2
            else:
                atol = 1e-4
            self.__assert_close(variance_out_tensor, variance_out,
                                "variance_out", atol)
Q
QI JUN 已提交
506
            print "op test forward passed: ", str(place), data_layout
Q
Qiao Longfei 已提交
507 508 509 510 511 512

            # run backward
            batch_norm_op_grad = get_backward_op(scope, batch_norm_op, set())
            set_output_grad(
                scope,
                ["y_out", "mean", "variance", "saved_mean", "saved_variance"],
Z
zchen0211 已提交
513 514
                place,
                feed_dict={"y_out": y_grad})
D
dzhwinter 已提交
515
            batch_norm_op_grad.run(scope, place)
Q
Qiao Longfei 已提交
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530

            x_grad_tensor = create_or_get_tensor(scope,
                                                 grad_var_name("x_val"), None,
                                                 place)
            scale_grad_tensor = create_or_get_tensor(scope,
                                                     grad_var_name("scale_val"),
                                                     None, place)
            bias_grad_tensor = create_or_get_tensor(scope,
                                                    grad_var_name("bias_val"),
                                                    None, place)

            # check gradient output
            self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad")
            self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad")
            self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad")
Q
QI JUN 已提交
531
            print "op test backward passed: ", str(place), data_layout
Q
Qiao Longfei 已提交
532 533

        places = [core.CPUPlace()]
534
        if core.is_compiled_with_cuda() and core.op_support_gpu("batch_norm"):
D
dzhwinter 已提交
535
            places.append(core.CUDAPlace(0))
D
dzhwinter 已提交
536

Q
Qiao Longfei 已提交
537
        for place in places:
538
            for data_format in ["NCHW", "NHWC"]:
539 540
                test_with_place(place, data_format, [2, 3, 4, 5])
                test_with_place(place, data_format, [2, 3])
Q
Qiao Longfei 已提交
541 542 543 544


if __name__ == '__main__':
    unittest.main()