test_strided_slice_op.py 29.3 KB
Newer Older
W
wangchaochaohu 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from op_test import OpTest
import numpy as np
import unittest
18
import paddle.fluid as fluid
19 20 21
import paddle

paddle.enable_static()
W
wangchaochaohu 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60


def strided_slice_native_forward(input, axes, starts, ends, strides):
    dim = input.ndim
    start = []
    end = []
    stride = []
    for i in range(dim):
        start.append(0)
        end.append(input.shape[i])
        stride.append(1)

    for i in range(len(axes)):
        start[axes[i]] = starts[i]
        end[axes[i]] = ends[i]
        stride[axes[i]] = strides[i]

    result = {
        1: lambda input, start, end, stride: input[start[0]:end[0]:stride[0]],
        2: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
                start[1]:end[1]:stride[1]],
        3: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
                start[1]:end[1]:stride[1], start[2]:end[2]:stride[2]],
        4: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
                start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3]],
        5: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
                start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], start[4]:end[4]:stride[4]],
        6: lambda input, start, end, stride: input[start[0]:end[0]:stride[0], \
                start[1]:end[1]:stride[1], start[2]:end[2]:stride[2], start[3]:end[3]:stride[3], \
                start[4]:end[4]:stride[4], start[5]:end[5]:stride[5]]
    }[dim](input, start, end, stride)

    return result


class TestStrideSliceOp(OpTest):
    def setUp(self):
        self.initTestCase()
        self.op_type = 'strided_slice'
61
        self.python_api = paddle.strided_slice
W
wangchaochaohu 已提交
62 63 64 65 66 67 68 69 70
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

        self.inputs = {'Input': self.input}
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            'starts': self.starts,
            'ends': self.ends,
71 72
            'strides': self.strides,
            'infer_flags': self.infer_flags
W
wangchaochaohu 已提交
73 74 75
        }

    def test_check_output(self):
76
        self.check_output(check_eager=True)
W
wangchaochaohu 已提交
77 78

    def test_check_grad(self):
79
        self.check_grad(set(['Input']), 'Out', check_eager=True)
W
wangchaochaohu 已提交
80 81

    def initTestCase(self):
82
        self.input = np.random.rand(100)
W
wangchaochaohu 已提交
83 84 85 86
        self.axes = [0]
        self.starts = [-4]
        self.ends = [-3]
        self.strides = [1]
87
        self.infer_flags = [1]
W
wangchaochaohu 已提交
88 89 90 91


class TestStrideSliceOp1(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
92
        self.input = np.random.rand(100)
W
wangchaochaohu 已提交
93 94 95 96
        self.axes = [0]
        self.starts = [3]
        self.ends = [8]
        self.strides = [1]
97
        self.infer_flags = [1]
W
wangchaochaohu 已提交
98 99 100 101


class TestStrideSliceOp2(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
102
        self.input = np.random.rand(100)
W
wangchaochaohu 已提交
103 104 105 106
        self.axes = [0]
        self.starts = [5]
        self.ends = [0]
        self.strides = [-1]
107
        self.infer_flags = [1]
W
wangchaochaohu 已提交
108 109 110 111


class TestStrideSliceOp3(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
112
        self.input = np.random.rand(100)
W
wangchaochaohu 已提交
113 114 115 116
        self.axes = [0]
        self.starts = [-1]
        self.ends = [-3]
        self.strides = [-1]
117
        self.infer_flags = [1]
W
wangchaochaohu 已提交
118 119 120 121


class TestStrideSliceOp4(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
122
        self.input = np.random.rand(3, 4, 10)
W
wangchaochaohu 已提交
123 124 125 126
        self.axes = [0, 1, 2]
        self.starts = [0, -1, 0]
        self.ends = [2, -3, 5]
        self.strides = [1, -1, 1]
127
        self.infer_flags = [1, 1, 1]
W
wangchaochaohu 已提交
128 129 130 131


class TestStrideSliceOp5(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
132
        self.input = np.random.rand(5, 5, 5)
W
wangchaochaohu 已提交
133 134 135 136
        self.axes = [0, 1, 2]
        self.starts = [1, 0, 0]
        self.ends = [2, 1, 3]
        self.strides = [1, 1, 1]
137
        self.infer_flags = [1, 1, 1]
W
wangchaochaohu 已提交
138 139 140 141


class TestStrideSliceOp6(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
142
        self.input = np.random.rand(5, 5, 5)
W
wangchaochaohu 已提交
143 144 145 146
        self.axes = [0, 1, 2]
        self.starts = [1, -1, 0]
        self.ends = [2, -3, 3]
        self.strides = [1, -1, 1]
147
        self.infer_flags = [1, 1, 1]
W
wangchaochaohu 已提交
148 149 150 151


class TestStrideSliceOp7(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
152
        self.input = np.random.rand(5, 5, 5)
W
wangchaochaohu 已提交
153 154 155 156
        self.axes = [0, 1, 2]
        self.starts = [1, 0, 0]
        self.ends = [2, 2, 3]
        self.strides = [1, 1, 1]
157
        self.infer_flags = [1, 1, 1]
W
wangchaochaohu 已提交
158 159 160 161


class TestStrideSliceOp8(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
162
        self.input = np.random.rand(1, 100, 1)
W
wangchaochaohu 已提交
163 164 165 166
        self.axes = [1]
        self.starts = [1]
        self.ends = [2]
        self.strides = [1]
167
        self.infer_flags = [1]
W
wangchaochaohu 已提交
168 169 170 171


class TestStrideSliceOp9(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
172
        self.input = np.random.rand(1, 100, 1)
W
wangchaochaohu 已提交
173 174 175 176
        self.axes = [1]
        self.starts = [-1]
        self.ends = [-2]
        self.strides = [-1]
177
        self.infer_flags = [1]
W
wangchaochaohu 已提交
178 179 180 181


class TestStrideSliceOp10(TestStrideSliceOp):
    def initTestCase(self):
Z
zhupengyang 已提交
182
        self.input = np.random.rand(10, 10)
W
wangchaochaohu 已提交
183 184 185 186
        self.axes = [0, 1]
        self.starts = [1, 0]
        self.ends = [2, 2]
        self.strides = [1, 1]
187
        self.infer_flags = [1, 1]
W
wangchaochaohu 已提交
188 189 190 191 192 193 194 195 196


class TestStrideSliceOp11(TestStrideSliceOp):
    def initTestCase(self):
        self.input = np.random.rand(3, 3, 3, 4)
        self.axes = [0, 1, 2, 3]
        self.starts = [1, 0, 0, 0]
        self.ends = [2, 2, 3, 4]
        self.strides = [1, 1, 1, 2]
197
        self.infer_flags = [1, 1, 1, 1]
W
wangchaochaohu 已提交
198 199 200 201 202 203 204 205 206


class TestStrideSliceOp12(TestStrideSliceOp):
    def initTestCase(self):
        self.input = np.random.rand(3, 3, 3, 4, 5)
        self.axes = [0, 1, 2, 3, 4]
        self.starts = [1, 0, 0, 0, 0]
        self.ends = [2, 2, 3, 4, 4]
        self.strides = [1, 1, 1, 1, 1]
207
        self.infer_flags = [1, 1, 1, 1]
W
wangchaochaohu 已提交
208 209 210 211 212 213 214 215 216


class TestStrideSliceOp13(TestStrideSliceOp):
    def initTestCase(self):
        self.input = np.random.rand(3, 3, 3, 6, 7, 8)
        self.axes = [0, 1, 2, 3, 4, 5]
        self.starts = [1, 0, 0, 0, 1, 2]
        self.ends = [2, 2, 3, 1, 2, 8]
        self.strides = [1, 1, 1, 1, 1, 2]
217 218 219
        self.infer_flags = [1, 1, 1, 1, 1]


220 221 222 223 224 225 226 227 228 229
class TestStrideSliceOp14(TestStrideSliceOp):
    def initTestCase(self):
        self.input = np.random.rand(4, 4, 4, 4)
        self.axes = [1, 2, 3]
        self.starts = [-5, 0, -7]
        self.ends = [-1, 2, 4]
        self.strides = [1, 1, 1]
        self.infer_flags = [1, 1, 1]


230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
class TestStrideSliceOpBool(TestStrideSliceOp):
    def test_check_grad(self):
        pass


class TestStrideSliceOpBool1D(TestStrideSliceOpBool):
    def initTestCase(self):
        self.input = np.random.rand(100).astype("bool")
        self.axes = [0]
        self.starts = [3]
        self.ends = [8]
        self.strides = [1]
        self.infer_flags = [1]


class TestStrideSliceOpBool2D(TestStrideSliceOpBool):
    def initTestCase(self):
        self.input = np.random.rand(10, 10).astype("bool")
        self.axes = [0, 1]
        self.starts = [1, 0]
        self.ends = [2, 2]
        self.strides = [1, 1]
        self.infer_flags = [1, 1]


class TestStrideSliceOpBool3D(TestStrideSliceOpBool):
    def initTestCase(self):
        self.input = np.random.rand(3, 4, 10).astype("bool")
        self.axes = [0, 1, 2]
        self.starts = [0, -1, 0]
        self.ends = [2, -3, 5]
        self.strides = [1, -1, 1]
        self.infer_flags = [1, 1, 1]


class TestStrideSliceOpBool4D(TestStrideSliceOpBool):
    def initTestCase(self):
        self.input = np.random.rand(3, 3, 3, 4).astype("bool")
        self.axes = [0, 1, 2, 3]
        self.starts = [1, 0, 0, 0]
        self.ends = [2, 2, 3, 4]
        self.strides = [1, 1, 1, 2]
        self.infer_flags = [1, 1, 1, 1]


class TestStrideSliceOpBool5D(TestStrideSliceOpBool):
    def initTestCase(self):
        self.input = np.random.rand(3, 3, 3, 4, 5).astype("bool")
        self.axes = [0, 1, 2, 3, 4]
        self.starts = [1, 0, 0, 0, 0]
        self.ends = [2, 2, 3, 4, 4]
        self.strides = [1, 1, 1, 1, 1]
        self.infer_flags = [1, 1, 1, 1]


class TestStrideSliceOpBool6D(TestStrideSliceOpBool):
    def initTestCase(self):
        self.input = np.random.rand(3, 3, 3, 6, 7, 8).astype("bool")
        self.axes = [0, 1, 2, 3, 4, 5]
        self.starts = [1, 0, 0, 0, 1, 2]
        self.ends = [2, 2, 3, 1, 2, 8]
        self.strides = [1, 1, 1, 1, 1, 2]
        self.infer_flags = [1, 1, 1, 1, 1]


295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
class TestStridedSliceOp_starts_ListTensor(OpTest):
    def setUp(self):
        self.op_type = "strided_slice"
        self.config()

        starts_tensor = []
        for index, ele in enumerate(self.starts):
            starts_tensor.append(("x" + str(index), np.ones(
                (1)).astype('int32') * ele))

        self.inputs = {'Input': self.input, 'StartsTensorList': starts_tensor}
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            'starts': self.starts_infer,
            'ends': self.ends,
            'strides': self.strides,
            'infer_flags': self.infer_flags
        }

    def config(self):
316
        self.input = np.random.random([3, 4, 5, 6]).astype("float64")
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354
        self.starts = [1, 0, 2]
        self.ends = [3, 3, 4]
        self.axes = [0, 1, 2]
        self.strides = [1, 1, 1]
        self.infer_flags = [1, -1, 1]
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

        self.starts_infer = [1, 10, 2]

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['Input'], 'Out', max_relative_error=0.006)


class TestStridedSliceOp_ends_ListTensor(OpTest):
    def setUp(self):
        self.op_type = "strided_slice"
        self.config()

        ends_tensor = []
        for index, ele in enumerate(self.ends):
            ends_tensor.append(("x" + str(index), np.ones(
                (1)).astype('int32') * ele))

        self.inputs = {'Input': self.input, 'EndsTensorList': ends_tensor}
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            'starts': self.starts,
            'ends': self.ends_infer,
            'strides': self.strides,
            'infer_flags': self.infer_flags
        }

    def config(self):
355
        self.input = np.random.random([3, 4, 5, 6]).astype("float64")
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391
        self.starts = [1, 0, 0]
        self.ends = [3, 3, 4]
        self.axes = [0, 1, 2]
        self.strides = [1, 1, 2]
        self.infer_flags = [1, -1, 1]
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

        self.ends_infer = [3, 1, 4]

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['Input'], 'Out', max_relative_error=0.006)


class TestStridedSliceOp_starts_Tensor(OpTest):
    def setUp(self):
        self.op_type = "strided_slice"
        self.config()
        self.inputs = {
            'Input': self.input,
            "StartsTensor": np.array(
                self.starts, dtype="int32")
        }
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            #'starts': self.starts,
            'ends': self.ends,
            'strides': self.strides,
            'infer_flags': self.infer_flags,
        }

    def config(self):
392
        self.input = np.random.random([3, 4, 5, 6]).astype("float64")
393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426
        self.starts = [1, 0, 2]
        self.ends = [2, 3, 4]
        self.axes = [0, 1, 2]
        self.strides = [1, 1, 1]
        self.infer_flags = [-1, -1, -1]
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['Input'], 'Out', max_relative_error=0.006)


class TestStridedSliceOp_ends_Tensor(OpTest):
    def setUp(self):
        self.op_type = "strided_slice"
        self.config()
        self.inputs = {
            'Input': self.input,
            "EndsTensor": np.array(
                self.ends, dtype="int32")
        }
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            'starts': self.starts,
            #'ends': self.ends,
            'strides': self.strides,
            'infer_flags': self.infer_flags,
        }

    def config(self):
427
        self.input = np.random.random([3, 4, 5, 6]).astype("float64")
428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
        self.starts = [1, 0, 2]
        self.ends = [2, 3, 4]
        self.axes = [0, 1, 2]
        self.strides = [1, 1, 1]
        self.infer_flags = [-1, -1, -1]
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['Input'], 'Out', max_relative_error=0.006)


class TestStridedSliceOp_listTensor_Tensor(OpTest):
    def setUp(self):
        self.config()
        ends_tensor = []
        for index, ele in enumerate(self.ends):
            ends_tensor.append(("x" + str(index), np.ones(
                (1)).astype('int32') * ele))
        self.op_type = "strided_slice"

        self.inputs = {
            'Input': self.input,
            "StartsTensor": np.array(
                self.starts, dtype="int32"),
            "EndsTensorList": ends_tensor
        }
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            #'starts': self.starts,
            #'ends': self.ends,
            'strides': self.strides,
            'infer_flags': self.infer_flags,
        }

    def config(self):
468
        self.input = np.random.random([3, 4, 5, 6]).astype("float64")
469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
        self.starts = [1, 0, 2]
        self.ends = [2, 3, 4]
        self.axes = [0, 1, 2]
        self.strides = [1, 1, 1]
        self.infer_flags = [-1, -1, -1]
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['Input'], 'Out', max_relative_error=0.006)


class TestStridedSliceOp_strides_Tensor(OpTest):
    def setUp(self):
        self.op_type = "strided_slice"
        self.config()
        self.inputs = {
            'Input': self.input,
            "StridesTensor": np.array(
                self.strides, dtype="int32")
        }
        self.outputs = {'Out': self.output}
        self.attrs = {
            'axes': self.axes,
            'starts': self.starts,
            'ends': self.ends,
            #'strides': self.strides,
            'infer_flags': self.infer_flags,
        }

    def config(self):
503
        self.input = np.random.random([3, 4, 5, 6]).astype("float64")
504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519
        self.starts = [1, -1, 2]
        self.ends = [2, 0, 4]
        self.axes = [0, 1, 2]
        self.strides = [1, -1, 1]
        self.infer_flags = [-1, -1, -1]
        self.output = strided_slice_native_forward(
            self.input, self.axes, self.starts, self.ends, self.strides)

    def test_check_output(self):
        self.check_output()

    def test_check_grad_normal(self):
        self.check_grad(['Input'], 'Out', max_relative_error=0.006)


# Test python API
520
class TestStridedSliceAPI(unittest.TestCase):
521
    def test_1(self):
522
        input = np.random.random([3, 4, 5, 6]).astype("float64")
523 524 525
        minus_1 = fluid.layers.fill_constant([1], "int32", -1)
        minus_3 = fluid.layers.fill_constant([1], "int32", -3)
        starts = fluid.layers.data(
526
            name='starts', shape=[3], dtype='int32', append_batch_size=False)
527
        ends = fluid.layers.data(
528
            name='ends', shape=[3], dtype='int32', append_batch_size=False)
529
        strides = fluid.layers.data(
530
            name='strides', shape=[3], dtype='int32', append_batch_size=False)
531 532 533 534 535

        x = fluid.layers.data(
            name="x",
            shape=[3, 4, 5, 6],
            append_batch_size=False,
536
            dtype="float64")
537
        out_1 = paddle.strided_slice(
538 539 540 541 542
            x,
            axes=[0, 1, 2],
            starts=[-3, 0, 2],
            ends=[3, 100, -1],
            strides=[1, 1, 1])
543
        out_2 = paddle.strided_slice(
544 545 546 547 548
            x,
            axes=[0, 1, 3],
            starts=[minus_3, 0, 2],
            ends=[3, 100, -1],
            strides=[1, 1, 1])
549
        out_3 = paddle.strided_slice(
550 551 552 553 554
            x,
            axes=[0, 1, 3],
            starts=[minus_3, 0, 2],
            ends=[3, 100, minus_1],
            strides=[1, 1, 1])
555
        out_4 = paddle.strided_slice(
556 557
            x, axes=[0, 1, 2], starts=starts, ends=ends, strides=strides)

558 559 560
        out_5 = x[-3:3, 0:100:2, -1:2:-1]
        out_6 = x[minus_3:3:1, 0:100:2, :, minus_1:2:minus_1]
        out_7 = x[minus_1, 0:100:2, :, -1:2:-1]
561 562 563 564 565 566 567

        exe = fluid.Executor(place=fluid.CPUPlace())
        res_1, res_2, res_3, res_4, res_5, res_6, res_7 = exe.run(
            fluid.default_main_program(),
            feed={
                "x": input,
                'starts': np.array([-3, 0, 2]).astype("int32"),
568
                'ends': np.array([3, 2147483648, -1]).astype("int64"),
569 570 571 572 573 574 575
                'strides': np.array([1, 1, 1]).astype("int32")
            },
            fetch_list=[out_1, out_2, out_3, out_4, out_5, out_6, out_7])
        assert np.array_equal(res_1, input[-3:3, 0:100, 2:-1, :])
        assert np.array_equal(res_2, input[-3:3, 0:100, :, 2:-1])
        assert np.array_equal(res_3, input[-3:3, 0:100, :, 2:-1])
        assert np.array_equal(res_4, input[-3:3, 0:100, 2:-1, :])
576 577 578
        assert np.array_equal(res_5, input[-3:3, 0:100:2, -1:2:-1, :])
        assert np.array_equal(res_6, input[-3:3, 0:100:2, :, -1:2:-1])
        assert np.array_equal(res_7, input[-1, 0:100:2, :, -1:2:-1])
W
wangchaochaohu 已提交
579

580 581 582 583 584 585 586 587 588 589
    def test_dygraph_op(self):
        x = paddle.zeros(shape=[3, 4, 5, 6], dtype="float32")
        axes = [1, 2, 3]
        starts = [-3, 0, 2]
        ends = [3, 2, 4]
        strides_1 = [1, 1, 1]
        sliced_1 = paddle.strided_slice(
            x, axes=axes, starts=starts, ends=ends, strides=strides_1)
        assert sliced_1.shape == (3, 2, 2, 2)

590 591 592 593 594 595 596 597
    @unittest.skipIf(not paddle.is_compiled_with_cuda(),
                     "Cannot use CUDAPinnedPlace in CPU only version")
    def test_cuda_pinned_place(self):
        with paddle.fluid.dygraph.guard():
            x = paddle.to_tensor(
                np.random.randn(2, 10), place=paddle.CUDAPinnedPlace())
            self.assertTrue(x.place.is_cuda_pinned_place())
            y = x[:, ::2]
598
            self.assertFalse(x.place.is_cuda_pinned_place())
599 600
            self.assertFalse(y.place.is_cuda_pinned_place())

W
wangchaochaohu 已提交
601

602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707
class ArrayLayer(paddle.nn.Layer):
    def __init__(self, input_size=224, output_size=10, array_size=1):
        super(ArrayLayer, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.array_size = array_size
        for i in range(self.array_size):
            setattr(self,
                    self.create_name(i),
                    paddle.nn.Linear(input_size, output_size))

    def create_name(self, index):
        return 'linear_' + str(index)

    def forward(self, inps):
        array = []
        for i in range(self.array_size):
            linear = getattr(self, self.create_name(i))
            array.append(linear(inps))

        tensor_array = self.create_tensor_array(array)

        tensor_array = self.array_slice(tensor_array)

        array1 = paddle.concat(tensor_array)
        array2 = paddle.concat(tensor_array[::-1])
        return array1 + array2 * array2

    def get_all_grads(self, param_name='weight'):
        grads = []
        for i in range(self.array_size):
            linear = getattr(self, self.create_name(i))
            param = getattr(linear, param_name)

            g = param.grad
            if g is not None:
                g = g.numpy()

            grads.append(g)

        return grads

    def clear_all_grad(self):
        param_names = ['weight', 'bias']
        for i in range(self.array_size):
            linear = getattr(self, self.create_name(i))
            for p in param_names:
                param = getattr(linear, p)
                param.clear_gradient()

    def array_slice(self, array):
        return array

    def create_tensor_array(self, tensors):
        tensor_array = None
        for i, tensor in enumerate(tensors):
            index = paddle.full(shape=[1], dtype='int64', fill_value=i)
            if tensor_array is None:
                tensor_array = paddle.tensor.array_write(tensor, i=index)
            else:
                paddle.tensor.array_write(tensor, i=index, array=tensor_array)
        return tensor_array


class TestStridedSliceTensorArray(unittest.TestCase):
    def setUp(self):
        paddle.disable_static()

    def grad_equal(self, g1, g2):
        if g1 is None:
            g1 = np.zeros_like(g2)
        if g2 is None:
            g2 = np.zeros_like(g1)
        return np.array_equal(g1, g2)

    def is_grads_equal(self, g1, g2):
        for i, g in enumerate(g1):

            self.assertTrue(
                self.grad_equal(g, g2[i]),
                msg="gradient_1:\n{} \ngradient_2:\n{}".format(g, g2))

    def is_grads_equal_zeros(self, grads):
        for g in grads:
            self.assertTrue(
                self.grad_equal(np.zeros_like(g), g),
                msg="The gradient should be zeros, but received \n{}".format(g))

    def create_case(self, net):
        inps1 = paddle.randn([1, net.input_size], dtype='float32')
        inps2 = inps1.detach().clone()
        l1 = net(inps1)
        s1 = l1.numpy()
        l1.sum().backward()
        grads_dy = net.get_all_grads()
        net.clear_all_grad()
        grads_zeros = net.get_all_grads()

        self.is_grads_equal_zeros(grads_zeros)

        func = paddle.jit.to_static(net.forward)
        l2 = func(inps2)
        s2 = l2.numpy()
        l2.sum().backward()
        grads_static = net.get_all_grads()
        net.clear_all_grad()
708
        # compare result of dygraph and static
709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758
        self.is_grads_equal(grads_static, grads_dy)
        self.assertTrue(
            np.array_equal(s1, s2),
            msg="dygraph graph result:\n{} \nstatic dygraph result:\n{}".format(
                l1.numpy(), l2.numpy()))

    def test_strided_slice_tensor_array_cuda_pinned_place(self):
        if paddle.device.is_compiled_with_cuda():
            with paddle.fluid.dygraph.guard():

                class Simple(paddle.nn.Layer):
                    def __init__(self):
                        super(Simple, self).__init__()

                    def forward(self, inps):
                        tensor_array = None
                        for i, tensor in enumerate(inps):
                            index = paddle.full(
                                shape=[1], dtype='int64', fill_value=i)
                            if tensor_array is None:
                                tensor_array = paddle.tensor.array_write(
                                    tensor, i=index)
                            else:
                                paddle.tensor.array_write(
                                    tensor, i=index, array=tensor_array)

                        array1 = paddle.concat(tensor_array)
                        array2 = paddle.concat(tensor_array[::-1])
                        return array1 + array2 * array2

                net = Simple()
                func = paddle.jit.to_static(net.forward)

                inps1 = paddle.to_tensor(
                    np.random.randn(2, 10),
                    place=paddle.CUDAPinnedPlace(),
                    stop_gradient=False)
                inps2 = paddle.to_tensor(
                    np.random.randn(2, 10),
                    place=paddle.CUDAPinnedPlace(),
                    stop_gradient=False)

                self.assertTrue(inps1.place.is_cuda_pinned_place())
                self.assertTrue(inps2.place.is_cuda_pinned_place())

                result = func([inps1, inps2])

                self.assertFalse(result.place.is_cuda_pinned_place())

    def test_strided_slice_tensor_array(self):
759
        class Net01(ArrayLayer):
760 761 762
            def array_slice(self, tensors):
                return tensors[::-1]

763
        self.create_case(Net01(array_size=10))
764

765
        class Net02(ArrayLayer):
766 767 768
            def array_slice(self, tensors):
                return tensors[::-2]

769
        self.create_case(Net02(input_size=112, array_size=11))
770

771
        class Net03(ArrayLayer):
772 773 774
            def array_slice(self, tensors):
                return tensors[::-3]

775
        self.create_case(Net03(input_size=112, array_size=9))
776

777
        class Net04(ArrayLayer):
778 779 780
            def array_slice(self, tensors):
                return tensors[1::-4]

781
        self.create_case(Net04(input_size=112, array_size=9))
782

783
        class Net05(ArrayLayer):
784 785 786
            def array_slice(self, tensors):
                return tensors[:7:-4]

787
        self.create_case(Net05(input_size=112, array_size=9))
788

789
        class Net06(ArrayLayer):
790 791 792
            def array_slice(self, tensors):
                return tensors[8:0:-4]

793
        self.create_case(Net06(input_size=112, array_size=9))
794

795
        class Net07(ArrayLayer):
796 797 798
            def array_slice(self, tensors):
                return tensors[8:1:-4]

799
        self.create_case(Net07(input_size=112, array_size=9))
800

801
        class Net08(ArrayLayer):
802 803 804
            def array_slice(self, tensors):
                return tensors[::2]

805
        self.create_case(Net08(input_size=112, array_size=11))
806

807
        class Net09(ArrayLayer):
808 809 810
            def array_slice(self, tensors):
                return tensors[::3]

811
        self.create_case(Net09(input_size=112, array_size=9))
812

813
        class Net10(ArrayLayer):
814 815 816
            def array_slice(self, tensors):
                return tensors[1::4]

817
        self.create_case(Net10(input_size=112, array_size=9))
818

819
        class Net11(ArrayLayer):
820 821 822
            def array_slice(self, tensors):
                return tensors[:8:4]

823
        self.create_case(Net11(input_size=112, array_size=9))
824

825
        class Net12(ArrayLayer):
826 827 828
            def array_slice(self, tensors):
                return tensors[1:8:4]

829
        self.create_case(Net12(input_size=112, array_size=9))
830

831
        class Net13(ArrayLayer):
832 833 834
            def array_slice(self, tensors):
                return tensors[8:10:4]

835
        self.create_case(Net13(input_size=112, array_size=13))
836

837
        class Net14(ArrayLayer):
838 839 840
            def array_slice(self, tensors):
                return tensors[3:10:4]

841
        self.create_case(Net14(input_size=112, array_size=13))
842

843
        class Net15(ArrayLayer):
844 845 846
            def array_slice(self, tensors):
                return tensors[2:10:4]

847
        self.create_case(Net15(input_size=112, array_size=13))
848

849
        class Net16(ArrayLayer):
850 851 852
            def array_slice(self, tensors):
                return tensors[3:10:3]

853
        self.create_case(Net16(input_size=112, array_size=13))
854

855
        class Net17(ArrayLayer):
856 857 858
            def array_slice(self, tensors):
                return tensors[3:15:3]

859
        self.create_case(Net17(input_size=112, array_size=13))
860

861
        class Net18(ArrayLayer):
862 863 864
            def array_slice(self, tensors):
                return tensors[0:15:3]

865
        self.create_case(Net18(input_size=112, array_size=13))
866

867
        class Net19(ArrayLayer):
868 869 870
            def array_slice(self, tensors):
                return tensors[-1:-5:-3]

871
        self.create_case(Net19(input_size=112, array_size=13))
872

873
        class Net20(ArrayLayer):
874 875 876
            def array_slice(self, tensors):
                return tensors[-1:-6:-3]

877
        self.create_case(Net20(input_size=112, array_size=13))
878

879
        class Net21(ArrayLayer):
880 881 882
            def array_slice(self, tensors):
                return tensors[-3:-6:-3]

883
        self.create_case(Net21(input_size=112, array_size=13))
884

885
        class Net22(ArrayLayer):
886 887 888
            def array_slice(self, tensors):
                return tensors[-5:-1:3]

889
        self.create_case(Net22(input_size=112, array_size=13))
890

891
        class Net23(ArrayLayer):
892 893 894
            def array_slice(self, tensors):
                return tensors[-6:-1:3]

895
        self.create_case(Net23(input_size=112, array_size=13))
896

897
        class Net24(ArrayLayer):
898 899 900
            def array_slice(self, tensors):
                return tensors[-6:-3:3]

901
        self.create_case(Net24(input_size=112, array_size=13))
902

903
        class Net25(ArrayLayer):
904 905 906
            def array_slice(self, tensors):
                return tensors[0::3]

907
        self.create_case(Net25(input_size=112, array_size=13))
908

909
        class Net26(ArrayLayer):
910 911 912
            def array_slice(self, tensors):
                return tensors[-60:20:3]

913
        self.create_case(Net26(input_size=112, array_size=13))
914

915
        class Net27(ArrayLayer):
916 917 918
            def array_slice(self, tensors):
                return tensors[-3:-60:-3]

919
        self.create_case(Net27(input_size=112, array_size=13))
920 921


W
wangchaochaohu 已提交
922 923
if __name__ == "__main__":
    unittest.main()