test_op_benchmark.py 20.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
#!/usr/bin/env python3

# Copyright (c) 2021 CINN Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddle.static as static
from cinn.frontend import *
from cinn import Target
from cinn.framework import *
import unittest
import cinn
from cinn import runtime
from cinn import ir
from cinn import lang
from cinn.common import *
import numpy as np
import sys

assert len(sys.argv) == 2
enable_gpu = sys.argv.pop()


class TestBenchmark(unittest.TestCase):
    def setUp(self):
        if enable_gpu == "ON":
            self.target = DefaultNVGPUTarget()
        else:
            self.target = DefaultHostTarget()

    def paddle_verify(self, result):
        paddle.enable_static()

        a = static.data(name='A', shape=[1, 128, 28, 28], dtype='float32')
        e = paddle.nn.initializer.NumpyArrayInitializer(
47 48
            np.array(result[1]).reshape((256, 128, 1, 1)).astype("float32")
        )
49 50 51 52 53 54 55
        res = static.nn.conv2d(
            input=a,
            num_filters=256,
            filter_size=1,
            stride=2,
            padding=0,
            dilation=1,
56 57
            param_attr=e,
        )
58 59 60 61 62 63 64 65 66 67

        exe = static.Executor(paddle.CPUPlace())
        exe.run(static.default_startup_program())

        x = np.array(result[0]).reshape((1, 128, 28, 28)).astype("float32")
        output = exe.run(feed={"A": x}, fetch_list=[res])
        output = np.array(output).reshape(-1)
        print("result in conv2d paddle_verify: \n")
        for i in range(0, output.shape[0]):
            if np.abs(output[i] - result[len(result) - 1][i]) > 1e-4:
68 69 70 71 72 73 74 75 76 77 78
                print(
                    "Error! ",
                    i,
                    "-th data has diff with target data:\n",
                    output[i],
                    " vs: ",
                    result[len(result) - 1][i],
                    ". Diff is: ",
                    output[i] - result[len(result) - 1][i],
                )
        self.assertTrue(np.allclose(result[len(result) - 1], output, atol=1e-4))
79 80 81 82 83

    def atest_conv2d_cinn(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([1, 128, 28, 28])
        b = Variable("E").set_type(Float(32)).set_shape([256, 128, 1, 1])
84 85 86
        c = prog.conv2d(
            a, b, {"stride": [2, 2], "dilation": [1, 1], "padding": [0, 0]}
        )
87 88
        tensor_data = [
            np.random.random([1, 128, 28, 28]).astype("float32"),
89
            np.random.random([256, 128, 1, 1]).astype("float32"),
90 91
        ]
        result = prog.test_benchmark(
92 93 94 95 96 97 98
            self.target,
            [a, b],
            tensor_data,
            c,
            20000,
            "TESTING [conv2d] time cost with shape [1, 128, 28, 28]...",
        )
99 100 101 102 103 104 105 106
        result = result.numpy(self.target).reshape(-1)
        tensor_data.append(result)
        self.paddle_verify(tensor_data)

    def atest_conv2d_cinn_code(self):
        prog = Program()
        a = Variable("X").set_type(Float(32)).set_shape([1, 128, 28, 28])
        b = Variable("Y").set_type(Float(32)).set_shape([256, 128, 1, 1])
107 108 109
        c = prog.conv2d(
            a, b, {"stride": [2, 2], "dilation": [1, 1], "padding": [0, 0]}
        )
110 111
        tensor_data = [
            np.random.random([1, 128, 28, 28]).astype("float32"),
112
            np.random.random([256, 128, 1, 1]).astype("float32"),
113 114
        ]
        result = prog.test_benchmark_with_code(
115 116 117 118 119
            self.target,
            [a, b],
            tensor_data,
            c,
            20000,
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180
            "TESTING [conv2d of tvm schedule] time cost with shape [1, 128, 28, 28]...",
            """
extern "C" {

#include "cinn_cuda_runtime_source.cuh"

#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif



__global__
void fn_conv2d_0_kernel(const float* __restrict__ X, const float* __restrict__ Y, float* __restrict__ COD)
{
  __shared__ float _input_pad_0_read_cache [ 224 ];
  float _COD_write_cache [ 2 ];
  __shared__ float _Y_read_cache [ 256 ];
  float* COD_write_cache = _COD_write_cache;
  float* COD_write_cache__reduce_init = _COD_write_cache;
  float* Y_read_cache = _Y_read_cache;
  float* input_pad_0_read_cache = _input_pad_0_read_cache;
  if ((blockIdx.z < 8)) {
    if ((blockIdx.y < 14)) {
      if ((threadIdx.z < 16)) {
        if ((threadIdx.x < 14)) {
        {
          for (int32_t rc_outer = 0; rc_outer < 2; rc_outer += 1) {
            COD_write_cache__reduce_init[rc_outer] = 0;
          };
          for (int32_t rc_outer = 0; rc_outer < 16; rc_outer += 1) {
            {
              __syncthreads();
              if ((threadIdx.z < 8)) {
                input_pad_0_read_cache[((2 * threadIdx.x) + (28 * threadIdx.z))] = X[((56 * blockIdx.y) + ((6272 * rc_outer) + ((2 * threadIdx.x) + (784 * threadIdx.z))))];
              };
            };
            for (int32_t rc_inner = 0; rc_inner < 2; rc_inner += 1) {
              if ((threadIdx.x < 8)) {
                Y_read_cache[((threadIdx.x / 2) + ((8 * (threadIdx.x % 2)) + ((4 * rc_inner) + (16 * threadIdx.z))))] = Y[((threadIdx.x / 2) + ((128 * (threadIdx.x % 2)) + ((4096 * blockIdx.z) + ((4 * rc_inner) + ((8 * rc_outer) + (256 * threadIdx.z))))))];
              };
            };
            __syncthreads();
            for (int32_t rc_inner = 0; rc_inner < 8; rc_inner += 1) {
              for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
                COD_write_cache[j_inner] = (COD_write_cache[j_inner] + (input_pad_0_read_cache[((28 * rc_inner) + (2 * threadIdx.x))] * Y_read_cache[((8 * j_inner) + ((16 * threadIdx.z) + rc_inner))]));
              };
            };
          };
          for (int32_t rc_outer = 0; rc_outer < 2; rc_outer += 1) {
            COD[((14 * blockIdx.y) + ((6272 * blockIdx.z) + ((196 * rc_outer) + ((392 * threadIdx.z) + threadIdx.x))))] = COD_write_cache[rc_outer];
          };
        }
        };
      };
    };
  };
}

}
181 182
            """,
        )
183 184 185 186 187 188
        result = result.numpy(self.target).reshape(-1)
        tensor_data.append(result)
        self.paddle_verify(tensor_data)

    def atest_conv2d_tvm_code(self):
        prog = Program()
189 190 191 192 193 194 195 196 197 198 199 200 201
        a = (
            Variable("placeholder")
            .set_type(Float(32))
            .set_shape([1, 128, 28, 28])
        )
        b = (
            Variable("placeholder1")
            .set_type(Float(32))
            .set_shape([256, 128, 1, 1])
        )
        c = prog.conv2d(
            a, b, {"stride": [2, 2], "dilation": [1, 1], "padding": [0, 0]}
        )
202 203
        tensor_data = [
            np.random.random([1, 128, 28, 28]).astype("float32"),
204
            np.random.random([256, 128, 1, 1]).astype("float32"),
205 206
        ]
        result = prog.test_benchmark_with_code(
207 208 209 210 211
            self.target,
            [a, b],
            tensor_data,
            c,
            20000,
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
            "TESTING [conv2d of tvm schedule] time cost with shape [1, 128, 28, 28]...",
            """
extern "C" {

#include "cinn_cuda_runtime_source.cuh"

#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif



__global__ void fn_conv2d_0_kernel(float* __restrict__ placeholder, float* __restrict__ placeholder1, float* __restrict__ Conv2d_nchw_out) {
  float compute_local[2];
  __shared__ float pad_temp_shared[216];
  __shared__ float placeholder_shared[256];
  for (int ff_c_init = 0; ff_c_init < 2; ++ff_c_init) {
    compute_local[(ff_c_init)] = 0.000000e+00f;
  }
  for (int rc_outer = 0; rc_outer < 16; ++rc_outer) {
    __syncthreads();
    if (((((int)threadIdx.z) * 14) + ((int)threadIdx.x)) < 216) {
      pad_temp_shared[(((((int)threadIdx.z) * 14) + ((int)threadIdx.x)))] = placeholder[(((((rc_outer * 6272) + ((((((int)threadIdx.z) * 14) + ((int)threadIdx.x)) / 27) * 784)) + (((int)blockIdx.y) * 56)) + (((((int)threadIdx.z) * 14) + ((int)threadIdx.x)) % 27)))];
    }
    for (int ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner = 0; ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner < 2; ++ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) {
      if (((((int)threadIdx.z) * 2) + (((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) >> 3)) < 32) {
        if ((((((int)threadIdx.z) * 16) + (((int)threadIdx.x) * 2)) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) < 256) {
          if (((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) < 16) {
            placeholder_shared[((((((int)threadIdx.z) * 16) + (((int)threadIdx.x) * 2)) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner))] = placeholder1[((((((((int)blockIdx.z) * 4096) + (((int)threadIdx.z) * 256)) + ((((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) >> 3) * 128)) + (rc_outer * 8)) + (((((int)threadIdx.x) * 2) + ax0_ax1_fused_ax2_fused_ax3_fused_inner_inner_inner) & 7)))];
          }
        }
      }
    }
    __syncthreads();
    for (int rc_inner = 0; rc_inner < 8; ++rc_inner) {
      for (int ff_c = 0; ff_c < 2; ++ff_c) {
        compute_local[(ff_c)] = (compute_local[(ff_c)] + (pad_temp_shared[(((rc_inner * 27) + (((int)threadIdx.x) * 2)))] * placeholder_shared[((((((int)threadIdx.z) * 16) + (ff_c * 8)) + rc_inner))]));
      }
    }
  }
  for (int ff_inner_inner_inner = 0; ff_inner_inner_inner < 2; ++ff_inner_inner_inner) {
    Conv2d_nchw_out[((((((((int)blockIdx.z) * 6272) + (((int)threadIdx.z) * 392)) + (ff_inner_inner_inner * 196)) + (((int)blockIdx.y) * 14)) + ((int)threadIdx.x)))] = compute_local[(ff_inner_inner_inner)];
  }
}

}
259 260
            """,
        )
261 262 263 264 265 266 267 268 269 270
        result = result.numpy(self.target).reshape(-1)
        tensor_data.append(result)
        self.paddle_verify(tensor_data)

    def atest_softmax(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([1024, 2048])
        c = prog.softmax(a, {})
        tensor_data = [np.random.random([1024, 2048]).astype("float32")]
        result = prog.test_benchmark(
271 272 273 274 275 276 277
            self.target,
            [a],
            tensor_data,
            c,
            200,
            "TESTING [softmax] time cost with shape [1024,2048]...",
        )
278 279 280 281 282 283 284 285

    def atest_matmul(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([512, 512])
        b = Variable("B").set_type(Float(32)).set_shape([512, 512])
        c = prog.mul(a, b, 1, 1)
        tensor_data = [
            np.random.random([512, 512]).astype("float32"),
286
            np.random.random([512, 512]).astype("float32"),
287 288
        ]
        result = prog.test_benchmark(
289 290 291 292 293 294 295
            self.target,
            [a, b],
            tensor_data,
            c,
            200,
            "TESTING [matmul] time cost with shape [512,512]...",
        )
296 297 298 299 300 301 302 303 304 305 306

    def atest_matmul2(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([128, 512])
        b = Variable("B").set_type(Float(32)).set_shape([256, 512])
        c = Variable("C").set_type(Float(32)).set_shape([128, 256])
        d = prog.mul(a, b, 1, 1)
        e = prog.add(d, c)
        tensor_data = [
            np.random.random([128, 512]).astype("float32"),
            np.random.random([256, 512]).astype("float32"),
307
            np.random.random([128, 256]).astype("float32"),
308 309
        ]
        result = prog.test_benchmark(
310 311 312 313 314 315
            self.target,
            [a, b, c],
            tensor_data,
            e,
            200,
            "TESTING [mul and add] time cost with shape [128,512]*[256,512]...",
316 317 318 319 320 321 322 323 324 325 326
        )

    def atest_matmul(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([512, 512])
        b = Variable("B").set_type(Float(32)).set_shape([512, 512])
        c = Variable("C").set_type(Float(32)).set_shape([512, 512])
        d = prog.mul(a, b, 1, 1)
        # e = prog.add(d, c)
        tensor_data = [
            np.random.random([512, 512]).astype("float32"),
327
            np.random.random([512, 512]).astype("float32"),
328 329
        ]
        result = prog.test_benchmark_with_code(
330 331 332 333 334 335 336
            self.target,
            [a, b],
            tensor_data,
            d,
            200,
            "TESTING [matmul] time cost with shape [512,512]...",
            '''
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374
            extern "C" {
#include "cinn_cuda_runtime_source.cuh"
#ifdef __CUDACC_RTC__
typedef int int32_t;
typedef char int8_t;
#endif

 __global__
 void fn_mul_0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ Mul_output)
 {
   const float* A_reshape = A;
   const float* B_reshape = B;
   float* Mul_output__reduce_init = Mul_output;
   if ((blockIdx.x < 512)) {
   {
     if ((threadIdx.x < 256)) {
     {
       for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
         Mul_output__reduce_init[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] = 0;
       };
     }
     };
   }
   };
   if ((blockIdx.x < 512)) {
   {
     if ((threadIdx.x < 256)) {
     {
       for (int32_t j_inner = 0; j_inner < 2; j_inner += 1) {
        for (int32_t axis_k = 0; axis_k < 512; axis_k += 1) {
          Mul_output[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] = (Mul_output[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))] + (A_reshape[((512 * blockIdx.x) + axis_k)] * B_reshape[((512 * axis_k) + ((2 * threadIdx.x) + j_inner))])) + Mul_output[((512 * blockIdx.x) + ((2 * threadIdx.x) + j_inner))];
         };
       };
     }
     };
  }
  };
 }
375 376
 }''',
        )
377 378 379 380 381

    def atest_pool2d(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([2, 64, 112, 112])
        c = prog.pool2d(
382 383
            a,
            {
384 385 386
                "kernel_size": (3, 3),
                "stride_size": (2, 2),
                "padding_size": (1, 1, 1, 1),
387 388 389
                "pool_type": "max",
            },
        )
390 391
        tensor_data = [np.random.random([2, 64, 112, 112]).astype("float32")]
        result = prog.test_benchmark(
392 393 394 395 396 397 398
            self.target,
            [a],
            tensor_data,
            c,
            2000,
            "TESTING [pool2d] time cost with shape [2, 64, 112, 112]...",
        )
399 400 401 402 403 404 405 406

    def atest_elementwise1(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([64, 64])
        b = Variable("B").set_type(Float(32)).set_shape([64, 64])
        c = prog.add(a, b)
        tensor_data = [
            np.random.random([64, 64]).astype("float32"),
407
            np.random.random([64, 64]).astype("float32"),
408 409
        ]
        result = prog.test_benchmark(
410 411 412 413 414 415 416
            self.target,
            [a, b],
            tensor_data,
            c,
            200,
            "TESTING [elementwise_add] time cost with shape [64, 64]...",
        )
417 418 419
        result = result.numpy(self.target).reshape(-1)
        self.assertTrue(
            np.allclose(
420 421 422
                (tensor_data[0] + tensor_data[1]).reshape(-1), result, atol=1e-4
            )
        )
423 424 425 426 427 428 429 430

    def atest_elementwise2(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([2, 512, 112, 112])
        b = Variable("B").set_type(Float(32)).set_shape([2, 512, 112, 112])
        c = prog.add(a, b)
        tensor_data = [
            np.random.random([2, 512, 112, 112]).astype("float32"),
431
            np.random.random([2, 512, 112, 112]).astype("float32"),
432 433
        ]
        result = prog.test_benchmark(
434 435 436 437 438 439
            self.target,
            [a, b],
            tensor_data,
            c,
            200,
            "TESTING [elementwise_add] time cost with shape [2, 512, 112, 112]...",
440 441 442 443
        )
        result = result.numpy(self.target).reshape(-1)
        self.assertTrue(
            np.allclose(
444 445 446
                (tensor_data[0] + tensor_data[1]).reshape(-1), result, atol=1e-4
            )
        )
447 448 449 450 451 452 453 454

    def atest_elementwise2(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([4, 1024])
        b = Variable("B").set_type(Float(32)).set_shape([4, 1024])
        c = prog.add(a, b)
        tensor_data = [
            np.random.random([4, 1024]).astype("float32"),
455
            np.random.random([4, 1024]).astype("float32"),
456 457
        ]
        result = prog.test_benchmark_with_code(
458 459 460 461 462
            self.target,
            [a, b],
            tensor_data,
            c,
            200,
463 464 465 466 467 468 469 470 471 472
            "TESTING [elementwise_add] time cost with input code...",
            '''extern "C" {

__global__
void fn_elementwise_add_0_kernel(const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ EleAdd_Out_0)
{

      EleAdd_Out_0[1024 * blockIdx.x + threadIdx.x] = (A[1024 * blockIdx.x + threadIdx.x] + B[1024 * blockIdx.x + threadIdx.x]);
}

473 474
}''',
        )
475 476 477 478 479 480 481 482 483 484 485 486 487 488

    def atest_batchnorm(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([2, 512, 32, 32])
        b = Variable("B").set_type(Float(32)).set_shape([512])
        c = Variable("C").set_type(Float(32)).set_shape([512])
        d = Variable("D").set_type(Float(32)).set_shape([512])
        e = Variable("E").set_type(Float(32)).set_shape([512])
        f = prog.batchnorm(a, b, c, d, e, {})
        tensor_data = [
            np.random.random([2, 512, 32, 32]).astype("float32"),
            np.random.random([512]).astype("float32"),
            np.random.random([512]).astype("float32"),
            np.random.random([512]).astype("float32"),
489
            np.random.random([512]).astype("float32"),
490 491
        ]
        result = prog.test_benchmark(
492 493 494 495 496 497 498
            self.target,
            [a, b, c, d, e],
            tensor_data,
            f,
            1000,
            "TESTING [batchnorm] time cost with shape [2, 512, 32, 32]...",
        )
499 500 501 502 503 504 505 506 507 508 509 510 511 512

    def atest_batchnorm2(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([2, 64, 8, 8])
        b = Variable("B").set_type(Float(32)).set_shape([64])
        c = Variable("C").set_type(Float(32)).set_shape([64])
        d = Variable("D").set_type(Float(32)).set_shape([64])
        e = Variable("E").set_type(Float(32)).set_shape([64])
        f = prog.batchnorm(a, b, c, d, e, {})
        tensor_data = [
            np.random.random([2, 64, 8, 8]).astype("float32"),
            np.random.random([64]).astype("float32"),
            np.random.random([64]).astype("float32"),
            np.random.random([64]).astype("float32"),
513
            np.random.random([64]).astype("float32"),
514 515
        ]
        result = prog.test_benchmark(
516 517 518 519 520 521 522
            self.target,
            [a, b, c, d, e],
            tensor_data,
            f,
            200,
            "TESTING [batchnorm] time cost with shape [2, 64, 8, 8]...",
        )
523 524 525 526 527 528 529

    def atest_relu3(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([2, 512, 112, 112])
        c = prog.relu(a)
        tensor_data = [np.random.random([2, 512, 112, 112]).astype("float32")]
        result = prog.test_benchmark(
530 531 532 533 534 535 536
            self.target,
            [a],
            tensor_data,
            c,
            200,
            "TESTING [relu] time cost with shape [2,512,112,112]...",
        )
537 538 539 540 541 542 543

    def atest_relu(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([64, 64])
        c = prog.sigmoid(a)
        tensor_data = [np.random.random([64, 64]).astype("float32")]
        result = prog.test_benchmark(
544 545 546 547 548 549 550
            self.target,
            [a],
            tensor_data,
            c,
            200,
            "TESTING [sigmoid] time cost with shape [64,64]...",
        )
551 552 553 554 555 556 557

    def atest_relu2(self):
        prog = Program()
        a = Variable("A").set_type(Float(32)).set_shape([2, 512, 112, 112])
        c = prog.sigmoid(a)
        tensor_data = [np.random.random([2, 512, 112, 112]).astype("float32")]
        result = prog.test_benchmark(
558 559 560 561 562 563 564
            self.target,
            [a],
            tensor_data,
            c,
            200,
            "TESTING [sigmoid] time cost with shape [2,512,112,112]...",
        )
565 566 567 568


if __name__ == "__main__":
    unittest.main()