test_batch_norm_op.py 9.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#!/usr/bin/env python3

# Copyright (c) 2021 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest, sys
import numpy as np
from op_test import OpTest, OpTestTool
6
6clc 已提交
20
from op_test_helper import TestCaseHelper
21 22 23 24 25 26
import paddle
import cinn
from cinn.frontend import *
from cinn.common import *


27 28 29
@OpTestTool.skip_if(
    not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
30 31
class TestBatchNormTrainOp(OpTest):
    def setUp(self):
6
6clc 已提交
32 33
        print(f"\nRunning {self.__class__.__name__}: {self.case}")
        self.prepare_inputs()
34

6
6clc 已提交
35 36
    def prepare_inputs(self):
        self.x_np = self.random(
37 38
            shape=self.case["x_shape"], dtype=self.case["x_dtype"]
        )
39 40

    def build_paddle_program(self, target):
6
6clc 已提交
41
        x = paddle.to_tensor(self.x_np)
42
        batch_norm = paddle.nn.BatchNorm(
43 44
            self.case["x_shape"][1], act=None, is_test=False
        )
45 46 47 48 49 50 51 52 53
        out = batch_norm(x)

        self.paddle_outputs = [out]

    # Note: If the forward and backward operators are run in the same program,
    # the forward result will be incorrect.
    def build_cinn_program(self, target):
        builder = NetBuilder("batch_norm")
        x = builder.create_input(
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
            self.nptype2cinntype(self.case["x_dtype"]),
            self.case["x_shape"],
            "x",
        )
        scale = builder.fill_constant(
            [self.case["x_shape"][1]], 1.0, 'scale', 'float32'
        )
        bias = builder.fill_constant(
            [self.case["x_shape"][1]], 0.0, 'bias', 'float32'
        )
        mean = builder.fill_constant(
            [self.case["x_shape"][1]], 0.0, 'mean', 'float32'
        )
        variance = builder.fill_constant(
            [self.case["x_shape"][1]], 1.0, 'variance', 'float32'
        )
70 71 72 73 74

        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)

        prog = builder.build()
        forward_res = self.get_cinn_output(
75 76
            prog, target, [x], [self.x_np], out, passes=[]
        )
77 78 79
        self.cinn_outputs = [forward_res[0]]

    def test_check_results(self):
80 81 82 83 84
        max_relative_error = (
            self.case["max_relative_error"]
            if "max_relative_error" in self.case
            else 1e-5
        )
6
6clc 已提交
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104
        self.check_outputs_and_grads(max_relative_error=max_relative_error)


class TestBatchNormTrainOpAll(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestBatchNormTrainOpCase"
        self.cls = TestBatchNormTrainOp

        self.inputs = [
            {
                "x_shape": [2, 16, 8, 8],
            },
            {
                "x_shape": [2, 16, 8, 1],
            },
            {
                "x_shape": [2, 16, 2048, 8],
            },
        ]
        self.dtypes = [
105 106 107
            {"x_dtype": "float16", "max_relative_error": 1e-3},
            {"x_dtype": "float32", "max_relative_error": 1e-5},
            {"x_dtype": "bfloat16", "max_relative_error": 1e-2},
6
6clc 已提交
108 109
        ]
        self.attrs = []
110 111


112 113 114
@OpTestTool.skip_if(
    not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
115 116
class TestBatchNormBackwardOp(OpTest):
    def setUp(self):
6
6clc 已提交
117 118
        print(f"\nRunning {self.__class__.__name__}: {self.case}")
        self.prepare_inputs()
119

6
6clc 已提交
120 121
    def prepare_inputs(self):
        self.x_np = self.random(
122 123
            shape=self.case["x_shape"], dtype=self.case["x_dtype"]
        )
6
6clc 已提交
124
        self.y_np = self.random(
125 126
            shape=self.case["x_shape"], dtype=self.case["x_dtype"]
        )
127 128

    def build_paddle_program(self, target):
6
6clc 已提交
129
        x = paddle.to_tensor(self.x_np, stop_gradient=False)
130
        batch_norm = paddle.nn.BatchNorm(
131 132
            self.case["x_shape"][1], act=None, is_test=False
        )
133 134 135
        out = batch_norm(x)

        self.paddle_outputs = [out]
6
6clc 已提交
136
        self.paddle_grads = self.get_paddle_grads([out], [x], [self.y_np])
137 138 139 140 141 142

    # Note: If the forward and backward operators are run in the same program,
    # the forward result will be incorrect.
    def build_cinn_program(self, target):
        builder = NetBuilder("batch_norm")
        x = builder.create_input(
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
            self.nptype2cinntype(self.case["x_dtype"]),
            self.case["x_shape"],
            "x",
        )
        scale = builder.fill_constant(
            [self.case["x_shape"][1]], 1.0, 'scale', 'float32'
        )
        bias = builder.fill_constant(
            [self.case["x_shape"][1]], 0.0, 'bias', 'float32'
        )
        mean = builder.fill_constant(
            [self.case["x_shape"][1]], 0.0, 'mean', 'float32'
        )
        variance = builder.fill_constant(
            [self.case["x_shape"][1]], 1.0, 'variance', 'float32'
        )
159 160 161 162 163

        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)

        prog = builder.build()
        forward_res = self.get_cinn_output(
164 165
            prog, target, [x], [self.x_np], out, passes=[]
        )
166 167 168 169
        self.cinn_outputs = [forward_res[0]]

        builder_grad = NetBuilder("batch_norm_grad")
        dout = builder_grad.create_input(
170 171 172 173
            self.nptype2cinntype(self.case["x_dtype"]),
            self.case["x_shape"],
            "dout",
        )
174
        x_g = builder_grad.create_input(
175 176 177 178 179 180 181
            self.nptype2cinntype(self.case["x_dtype"]),
            self.case["x_shape"],
            "x_g",
        )
        scale_g = builder_grad.fill_constant(
            scale.shape(), 1.0, 'scale_g', 'float32'
        )
182
        save_mean = builder_grad.create_input(
183 184
            self.nptype2cinntype('float32'), out[1].shape(), "save_mean"
        )
185
        save_variance = builder_grad.create_input(
186 187
            self.nptype2cinntype('float32'), out[2].shape(), "save_variance"
        )
188

189 190 191
        out_grad = builder_grad.batch_norm_grad(
            dout, x_g, scale_g, save_mean, save_variance
        )
192 193 194
        prog = builder_grad.build()
        backward_res = self.get_cinn_output(
            prog,
195 196
            target,
            [dout, x_g, save_mean, save_variance],
6
6clc 已提交
197
            [self.y_np, self.x_np, forward_res[1], forward_res[2]],
198
            out_grad,
199 200
            passes=[],
        )
201 202 203
        self.cinn_grads = [backward_res[0]]

    def test_check_results(self):
204 205 206 207 208
        max_relative_error = (
            self.case["max_relative_error"]
            if "max_relative_error" in self.case
            else 1e-5
        )
6
6clc 已提交
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228
        self.check_outputs_and_grads(max_relative_error=max_relative_error)


class TestBatchNormBackwardOpAll(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestBatchNormBackwardOpCase"
        self.cls = TestBatchNormBackwardOp

        self.inputs = [
            {
                "x_shape": [2, 16, 8, 8],
            },
            {
                "x_shape": [2, 16, 8, 1],
            },
            {
                "x_shape": [2, 16, 2048, 8],
            },
        ]
        self.dtypes = [
229 230
            {"x_dtype": "float16", "max_relative_error": 1e-3},
            {"x_dtype": "float32", "max_relative_error": 1e-5},
6
6clc 已提交
231 232
        ]
        self.attrs = []
233 234


235 236 237
@OpTestTool.skip_if(
    not is_compiled_with_cuda(), "x86 test will be skipped due to timeout."
)
238 239
class TestBatchNormInferOp(OpTest):
    def setUp(self):
6
6clc 已提交
240 241
        print(f"\nRunning {self.__class__.__name__}: {self.case}")
        self.prepare_inputs()
242

6
6clc 已提交
243 244
    def prepare_inputs(self):
        self.x_np = self.random(
245 246
            shape=self.case["x_shape"], dtype=self.case["x_dtype"]
        )
247 248

    def build_paddle_program(self, target):
6
6clc 已提交
249
        x = paddle.to_tensor(self.x_np)
250
        batch_norm = paddle.nn.BatchNorm(
251 252
            self.case["x_shape"][1], act=None, is_test=True
        )
253 254 255 256 257 258 259 260 261
        out = batch_norm(x)

        self.paddle_outputs = [out]

    # Note: If the forward and backward operators are run in the same program,
    # the forward result will be incorrect.
    def build_cinn_program(self, target):
        builder = NetBuilder("batch_norm")
        x = builder.create_input(
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277
            self.nptype2cinntype(self.case["x_dtype"]),
            self.case["x_shape"],
            "x",
        )
        scale = builder.fill_constant(
            [self.case["x_shape"][1]], 1.0, 'scale', 'float32'
        )
        bias = builder.fill_constant(
            [self.case["x_shape"][1]], 0.0, 'bias', 'float32'
        )
        mean = builder.fill_constant(
            [self.case["x_shape"][1]], 0.0, 'mean', 'float32'
        )
        variance = builder.fill_constant(
            [self.case["x_shape"][1]], 1.0, 'variance', 'float32'
        )
278 279 280 281 282

        out = builder.batchnorm(x, scale, bias, mean, variance, is_test=False)

        prog = builder.build()
        forward_res = self.get_cinn_output(
283 284
            prog, target, [x], [self.x_np], out, passes=[]
        )
285 286 287 288 289 290
        self.cinn_outputs = [forward_res[0]]

    def test_check_results(self):
        self.check_outputs_and_grads()


6
6clc 已提交
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
class TestBatchNormInferOpAll(TestCaseHelper):
    def init_attrs(self):
        self.class_name = "TestBatchNormInferOpCase"
        self.cls = TestBatchNormInferOp

        self.inputs = [
            {
                "x_shape": [2, 16, 8, 8],
            },
            {
                "x_shape": [2, 16, 8, 1],
            },
            {
                "x_shape": [2, 16, 2048, 8],
            },
        ]
        self.dtypes = [
308
            {"x_dtype": "float32", "max_relative_error": 1e-5},
6
6clc 已提交
309 310 311 312
        ]
        self.attrs = []


313
if __name__ == "__main__":
6
6clc 已提交
314 315 316
    TestBatchNormTrainOpAll().run()
    TestBatchNormBackwardOpAll().run()
    TestBatchNormInferOpAll().run()