test_prim2orig.py 22.1 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle.fluid.layer_helper import LayerHelper
19
from paddle.incubate.autograd.primrules import _prim2orig
20 21 22 23

paddle.enable_static()


24
# ------------------------ Test prim2orig rules ---------------------------- #
25 26 27 28 29 30
class TestAddPPrim2Orig(unittest.TestCase):
    def setUp(self):
        self.main_program = paddle.static.Program()
        self.startup_program = paddle.static.Program()
        self.layer_help = LayerHelper('TestPrim2Orig')

31 32 33
        with paddle.static.program_guard(
            self.main_program, self.startup_program
        ):
34 35 36 37 38 39 40 41 42
            self.init_data()

    def init_data(self):
        self.op_type = 'add_p'
        X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
        Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')

        self.input = {'X': X, 'Y': Y}
        self.output = {
43 44 45
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
46 47 48 49 50 51 52 53 54
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['add_p', 'elementwise_add']
        # { prim_op_output_var: orign_op_out_index }
        self.out_map = {self.output['Z']: 0}

    def test_op(self):
55 56 57 58 59 60 61 62 63
        with paddle.static.program_guard(
            self.main_program, self.startup_program
        ):
            op = self.layer_help.append_op(
                type=self.op_type,
                inputs=self.input,
                outputs=self.output,
                attrs=self.attrs,
            )
64 65 66 67

            orig_out = _prim2orig(op, *self.prim2orig_args)
            all_ops = [op.type for op in self.main_program.block(0).ops]
            self.assertEqual(sorted(all_ops), sorted(self.all_ops))
68
            orig_out = paddle.utils.flatten(orig_out)
69 70 71 72 73 74 75 76 77 78 79 80
            for k, v in self.out_map.items():
                self.assertEqual(k.shape, orig_out[v].shape)


class TestSubPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'sub_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
81 82 83
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['sub_p', 'elementwise_sub']
        self.out_map = {self.output['Z']: 0}


class TestMulPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'mul_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
100 101 102
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['mul_p', 'elementwise_mul']
        self.out_map = {self.output['Z']: 0}


class TestDivPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'div_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
119 120 121
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
122 123 124 125 126 127 128 129 130 131 132 133 134
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['div_p', 'elementwise_div']
        self.out_map = {self.output['Z']: 0}


class TestSqrtPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'sqrt_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

135 136 137
        self.input = {
            'X': X,
        }
138
        self.output = {
139 140 141
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
142 143 144
        }
        self.attrs = {}

145
        self.prim2orig_args = (X,)
146 147 148 149 150 151 152 153 154
        self.all_ops = ['sqrt_p', 'sqrt']
        self.out_map = {self.output['Y']: 0}


class TestTanhPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'tanh_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

155 156 157
        self.input = {
            'X': X,
        }
158
        self.output = {
159 160 161
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
162 163 164
        }
        self.attrs = {}

165
        self.prim2orig_args = (X,)
166 167 168 169
        self.all_ops = ['tanh_p', 'tanh']
        self.out_map = {self.output['Y']: 0}


170 171 172 173 174 175 176 177 178
class TestSinPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'sin_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
179 180 181
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
182 183 184
        }
        self.attrs = {}

185
        self.prim2orig_args = (X,)
186 187 188 189 190 191 192 193 194 195 196 197 198
        self.all_ops = ['sin_p', 'sin']
        self.out_map = {self.output['Y']: 0}


class TestCosPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'cos_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
199 200 201
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
202 203 204
        }
        self.attrs = {}

205
        self.prim2orig_args = (X,)
206 207 208 209 210 211 212 213 214 215 216 217 218
        self.all_ops = ['cos_p', 'cos']
        self.out_map = {self.output['Y']: 0}


class TestExpPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'exp_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
219 220 221
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
222 223 224
        }
        self.attrs = {}

225
        self.prim2orig_args = (X,)
226 227 228 229
        self.all_ops = ['exp_p', 'exp']
        self.out_map = {self.output['Y']: 0}


230 231 232 233 234 235 236 237 238
class TestErfPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'erf_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
239 240 241
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
242 243 244
        }
        self.attrs = {}

245
        self.prim2orig_args = (X,)
246 247 248 249
        self.all_ops = ['erf_p', 'erf']
        self.out_map = {self.output['Y']: 0}


250 251 252 253 254 255 256 257 258
class TestAbsPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'abs_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
259 260 261
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
262 263 264
        }
        self.attrs = {}

265
        self.prim2orig_args = (X,)
266 267 268 269
        self.all_ops = ['abs_p', 'abs']
        self.out_map = {self.output['Y']: 0}


270 271 272 273 274 275 276 277 278
class TestLogPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'log_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
279 280 281
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
282 283 284
        }
        self.attrs = {}

285
        self.prim2orig_args = (X,)
286 287 288 289
        self.all_ops = ['log_p', 'log']
        self.out_map = {self.output['Y']: 0}


290 291 292 293 294
class TestReshapePPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'reshape_p'
        X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')

295 296 297
        self.input = {
            'X': X,
        }
298
        self.output = {
299 300 301
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
302 303 304
        }
        self.attrs = {'shape': [4, 4]}

305
        self.prim2orig_args = (X,)
306 307 308 309 310 311 312 313 314
        self.all_ops = ['reshape_p', 'reshape2']
        self.out_map = {self.output['Y']: 0}


class TestBroadcastPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'broadcast_p'
        X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')

315 316 317
        self.input = {
            'X': X,
        }
318
        self.output = {
319 320 321
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
322 323 324
        }
        self.attrs = {'shape': [10, 2, 8]}

325
        self.prim2orig_args = (X,)
326 327 328 329 330 331 332 333 334
        self.all_ops = ['broadcast_p', 'expand_v2']
        self.out_map = {self.output['Y']: 0}


class TestTransposePPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'transpose_p'
        X = paddle.static.data(name='X', shape=[7, 8, 9, 10], dtype='float64')

335 336 337
        self.input = {
            'X': X,
        }
338
        self.output = {
339 340 341
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
342 343 344
        }
        self.attrs = {'axis': [1, 2, 0, 3]}

345
        self.prim2orig_args = (X,)
346 347 348 349 350 351 352 353 354
        self.all_ops = ['transpose_p', 'transpose2']
        self.out_map = {self.output['Y']: 0}


class TestSplitPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'split_p'
        X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')

355 356 357
        self.input = {
            'X': X,
        }
358 359 360
        self.output = {
            'YS': [
                self.layer_help.create_variable_for_type_inference(
361 362 363
                    dtype=X.dtype
                )
                for i in range(3)
364 365 366 367
            ]
        }
        self.attrs = {'num_or_sections': [2, 3, 4], 'axis': 1}

368
        self.prim2orig_args = (X,)
369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
        self.all_ops = ['split_p', 'split']
        self.out_map = {
            self.output['YS'][0]: 0,
            self.output['YS'][1]: 1,
            self.output['YS'][2]: 2,
        }


class TestConcatPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'concat_p'
        X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[2, 9, 5], dtype='float64')
        Z = paddle.static.data(name='Z', shape=[1, 9, 5], dtype='float64')

384 385 386
        self.input = {
            'XS': [X, Y, Z],
        }
387
        self.output = {
388 389 390
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
391 392 393
        }
        self.attrs = {'axis': 0}

394
        self.prim2orig_args = ((X, Y, Z),)
395 396 397 398 399 400
        self.all_ops = ['concat_p', 'concat']
        self.out_map = {self.output['Y']: 0}


class TestReducePPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
401
        self.op_type = 'reduce_sum_p'
402 403 404 405
        X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')

        self.input = {'X': X}
        self.output = {
406 407 408
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
409 410 411
        }
        self.attrs = {'axis': [1], 'keepdim': True}

412
        self.prim2orig_args = (X,)
413
        self.all_ops = ['reduce_sum_p', 'reduce_sum']
414 415 416 417 418 419 420 421 422 423 424
        self.out_map = {self.output['Y']: 0}


class TestMatmulPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'matmul_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[5, 9], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
425 426 427
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
428 429 430 431 432 433 434 435 436 437 438 439 440
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['matmul_p', 'matmul_v2']
        self.out_map = {self.output['Z']: 0}


class TestSliceSelectPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'slice_select_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')

441 442 443
        self.input = {
            'X': X,
        }
444
        self.output = {
445 446 447
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
448 449 450
        }
        self.attrs = {'axis': [0], 'starts': [1], 'ends': [8], 'strides': [2]}

451
        self.prim2orig_args = (X,)
452 453 454 455 456 457 458 459 460 461 462 463
        self.all_ops = ['slice_select_p', 'strided_slice']
        self.out_map = {self.output['Y']: 0}


class TestSliceAssignPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'slice_assign_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
464 465 466
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
467 468 469 470 471 472 473 474 475 476 477 478
        }
        self.attrs = {'axis': [1], 'starts': [0], 'ends': [3], 'strides': [1]}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['slice_assign_p', 'assign', 'set_value']
        self.out_map = {self.output['Z']: 0}


class TestGatherPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'gather_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
479 480 481
        IndexTensor = paddle.static.data(
            name='IndexTensor', shape=[3], dtype='int32'
        )
482 483 484

        self.input = {'X': X, 'IndexTensor': IndexTensor}
        self.output = {
485 486 487
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
488
        }
489 490 491
        self.attrs = {
            'axis': 0,
        }
492 493 494

        self.prim2orig_args = (
            IndexTensor,
495 496
            X,
        )
497 498 499 500 501 502 503 504 505
        self.all_ops = ['gather_p', 'gather']
        self.out_map = {self.output['Y']: 0}


class TestScatterAddPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'scatter_add_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
506 507 508
        IndexTensor = paddle.static.data(
            name='IndexTensor', shape=[3], dtype='int32'
        )
509 510 511

        self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor}
        self.output = {
512 513 514
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
515
        }
516 517 518
        self.attrs = {
            'axis': 0,
        }
519 520 521

        self.prim2orig_args = (IndexTensor, X, Y)
        self.all_ops = [
522 523 524 525
            'scatter_add_p',
            'fill_any_like',
            'scatter',
            'elementwise_add',
526 527 528 529 530 531 532 533 534 535
        ]
        self.out_map = {self.output['Z']: 0}


class TestFillConstantPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'fill_constant_p'

        self.input = {}
        self.output = {
536 537 538
            'Y': self.layer_help.create_variable_for_type_inference(
                paddle.int32
            )
539 540 541 542 543 544 545 546
        }
        self.attrs = {'value': 10, 'shape': [5, 5], 'dtype': paddle.int32}

        self.prim2orig_args = ()
        self.all_ops = ['fill_constant_p', 'fill_constant']
        self.out_map = {self.output['Y']: 0}


547 548 549 550 551 552 553 554 555
class TestSelectPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'select_p'
        Cond = paddle.static.data(name='Condition', shape=[5, 6], dtype='bool')
        X = paddle.static.data(name='X', shape=[5, 6], dtype='float32')
        Y = paddle.static.data(name='Y', shape=[5, 6], dtype='float32')

        self.input = {'Condition': Cond, 'X': X, 'Y': Y}
        self.output = {
556 557 558
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
559 560 561 562 563 564 565 566 567 568 569 570 571 572 573
        }
        self.attrs = {}
        self.prim2orig_args = (Cond, X, Y)
        self.all_ops = ['select_p', 'where']
        self.out_map = {self.output['Z']: 0}


class TestEqPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'eq_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
574 575 576
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype='bool'
            )
577 578 579 580 581 582 583 584
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['eq_p', 'equal']
        self.out_map = {self.output['Z']: 0}


585 586 587 588 589 590 591 592
class TestNePPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'ne_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
593 594 595
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype='bool'
            )
596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['ne_p', 'not_equal']
        self.out_map = {self.output['Z']: 0}


class TestGtPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'gt_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
612 613 614
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype='bool'
            )
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['gt_p', 'greater_than']
        self.out_map = {self.output['Z']: 0}


class TestGePPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'ge_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
631 632 633
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype='bool'
            )
634 635 636 637 638 639 640 641
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['ge_p', 'greater_equal']
        self.out_map = {self.output['Z']: 0}


642 643 644 645 646 647 648 649
class TestPowPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'pow_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
650 651 652
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
653 654 655 656 657 658 659 660
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['pow_p', 'elementwise_pow']
        self.out_map = {self.output['Z']: 0}


661 662 663 664 665 666 667 668
class TestMaxPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'max_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
669 670 671
            'Z': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
672 673 674 675 676 677 678 679
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['max_p', 'elementwise_max']
        self.out_map = {self.output['Z']: 0}


680 681 682 683 684 685
class TestBernoulliPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'bernoulli_p'

        self.input = {}
        self.output = {
686 687 688
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=paddle.float64
            )
689 690 691 692 693 694 695 696
        }
        self.attrs = {'shape': [7, 8], 'dtype': paddle.float64, 'p': 0.5}

        self.prim2orig_args = ()
        self.all_ops = ['bernoulli_p', 'fill_constant', 'bernoulli']
        self.out_map = {self.output['Y']: 0}


697 698 699 700 701 702 703 704 705
class TestCastPPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'cast_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
706 707 708
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
709 710 711
        }
        self.attrs = {'dtype': paddle.int64}

712
        self.prim2orig_args = (X,)
713 714 715 716
        self.all_ops = ['cast_p', 'cast']
        self.out_map = {self.output['Y']: 0}


J
Jiabin Yang 已提交
717 718 719 720 721 722 723 724 725
class TestRsqrtPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'rsqrt_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

        self.input = {
            'X': X,
        }
        self.output = {
726 727 728
            'Y': self.layer_help.create_variable_for_type_inference(
                dtype=X.dtype
            )
J
Jiabin Yang 已提交
729 730 731
        }
        self.attrs = {}

732
        self.prim2orig_args = (X,)
J
Jiabin Yang 已提交
733 734 735 736
        self.all_ops = ['rsqrt_p', 'rsqrt']
        self.out_map = {self.output['Y']: 0}


737 738 739 740 741 742
class TestUniformRandomPrim2Orig(TestAddPPrim2Orig):
    def init_data(self):
        self.op_type = 'uniform_random_p'

        self.input = {}
        self.output = {
743 744 745
            'Out': self.layer_help.create_variable_for_type_inference(
                dtype=paddle.float64
            )
746 747 748 749 750 751
        }
        self.attrs = {
            'shape': [1, 2, 3],
            'min': -1.0,
            'max': 1.0,
            'seed': 0,
752
            'dtype': paddle.float64,
753 754 755 756 757 758 759
        }

        self.prim2orig_args = ()
        self.all_ops = ['uniform_random_p', 'uniform_random']
        self.out_map = {self.output['Out']: 0}


760 761
if __name__ == '__main__':
    unittest.main()