test_prim2orig.py 12.7 KB
Newer Older
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.layers.utils import flatten
from paddle.incubate.autograd.primrules import _orig2prim, _prim2orig, _jvp, _transpose

paddle.enable_static()


############################ Test prim2orig rules ############################
class TestAddPPrim2Orig(unittest.TestCase):
27

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    def setUp(self):
        self.main_program = paddle.static.Program()
        self.startup_program = paddle.static.Program()
        self.layer_help = LayerHelper('TestPrim2Orig')

        with paddle.static.program_guard(self.main_program,
                                         self.startup_program):
            self.init_data()

    def init_data(self):
        self.op_type = 'add_p'
        X = paddle.static.data(name='X', shape=[2, 2], dtype='float')
        Y = paddle.static.data(name='Y', shape=[2, 2], dtype='float')

        self.input = {'X': X, 'Y': Y}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['add_p', 'elementwise_add']
        # { prim_op_output_var: orign_op_out_index }
        self.out_map = {self.output['Z']: 0}

    def test_op(self):
        with paddle.static.program_guard(self.main_program,
                                         self.startup_program):
57 58 59 60
            op = self.layer_help.append_op(type=self.op_type,
                                           inputs=self.input,
                                           outputs=self.output,
                                           attrs=self.attrs)
61 62 63 64 65 66 67 68 69 70

            orig_out = _prim2orig(op, *self.prim2orig_args)
            all_ops = [op.type for op in self.main_program.block(0).ops]
            self.assertEqual(sorted(all_ops), sorted(self.all_ops))
            orig_out = flatten(orig_out)
            for k, v in self.out_map.items():
                self.assertEqual(k.shape, orig_out[v].shape)


class TestSubPPrim2Orig(TestAddPPrim2Orig):
71

72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89
    def init_data(self):
        self.op_type = 'sub_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['sub_p', 'elementwise_sub']
        self.out_map = {self.output['Z']: 0}


class TestMulPPrim2Orig(TestAddPPrim2Orig):
90

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
    def init_data(self):
        self.op_type = 'mul_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['mul_p', 'elementwise_mul']
        self.out_map = {self.output['Z']: 0}


class TestDivPPrim2Orig(TestAddPPrim2Orig):
109

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    def init_data(self):
        self.op_type = 'div_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[7, 8], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['div_p', 'elementwise_div']
        self.out_map = {self.output['Z']: 0}


class TestSqrtPPrim2Orig(TestAddPPrim2Orig):
128

129 130 131 132
    def init_data(self):
        self.op_type = 'sqrt_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

133 134 135
        self.input = {
            'X': X,
        }
136 137 138 139 140 141 142 143 144 145 146 147
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, )
        self.all_ops = ['sqrt_p', 'sqrt']
        self.out_map = {self.output['Y']: 0}


class TestTanhPPrim2Orig(TestAddPPrim2Orig):
148

149 150 151 152
    def init_data(self):
        self.op_type = 'tanh_p'
        X = paddle.static.data(name='X', shape=[7, 8], dtype='float64')

153 154 155
        self.input = {
            'X': X,
        }
156 157 158 159 160 161 162 163 164 165 166 167
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, )
        self.all_ops = ['tanh_p', 'tanh']
        self.out_map = {self.output['Y']: 0}


class TestReshapePPrim2Orig(TestAddPPrim2Orig):
168

169 170 171 172
    def init_data(self):
        self.op_type = 'reshape_p'
        X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')

173 174 175
        self.input = {
            'X': X,
        }
176 177 178 179 180 181 182 183 184 185 186 187
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'shape': [4, 4]}

        self.prim2orig_args = (X, )
        self.all_ops = ['reshape_p', 'reshape2']
        self.out_map = {self.output['Y']: 0}


class TestBroadcastPPrim2Orig(TestAddPPrim2Orig):
188

189 190 191 192
    def init_data(self):
        self.op_type = 'broadcast_p'
        X = paddle.static.data(name='X', shape=[2, 8], dtype='float64')

193 194 195
        self.input = {
            'X': X,
        }
196 197 198 199 200 201 202 203 204 205 206 207
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'shape': [10, 2, 8]}

        self.prim2orig_args = (X, )
        self.all_ops = ['broadcast_p', 'expand_v2']
        self.out_map = {self.output['Y']: 0}


class TestTransposePPrim2Orig(TestAddPPrim2Orig):
208

209 210 211 212
    def init_data(self):
        self.op_type = 'transpose_p'
        X = paddle.static.data(name='X', shape=[7, 8, 9, 10], dtype='float64')

213 214 215
        self.input = {
            'X': X,
        }
216 217 218 219 220 221 222 223 224 225 226 227
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'axis': [1, 2, 0, 3]}

        self.prim2orig_args = (X, )
        self.all_ops = ['transpose_p', 'transpose2']
        self.out_map = {self.output['Y']: 0}


class TestSplitPPrim2Orig(TestAddPPrim2Orig):
228

229 230 231 232
    def init_data(self):
        self.op_type = 'split_p'
        X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')

233 234 235
        self.input = {
            'X': X,
        }
236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
        self.output = {
            'YS': [
                self.layer_help.create_variable_for_type_inference(
                    dtype=X.dtype) for i in range(3)
            ]
        }
        self.attrs = {'num_or_sections': [2, 3, 4], 'axis': 1}

        self.prim2orig_args = (X, )
        self.all_ops = ['split_p', 'split']
        self.out_map = {
            self.output['YS'][0]: 0,
            self.output['YS'][1]: 1,
            self.output['YS'][2]: 2,
        }


class TestConcatPPrim2Orig(TestAddPPrim2Orig):
254

255 256 257 258 259 260
    def init_data(self):
        self.op_type = 'concat_p'
        X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[2, 9, 5], dtype='float64')
        Z = paddle.static.data(name='Z', shape=[1, 9, 5], dtype='float64')

261 262 263
        self.input = {
            'XS': [X, Y, Z],
        }
264 265 266 267 268 269 270 271 272 273 274 275
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'axis': 0}

        self.prim2orig_args = ((X, Y, Z), )
        self.all_ops = ['concat_p', 'concat']
        self.out_map = {self.output['Y']: 0}


class TestReducePPrim2Orig(TestAddPPrim2Orig):
276

277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293
    def init_data(self):
        self.op_type = 'reduce_p'
        X = paddle.static.data(name='X', shape=[3, 9, 5], dtype='float64')

        self.input = {'X': X}
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'axis': [1], 'keepdim': True}

        self.prim2orig_args = (X, )
        self.all_ops = ['reduce_p', 'reduce_sum']
        self.out_map = {self.output['Y']: 0}


class TestMatmulPPrim2Orig(TestAddPPrim2Orig):
294

295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
    def init_data(self):
        self.op_type = 'matmul_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[5, 9], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['matmul_p', 'matmul_v2']
        self.out_map = {self.output['Z']: 0}


class TestSliceSelectPPrim2Orig(TestAddPPrim2Orig):
313

314 315 316 317
    def init_data(self):
        self.op_type = 'slice_select_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')

318 319 320
        self.input = {
            'X': X,
        }
321 322 323 324 325 326 327 328 329 330 331 332
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'axis': [0], 'starts': [1], 'ends': [8], 'strides': [2]}

        self.prim2orig_args = (X, )
        self.all_ops = ['slice_select_p', 'strided_slice']
        self.out_map = {self.output['Y']: 0}


class TestSliceAssignPPrim2Orig(TestAddPPrim2Orig):
333

334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
    def init_data(self):
        self.op_type = 'slice_assign_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[9, 3], dtype='float64')

        self.input = {'X': X, 'Y': Y}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
        self.attrs = {'axis': [1], 'starts': [0], 'ends': [3], 'strides': [1]}

        self.prim2orig_args = (X, Y)
        self.all_ops = ['slice_assign_p', 'assign', 'set_value']
        self.out_map = {self.output['Z']: 0}


class TestGatherPPrim2Orig(TestAddPPrim2Orig):
352

353 354 355
    def init_data(self):
        self.op_type = 'gather_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
356 357 358
        IndexTensor = paddle.static.data(name='IndexTensor',
                                         shape=[3],
                                         dtype='int32')
359 360 361 362 363 364

        self.input = {'X': X, 'IndexTensor': IndexTensor}
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
365 366 367
        self.attrs = {
            'axis': 0,
        }
368 369 370

        self.prim2orig_args = (
            IndexTensor,
371 372
            X,
        )
373 374 375 376 377
        self.all_ops = ['gather_p', 'gather']
        self.out_map = {self.output['Y']: 0}


class TestScatterAddPPrim2Orig(TestAddPPrim2Orig):
378

379 380 381 382
    def init_data(self):
        self.op_type = 'scatter_add_p'
        X = paddle.static.data(name='X', shape=[9, 5], dtype='float64')
        Y = paddle.static.data(name='Y', shape=[3, 5], dtype='float64')
383 384 385
        IndexTensor = paddle.static.data(name='IndexTensor',
                                         shape=[3],
                                         dtype='int32')
386 387 388 389 390 391

        self.input = {'X': X, 'Y': Y, 'IndexTensor': IndexTensor}
        self.output = {
            'Z':
            self.layer_help.create_variable_for_type_inference(dtype=X.dtype)
        }
392 393 394
        self.attrs = {
            'axis': 0,
        }
395 396 397 398 399 400 401 402 403

        self.prim2orig_args = (IndexTensor, X, Y)
        self.all_ops = [
            'scatter_add_p', 'fill_any_like', 'scatter', 'elementwise_add'
        ]
        self.out_map = {self.output['Z']: 0}


class TestFillConstantPPrim2Orig(TestAddPPrim2Orig):
404

405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421
    def init_data(self):
        self.op_type = 'fill_constant_p'

        self.input = {}
        self.output = {
            'Y':
            self.layer_help.create_variable_for_type_inference(paddle.int32)
        }
        self.attrs = {'value': 10, 'shape': [5, 5], 'dtype': paddle.int32}

        self.prim2orig_args = ()
        self.all_ops = ['fill_constant_p', 'fill_constant']
        self.out_map = {self.output['Y']: 0}


if __name__ == '__main__':
    unittest.main()