test_reshape.py 22.3 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
J
jinyaohui 已提交
16 17

import mindspore as ms
Z
zhunaipan 已提交
18
import mindspore.nn as nn
J
jinyaohui 已提交
19
from mindspore import Tensor
Z
zhunaipan 已提交
20 21
from mindspore import context
from mindspore.common.api import _executor
J
jinyaohui 已提交
22 23 24 25
from mindspore.common.parameter import Parameter
from mindspore.common.parameter import ParameterTuple
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.optim.momentum import Momentum
Z
zhunaipan 已提交
26 27
from mindspore.ops import composite as C
from mindspore.ops import functional as F
J
jinyaohui 已提交
28 29
from mindspore.ops import operations as P
from mindspore.ops.operations.comm_ops import _VirtualDataset
Z
zhunaipan 已提交
30
from mindspore.parallel import set_algo_parameters
J
jinyaohui 已提交
31 32 33
from mindspore.train import Model, ParallelMode
from tests.dataset_mock import MindData
from tests.ut.python.ops.test_math_ops import VirtualLoss
J
jinyaohui 已提交
34

Z
zhunaipan 已提交
35 36 37
context.set_context(mode=context.GRAPH_MODE)
context.reset_auto_parallel_context()

J
jinyaohui 已提交
38

Z
zhunaipan 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
class Dataset(MindData):
    def __init__(self, predict, label, length=3, input_num=2):
        super(Dataset, self).__init__(size=length)
        self.predict = predict
        self.label = label
        self.index = 0
        self.length = length
        self.input_num = input_num

    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= self.length:
            raise StopIteration
        self.index += 1
        if self.input_num == 2:
Y
Yi Huaijie 已提交
56 57
            return (self.predict, self.label)
        return (self.predict,)
Z
zhunaipan 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

    def reset(self):
        self.index = 0


class ReshapeNet(nn.Cell):
    def __init__(self, strategy0, strategy1, strategy2):
        super(ReshapeNet, self).__init__()
        self.relu = P.ReLU().set_strategy(strategy0)
        self.reshape = P.Reshape().set_strategy(strategy1)
        self.matmul = P.MatMul().set_strategy(strategy2)
        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")

    def construct(self, x):
        x = self.relu(x)
        x = self.reshape(x, (256, 25088))
        x = self.matmul(x, self.matmul_weight)
        return x


def reshape_net(strategy0, strategy1, strategy2):
    return ReshapeNet(strategy0=strategy0, strategy1=strategy1, strategy2=strategy2)


def reshape_common(parallel_mode, strategy0, strategy1, strategy2, strategy_loss):
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2

    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
    predict = Tensor(np.ones([32, 512, 7, 7]), dtype=ms.float32)
    label = Tensor(np.ones([32]), dtype=ms.int32)
    dataset = Dataset(predict, label, 2)
    net = reshape_net(strategy0, strategy1, strategy2)

    loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
    loss.softmax_cross_entropy.set_strategy(strategy_loss)
J
jinyaohui 已提交
96
    loss.one_hot.set_strategy(((8, 1), (), ()))
Z
zhunaipan 已提交
97 98 99 100 101 102
    opt = Momentum(net.trainable_params(), learning_rate, momentum)
    model = Model(net, loss, opt)
    model.train(epoch_size, dataset, dataset_sink_mode=False)


def test_reshape1():
J
jinyaohui 已提交
103
    strategy0 = ((8, 1, 1, 1),)
Z
zhunaipan 已提交
104 105 106 107 108 109 110
    strategy1 = None
    strategy2 = ((8, 1), (1, 1))
    strategy_loss = ((8, 1), (8, 1))
    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)


def test_reshape1_strategy_1():
J
jinyaohui 已提交
111 112
    strategy0 = ((8, 1, 1, 1),)
    strategy1 = ((8, 1, 1, 1),)
Z
zhunaipan 已提交
113 114 115 116
    strategy2 = ((8, 1), (1, 1))
    strategy_loss = ((8, 1), (8, 1))
    try:
        reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
117 118 119 120 121
    except ValueError:
        pass
    except TypeError:
        pass
    except RuntimeError:
Z
zhunaipan 已提交
122 123 124 125
        pass


def test_reshape1_strategy_2():
J
jinyaohui 已提交
126 127
    strategy0 = ((8, 1, 1, 1),)
    strategy1 = ((8, 1, 1, 1),)
Z
zhunaipan 已提交
128 129 130 131
    strategy2 = ((8, 1), (1, 1))
    strategy_loss = ((8, 1), (8, 1))
    try:
        reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)
132 133 134 135 136
    except ValueError:
        pass
    except TypeError:
        pass
    except RuntimeError:
Z
zhunaipan 已提交
137 138 139 140
        pass


def test_reshape2():
J
jinyaohui 已提交
141
    strategy0 = ((8, 1, 1, 1),)
Z
zhunaipan 已提交
142 143 144 145 146 147 148
    strategy1 = None
    strategy2 = ((8, 1), (1, 1))
    strategy_loss = ((8, 1), (8, 1))
    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)


def test_reshape3():
J
jinyaohui 已提交
149
    strategy0 = ((2, 1, 1, 1),)
Z
zhunaipan 已提交
150 151 152 153 154 155 156
    strategy1 = None
    strategy2 = ((8, 1), (1, 1))
    strategy_loss = ((8, 1), (8, 1))
    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)


def test_reshape4():
J
jinyaohui 已提交
157
    strategy0 = ((1, 1, 1, 1),)
Z
zhunaipan 已提交
158 159 160 161 162 163 164
    strategy1 = None
    strategy2 = ((8, 1), (1, 1))
    strategy_loss = ((8, 1), (8, 1))
    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)


def test_reshape5():
J
jinyaohui 已提交
165
    strategy0 = ((2, 1, 1, 1),)
Z
zhunaipan 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313
    strategy1 = None
    strategy2 = ((1, 8), (8, 1))
    strategy_loss = ((8, 1), (8, 1))
    reshape_common(ParallelMode.SEMI_AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)


def test_reshape_auto():
    strategy0 = None
    strategy1 = None
    strategy2 = None
    strategy_loss = None
    reshape_common(ParallelMode.AUTO_PARALLEL, strategy0, strategy1, strategy2, strategy_loss)


class NetWithLoss(nn.Cell):
    def __init__(self, network):
        super(NetWithLoss, self).__init__()
        self.loss = VirtualLoss()
        self.network = network

    def construct(self, x):
        predict = self.network(x)
        return self.loss(predict)


class GradWrap(nn.Cell):
    def __init__(self, network):
        super(GradWrap, self).__init__()
        self.network = network

    def construct(self, x):
        return C.grad_all(self.network)(x)


class ReshapeNet1(nn.Cell):
    def __init__(self, strategy0):
        super(ReshapeNet1, self).__init__()
        self.virtual_dataset = _VirtualDataset()
        self.reshape = P.Reshape()
        self.matmul = P.MatMul().set_strategy(strategy0)
        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
        self.reshape2 = P.Reshape()

    def construct(self, x):
        x = self.virtual_dataset(x)
        x = self.reshape(x, (256, 25088))
        x = self.matmul(x, self.matmul_weight)
        x = self.reshape2(x, (256 * 256,))
        return x


class ReshapeNet2(nn.Cell):
    def __init__(self, strategy0):
        super(ReshapeNet2, self).__init__()
        self.virtual_dataset = _VirtualDataset()
        self.reshape = P.Reshape()
        self.matmul = P.MatMul().set_strategy(strategy0)
        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
        self.reshape2 = P.Reshape()
        self.reduce_sum = P.ReduceSum(keep_dims=True)
        self.reshape3 = P.Reshape()

    def construct(self, x):
        x = self.virtual_dataset(x)
        x = self.reshape(x, (256, 25088))
        x = self.matmul(x, self.matmul_weight)
        x = self.reshape2(x, (256 * 256,))
        x = self.reduce_sum(x, -1)
        x = self.reshape3(x, ())
        return x


class ReshapeNet3(nn.Cell):
    def __init__(self, strategy0):
        super(ReshapeNet3, self).__init__()
        self.virtual_dataset = _VirtualDataset()
        self.reshape = P.Reshape()
        self.matmul = P.MatMul().set_strategy(strategy0)
        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
        self.reshape2 = P.Reshape()
        self.reduce_sum = P.ReduceSum(keep_dims=False)
        self.reshape3 = P.Reshape()

    def construct(self, x):
        x = self.virtual_dataset(x)
        x = self.reshape(x, (256, 25088))
        x = self.matmul(x, self.matmul_weight)
        x = self.reshape2(x, (256 * 256,))
        x = self.reduce_sum(x, -1)
        x = self.reshape3(x, (1, 1))
        return x


class ReshapeNet4(nn.Cell):
    def __init__(self, strategy0):
        super(ReshapeNet4, self).__init__()
        self.virtual_dataset = _VirtualDataset()
        self.reshape = P.Reshape()
        self.reshape2 = P.Reshape()
        self.matmul = P.MatMul().set_strategy(strategy0)
        self.matmul_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")

    def construct(self, x):
        x = self.virtual_dataset(x)
        x = self.reshape(x, (256, 25088))
        w = self.reshape2(self.matmul_weight, (25088, 256))
        x = self.matmul(x, w)
        return x


class ReshapeNet5(nn.Cell):
    def __init__(self, strategy0):
        super(ReshapeNet5, self).__init__()
        self.virtual_dataset = _VirtualDataset()
        self.reshape = P.Reshape()
        self.matmul1 = P.MatMul().set_strategy(strategy0)
        self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
        self.matmul2 = P.MatMul().set_strategy(strategy0)

    def construct(self, x):
        x = self.virtual_dataset(x)
        x = self.reshape(x, (256, 25088))
        matmul1_o = self.matmul1(x, self.matmul1_weight)
        matmul2_o = self.matmul2(matmul1_o, x)
        return matmul2_o


class ReshapeNet6(nn.Cell):
    def __init__(self, strategy0):
        super(ReshapeNet6, self).__init__()
        self.virtual_dataset = _VirtualDataset()
        self.reshape = P.Reshape()
        self.matmul1_1 = P.MatMul().set_strategy(strategy0)
        self.matmul1_2 = P.MatMul().set_strategy(strategy0)
        self.matmul1_weight = Parameter(Tensor(np.ones([25088, 256]), dtype=ms.float32), name="weight")
        self.matmul2 = P.MatMul().set_strategy(strategy0)
        self.add = P.TensorAdd()

    def construct(self, x):
        x = self.virtual_dataset(x)
        x = self.reshape(x, (256, 25088))
        matmul1_1_o = self.matmul1_1(x, self.matmul1_weight)
        matmul1_2_o = self.matmul1_2(x, self.matmul1_weight)
        matmul1_o = self.add(matmul1_1_o, matmul1_2_o)
        matmul2_o = self.matmul2(matmul1_o, x)
        return matmul2_o


Y
Yi Huaijie 已提交
314
def compile_net(net, input_):
Y
yangzhenzhang 已提交
315
    net.set_auto_parallel()
Y
Yi Huaijie 已提交
316
    _executor.compile(net, input_)
Y
yangzhenzhang 已提交
317 318


Z
zhunaipan 已提交
319 320 321 322
def reshape_net2(backbone):
    batch_size = 16
    device_num = 16
    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
Y
Yi Huaijie 已提交
323
    input_ = Tensor(np.ones([batch_size * device_num, 512, 7, 7]).astype(np.float32) * 0.01)
Z
zhunaipan 已提交
324 325 326

    net = GradWrap(NetWithLoss(backbone))
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
J
jinyaohui 已提交
327

Y
Yi Huaijie 已提交
328
    compile_net(net, input_)
Z
zhunaipan 已提交
329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346


def test_reshape_net1_1():
    reshape_net2(ReshapeNet1(((1, 8), (8, 1))))


def test_reshape_net1_2():
    reshape_net2(ReshapeNet1(((1, 8), (8, 2))))


def test_reshape_net2_1():
    reshape_net2(ReshapeNet2(((1, 8), (8, 1))))


def test_reshape_net2_2():
    reshape_net2(ReshapeNet2(((1, 8), (8, 2))))


Y
yao_yf 已提交
347
def _test_reshape_net3_1():
Z
zhunaipan 已提交
348 349 350
    reshape_net2(ReshapeNet3(((1, 8), (8, 1))))


Y
yao_yf 已提交
351
def _test_reshape_net3_2():
Z
zhunaipan 已提交
352 353 354 355 356 357
    reshape_net2(ReshapeNet3(((1, 8), (8, 2))))


def test_reshape_net4_1():
    try:
        reshape_net2(ReshapeNet4(((1, 8), (8, 1))))
358 359 360 361 362
    except ValueError:
        pass
    except TypeError:
        pass
    except RuntimeError:
Z
zhunaipan 已提交
363 364 365 366 367 368
        pass


def test_reshape_net4_2():
    try:
        reshape_net2(ReshapeNet4(((1, 8), (8, 2))))
369 370 371 372 373
    except ValueError:
        pass
    except TypeError:
        pass
    except RuntimeError:
Z
zhunaipan 已提交
374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
        pass


def test_reshape_net5_1():
    reshape_net2(ReshapeNet5(((1, 8), (8, 1))))


def test_reshape_net5_2():
    reshape_net2(ReshapeNet5(((1, 8), (8, 2))))


def test_reshape_net6_1():
    reshape_net2(ReshapeNet6(((1, 8), (8, 1))))


def test_reshape_net6_2():
    reshape_net2(ReshapeNet6(((1, 8), (8, 2))))


class TrainOneStepCell(nn.Cell):
    """
    Network training package class.

    Append an optimizer to the training network after that the construct function
    can be called to create the backward graph.

    Args:
        network (Cell): The training network.
        optimizer (Cell): Optimizer for updating the weights.
        sens (Number): The adjust parameter. Default: 1.0.

    Examples:
        >>> net = Net()
        >>> loss_fn = nn.SoftmaxCrossEntropyWithLogits()
        >>> optim = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        >>> loss_net = WithLossCell(net, loss_fn)
        >>> train_net = TrainOneStepCell(loss_net, optim)
    """
J
jinyaohui 已提交
412

Z
zhunaipan 已提交
413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466
    def __init__(self, network, optimizer, sens=1.0):
        super(TrainOneStepCell, self).__init__(auto_prefix=False)
        self.network = network
        self.network.add_flags(defer_inline=True)
        self.weights = ParameterTuple(network.trainable_params())
        self.optimizer = optimizer
        self.grad = C.GradOperation('grad',
                                    get_by_list=True,
                                    sens_param=True)
        self.sens = sens

    def construct(self, data):
        weights = self.weights
        loss = self.network(data)
        sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
        grads = self.grad(self.network, weights)(data, sens)

        return F.depend(loss, self.optimizer(grads))


def reshape_common2(parallel_mode, net):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2

    predict = Tensor(np.ones([batch_size, 512, 7, 7]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size]), dtype=ms.int32)
    dataset = Dataset(predict, label, 2, input_num=1)
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=16)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
    train_net = TrainOneStepCell(net, opt).set_train()
    model = Model(train_net)
    model.train(epoch_size, dataset, dataset_sink_mode=False)


def test_reshape_common2_0():
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 1))))


def test_reshape_common2_1():
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet1(((1, 8), (8, 2))))


def test_reshape_common2_2():
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 1))))


def test_reshape_common2_3():
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet2(((1, 8), (8, 2))))


Y
yao_yf 已提交
467
def _test_reshape_common2_4():
Z
zhunaipan 已提交
468 469 470
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 1))))


Y
yao_yf 已提交
471
def _test_reshape_common2_5():
Z
zhunaipan 已提交
472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
    reshape_common2(ParallelMode.SEMI_AUTO_PARALLEL, ReshapeNet3(((1, 8), (8, 2))))


class BatchNormReshapeNet(nn.Cell):
    def __init__(self):
        super(BatchNormReshapeNet, self).__init__()
        self.vd = P._VirtualDataset()
        self.batch_norm = nn.BatchNorm1d(512, affine=False)
        self.reshape = P.Reshape()
        self.prelu = nn.PReLU(channel=256)

    def construct(self, x):
        x = self.vd(x)
        x = self.batch_norm(x)
        x = self.reshape(x, (512, 256))
        x = self.prelu(x)
        return x


def test_batchnorm_reshape_train():
    batch_size = 16
    device_num = 16
    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
Y
Yi Huaijie 已提交
496
    input_ = Tensor(np.ones([batch_size * device_num, 512]).astype(np.float32) * 0.01)
Z
zhunaipan 已提交
497 498

    net = GradWrap(NetWithLoss(BatchNormReshapeNet()))
J
jinyaohui 已提交
499

Y
Yi Huaijie 已提交
500
    compile_net(net, input_)
Z
zhunaipan 已提交
501 502 503 504 505 506 507 508 509 510 511 512 513 514


def bn_with_initialize(out_channels):
    bn = nn.BatchNorm2d(out_channels, momentum=0.3, eps=1e-5).add_flags_recursive(fp32=True)
    return bn


def fc_with_initialize(input_channels, out_channels):
    return nn.Dense(input_channels, out_channels).add_flags_recursive(fp16=True)


class BNReshapeDenseBNNet(nn.Cell):
    def __init__(self):
        super(BNReshapeDenseBNNet, self).__init__()
高东海's avatar
高东海 已提交
515
        self.batch_norm = bn_with_initialize(2)
Z
zhunaipan 已提交
516 517 518
        self.reshape = P.Reshape()
        self.cast = P.Cast()
        self.batch_norm2 = nn.BatchNorm1d(512, affine=False)
高东海's avatar
高东海 已提交
519
        self.fc = fc_with_initialize(2 * 32 * 32, 512)
Z
zhunaipan 已提交
520 521 522

    def construct(self, x):
        x = self.batch_norm(x)
J
jinyaohui 已提交
523
        x = self.reshape(x, (16, 2 * 32 * 32))
Z
zhunaipan 已提交
524 525 526 527 528 529 530 531 532
        x = self.fc(x)
        x = self.batch_norm2(x)
        return x


def test_bn_reshape_dense_bn_train():
    batch_size = 16
    device_num = 16
    context.set_auto_parallel_context(device_num=device_num, global_rank=0)
Y
Yi Huaijie 已提交
533
    input_ = Tensor(np.ones([batch_size, 2, 32, 32]).astype(np.float32) * 0.01)
Z
zhunaipan 已提交
534 535 536

    net = GradWrap(NetWithLoss(BNReshapeDenseBNNet()))
    context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
J
jinyaohui 已提交
537

Y
Yi Huaijie 已提交
538
    compile_net(net, input_)
Z
zhunaipan 已提交
539 540 541 542 543 544 545


class ParallelReduceMeanNet(nn.Cell):
    def __init__(self, conv_in_channel, conv_out_channel,
                 reducemean_keep_dims=False, reducemean_axis=-1, strategy=None):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=conv_in_channel, out_channels=conv_out_channel,
J
jinyaohui 已提交
546 547
                              kernel_size=1, stride=1, pad_mode='valid', has_bias=True,
                              weight_init='ones', bias_init='ones')
Z
zhunaipan 已提交
548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582
        self.reduce_mean = P.ReduceMean(keep_dims=reducemean_keep_dims)
        self.flat = nn.Flatten()
        self.reducemean_axis = reducemean_axis
        if strategy is not None:
            self.reduce_mean.set_strategy(strategy)

    def construct(self, inputs):
        x = self.conv(inputs)
        x = self.reduce_mean(x, self.reducemean_axis)
        x = self.flat(x)
        return x


class CrossEntropyLoss(nn.Cell):
    def __init__(self, reduction='mean'):
        super(CrossEntropyLoss, self).__init__()

        self.reduce_mean = P.ReduceMean()
        self.cross_entropy = SoftmaxCrossEntropyWithLogits()
        self.reduction = reduction

    def construct(self, logits, label):
        loss = self.cross_entropy(logits, label)
        if self.reduction == 'mean':
            loss = self.reduce_mean(loss, (-1,))
        return loss


def test_flatten_reshape(parallel_mode="auto_parallel"):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
J
jinyaohui 已提交
583 584
    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3),
                                strategy=((4, 2, 1, 1),))
Z
zhunaipan 已提交
585 586 587 588 589 590
    loss = CrossEntropyLoss()
    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32)
    dataset = Dataset(predict, label, 2, input_num=2)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
J
jinyaohui 已提交
591
    model = Model(net, loss_fn=loss, optimizer=opt)
Z
zhunaipan 已提交
592 593 594 595 596 597 598 599 600 601
    model.train(epoch_size, dataset, dataset_sink_mode=False)


def test_flatten_reshape2(parallel_mode="auto_parallel"):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
602
    set_algo_parameters(fully_use_devices=False)
J
jinyaohui 已提交
603 604
    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_axis=(2, 3),
                                strategy=((4, 1, 1, 1),))
Z
zhunaipan 已提交
605 606 607 608 609 610
    loss = CrossEntropyLoss()
    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size, 64]), dtype=ms.float32)
    dataset = Dataset(predict, label, 2, input_num=2)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
J
jinyaohui 已提交
611
    model = Model(net, loss_fn=loss, optimizer=opt)
Z
zhunaipan 已提交
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643
    model.train(epoch_size, dataset, dataset_sink_mode=False)


class ParallelReshapeNet(nn.Cell):
    def __init__(self, dense_in_channel, dense_out_channel, shape, strategy=None):
        super().__init__()
        self.flat = nn.Flatten()
        self.dense = nn.Dense(in_channels=dense_in_channel,
                              out_channels=dense_out_channel,
                              weight_init='ones',
                              bias_init='ones',
                              has_bias=True)
        self.reshape = P.Reshape()
        self.shape = shape
        self.reshape.set_strategy(strategy)

    def construct(self, inputs):
        x = self.flat(inputs)
        x = self.dense(x)
        x = self.reshape(x, self.shape)
        return x


# the shape of input and output of reshape is the same
# reshape is optimized before step_parallel
def test_flatten_reshape3(parallel_mode="auto_parallel"):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
644
    set_algo_parameters(fully_use_devices=False)
Z
zhunaipan 已提交
645 646 647 648 649 650 651
    net = ParallelReshapeNet(dense_in_channel=2048, dense_out_channel=1000, shape=(128, 1000), strategy=((16, 1),))
    loss = CrossEntropyLoss()
    predict = Tensor(np.ones([batch_size, 1, 2, 1024]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size, 1000]), dtype=ms.float32)
    dataset = Dataset(predict, label, 2, input_num=2)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
J
jinyaohui 已提交
652
    model = Model(net, loss_fn=loss, optimizer=opt)
Z
zhunaipan 已提交
653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672
    model.train(epoch_size, dataset, dataset_sink_mode=False)


class CrossEntropyLoss2(nn.Cell):
    def __init__(self, reduction='mean'):
        super(CrossEntropyLoss2, self).__init__()
        self.cross_entropy = SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logits, label):
        loss = self.cross_entropy(logits, label)
        return loss


def test_flatten_reshape4(parallel_mode="semi_auto_parallel"):
    batch_size = 16
    learning_rate = 0.1
    momentum = 0.9
    epoch_size = 2
    context.reset_auto_parallel_context()
    context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=8)
673
    set_algo_parameters(fully_use_devices=False)
J
jinyaohui 已提交
674 675
    net = ParallelReduceMeanNet(conv_in_channel=3, conv_out_channel=64, reducemean_keep_dims=True,
                                strategy=((4, 1, 1, 1),))
Z
zhunaipan 已提交
676 677 678 679 680 681 682 683
    loss = CrossEntropyLoss2()
    predict = Tensor(np.ones([batch_size, 3, 32, 32]), dtype=ms.float32)
    label = Tensor(np.ones([batch_size, 2048]), dtype=ms.float32)
    dataset = Dataset(predict, label, 2, input_num=2)

    opt = Momentum(net.trainable_params(), learning_rate, momentum)
    model = Model(net, loss_fn=loss, optimizer=opt)
    model.train(epoch_size, dataset, dataset_sink_mode=False)