soft_entropy_loss_expand_parallel.py 11.1 KB
Newer Older
Z
zhunaipan 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright 2019 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
# ============================================================================
Z
zhunaipan 已提交
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34

import os
import pytest
import numpy as np
import mindspore as ms
from numpy import allclose
from mindspore.nn import Cell
from mindspore import context
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.tensor import Tensor
import mindspore.communication.management as distributedTool
from mindspore.common.parameter import ParameterTuple, Parameter
from mindspore.ops import composite as C
from mindspore.common import dtype as mstype
from mindspore.train import Model, ParallelMode
from mindspore.nn.optim.momentum import Momentum
from mindspore.train.callback import Callback

np.set_printoptions(threshold=np.inf)
35
device_num = 2
Z
zhunaipan 已提交
36 37 38 39 40 41 42
device_id = int(os.getenv('DEVICE_ID'))
rank_id = 0
embed = 128
classes = 32
batch_size = 32*2
MatmulParamShape = (classes, embed)

43

Z
zhunaipan 已提交
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def setup_module():
    global device_num
    global rank_id
    np.random.seed(0)
    context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
    context.set_context(enable_task_sink=True,
                        device_id=device_id)
    context.set_context(enable_ir_fusion=True)
    context.set_context(enable_loop_sink=False)
    distributedTool.init()
    rank_id = distributedTool.get_rank()
    device_num = distributedTool.get_group_size()
    context.set_auto_parallel_context(device_num=device_num,
                                      global_rank=device_id)

59

Z
zhunaipan 已提交
60 61 62
def teardown_module():
    distributedTool.release()

63

Z
zhunaipan 已提交
64 65 66 67 68 69
class DataGenerator():
    def get_parallel_blocks(self, input_, strategy):
        blocks = [input_]
        i = 0
        for stra in strategy:
            temp = []
70
            while len(blocks) > 0:
Z
zhunaipan 已提交
71 72 73
                block = blocks.pop(0)
                temp.extend(np.split(block, stra, axis=i))
            blocks.extend(temp)
74
            i += 1
Z
zhunaipan 已提交
75 76 77 78 79
        return blocks

    def generate_data(self, shape):
        size = np.cumprod(shape)[-1]
        num_range = min(size, 1000)
80
        data = (np.arange(0, size) % num_range)/num_range
Z
zhunaipan 已提交
81 82 83 84 85 86 87 88
        data = np.reshape(data, shape)
        return data

    def input_data(self, shape):
        data = (self.generate_data(shape)*0.1).astype(np.float32)
        stra = [1]*len(shape)
        stra[0] = device_num
        datas = self.get_parallel_blocks(data, stra)
89
        return Tensor(data), Tensor(datas[rank_id])
Z
zhunaipan 已提交
90 91 92 93 94 95

    def label_data(self, shape, embed):
        data = (self.generate_data(shape)*(embed-1)).astype(np.int32)
        stra = [1]*len(shape)
        stra[0] = device_num
        datas = self.get_parallel_blocks(data, stra)
96 97
        return Tensor(data), Tensor(datas[rank_id])

Z
zhunaipan 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124

class Dataset():
    def __init__(self, predict, label, length=1, input_num=2):
        self.predict = predict
        self.label = label
        self.index = 0
        self.length = length
        self.input_num = input_num

    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= self.length:
            raise StopIteration
        self.index += 1
        if self.input_num == 2:
            return self.predict, self.label
        else:
            return self.predict,

    def reset(self):
        self.index = 0

    def get_dataset_size(self):
        return self.length

L
lichenever 已提交
125 126 127
    def get_repeat_count(self):
        return self.length

128

Z
zhunaipan 已提交
129 130 131 132
class ModelCallback(Callback):
    def __init__(self):
        super(ModelCallback, self).__init__()
        self.loss_list = []
133

Z
zhunaipan 已提交
134 135 136 137 138
    def epoch_end(self, run_context, *args):
        cb_params = run_context.original_args()
        result = cb_params.net_outputs
        self.loss_list.append(result.asnumpy().mean())

139

Z
zhunaipan 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
class SoftmaxCrossEntropyExpand(Cell):
    def __init__(self, sparse=False, stra_list=[]):
        super(SoftmaxCrossEntropyExpand, self).__init__()
        if len(stra_list) < 11:
            stra_list = [None]*11
        self.exp = P.Exp()
        self.reduce_sum = P.ReduceSum(keep_dims=True).set_strategy(strategy=stra_list[1])
        self.onehot = P.OneHot().set_strategy(strategy=stra_list[2])
        self.on_value = Tensor(1.0, mstype.float32)
        self.off_value = Tensor(0.0, mstype.float32)
        self.div = P.Div().set_strategy(strategy=stra_list[3])
        self.log = P.Log().set_strategy(strategy=stra_list[4])
        self.sum_cross_entropy = P.ReduceSum(keep_dims=False).set_strategy(strategy=stra_list[5])
        self.mul = P.Mul().set_strategy(strategy=stra_list[6])
        self.mul2 = P.Mul().set_strategy(strategy=stra_list[7])
        self.cast = P.Cast()
        self.reduce_mean = P.ReduceMean(keep_dims=False).set_strategy(strategy=stra_list[8])
        self.sparse = sparse
        self.reduce_max = P.ReduceMax(keep_dims=True).set_strategy(strategy=stra_list[9])
        self.sub = P.Sub().set_strategy(strategy=stra_list[10])

    def construct(self, logit, label):
        logit_max = self.reduce_max(logit, -1)
        exp = self.exp(self.sub(logit, logit_max))
        exp_sum = self.reduce_sum(exp, -1)
        softmax_result = self.div(exp, exp_sum)
        if self.sparse:
            label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
        softmax_result_log = self.log(softmax_result)
        loss = self.sum_cross_entropy((self.mul(softmax_result_log, label)), -1)
        loss = self.mul2(F.scalar_to_array(-1.0), loss)
        loss = self.reduce_mean(loss, -1)
        return loss

174

Z
zhunaipan 已提交
175
class MatmulNet(Cell):
176
    def __init__(self, matmul_stra=None, loss_stra_list=[]):
Z
zhunaipan 已提交
177 178 179
        super(MatmulNet, self).__init__()
        self.matmul = P.MatMul(transpose_b=True).set_strategy(strategy=matmul_stra)
        self.loss = SoftmaxCrossEntropyExpand(sparse=True, stra_list=loss_stra_list)
180 181
        self.weight = Parameter(Tensor(np.ones(MatmulParamShape), dtype=ms.float32), name="weight")

Z
zhunaipan 已提交
182 183 184 185 186
    def construct(self, x, label):
        loss_input = self.matmul(x, self.weight)
        out = self.loss(loss_input, label)
        return out

187

Z
zhunaipan 已提交
188 189 190 191
class LossFactory():
    def __init__(self):
        dataGen = DataGenerator()
        self.input_full, self.input_part = dataGen.input_data((batch_size, embed))
192
        self.label_full, self.label_part = dataGen.label_data((batch_size,), embed)
Z
zhunaipan 已提交
193 194 195 196 197 198 199 200 201

    def single_matmul_trains(self):
        single_callback = ModelCallback()
        net = MatmulNet()
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_full, self.label_full)
        model.train(epoch_size, dataset, callbacks=single_callback, dataset_sink_mode=False)
L
lichenever 已提交
202 203
        loss_value = np.array(single_callback.loss_list)
        return loss_value
Z
zhunaipan 已提交
204 205 206 207 208

    def data_parallel_matmul_trains(self):
        parallel_callback = ModelCallback()
        context.set_auto_parallel_context(parallel_mode="semi_auto_parallel")
        net = MatmulNet()
209
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
Z
zhunaipan 已提交
210 211 212 213 214
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_part, self.label_part)
        model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
        loss_value = np.array(parallel_callback.loss_list)
L
lichenever 已提交
215
        return loss_value
216

Z
zhunaipan 已提交
217 218
    def model_parallel_matmul_trains(self):
        parallel_callback = ModelCallback()
219 220 221 222 223 224 225 226 227 228
        matmul_stra = ((1, 1), (device_num, 1))
        reduce_max_stra = ((1, device_num),)
        sub_stra = ((1, device_num), (1, 1))
        exp_stra = ((1, device_num),)
        reduce_sum_stra = ((1, device_num),)
        div_stra = ((1, device_num), (1, 1))
        log_stra = ((1, device_num),)
        mul_stra = ((1, device_num), (1, device_num))
        sum_cross_entropy_stra = ((1, device_num),)
        mul2_stra = ((), (device_num,))
Z
zhunaipan 已提交
229
        reduce_mean_stra = ((device_num,),)
230 231 232
        onehot_stra = ((1, device_num), (), ())
        loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra,
                          sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra]
Z
zhunaipan 已提交
233
        context.set_auto_parallel_context(parallel_mode="auto_parallel")
234 235
        net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list)
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
Z
zhunaipan 已提交
236 237 238 239 240
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_part, self.label_part)
        model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
        loss_value = np.array(parallel_callback.loss_list)
L
lichenever 已提交
241
        return loss_value
Z
zhunaipan 已提交
242 243 244

    def mix_parallel_matmul_trains(self):
        parallel_callback = ModelCallback()
245 246 247 248 249 250 251 252 253 254
        matmul_stra = ((device_num, 1), (1, 1))
        reduce_max_stra = ((1, device_num),)
        sub_stra = ((device_num, 1), (device_num, 1))
        exp_stra = ((1, device_num),)
        reduce_sum_stra = ((1, device_num),)
        div_stra = ((1, device_num), (1, 1))
        log_stra = ((1, device_num),)
        mul_stra = ((1, device_num), (1, device_num))
        sum_cross_entropy_stra = ((1, device_num),)
        mul2_stra = ((), (device_num,))
Z
zhunaipan 已提交
255
        reduce_mean_stra = ((device_num,),)
256 257 258
        onehot_stra = ((1, device_num), (), ())
        loss_stra_list = [exp_stra, reduce_sum_stra, onehot_stra, div_stra, log_stra,
                          sum_cross_entropy_stra, mul_stra, mul2_stra, reduce_mean_stra, reduce_max_stra, sub_stra]
Z
zhunaipan 已提交
259
        context.set_auto_parallel_context(parallel_mode="auto_parallel")
260
        net = MatmulNet(matmul_stra=matmul_stra, loss_stra_list=loss_stra_list)
Z
zhunaipan 已提交
261 262 263 264 265 266
        optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
        model = Model(net, optimizer=optimizer)
        epoch_size = 6
        dataset = Dataset(self.input_part, self.label_part)
        model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
        loss_value = np.array(parallel_callback.loss_list)
L
lichenever 已提交
267
        return loss_value
Z
zhunaipan 已提交
268

269

L
lichenever 已提交
270
def test_all_trains():
Z
zhunaipan 已提交
271 272
    loss_factory = LossFactory()
    context.reset_auto_parallel_context()
L
lichenever 已提交
273 274 275 276 277
    single_loss = loss_factory.single_matmul_trains()
    model_parallel_loss = loss_factory.model_parallel_matmul_trains()
    mix_parallel_loss = loss_factory.mix_parallel_matmul_trains()
    assert allclose(single_loss, model_parallel_loss)
    assert allclose(single_loss, mix_parallel_loss)