提交 91a0fa75 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5594 add st about loss_scale and parser for pynative mode

Merge pull request !5594 from Simson/push-to-opensource
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.nn.wrap.cell_wrapper import WithLossCell
from mindspore.nn.wrap.loss_scale import TrainOneStepWithLossScaleCell
from mindspore.ops import operations as P
from mindspore.nn.optim import Momentum, RMSProp
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.train import Model
from mindspore.nn.optim import Lamb
from mindspore.train.loss_scale_manager import DynamicLossScaleManager
class MindData:
""" Stub for MindData """
def __init__(self, size=None, batch_size=None, repeat_count=1,
np_types=None, output_shapes=None, input_indexes=(), func_name=''):
self._size = size
self._batch_size = batch_size
self._repeat_count = repeat_count
self._np_types = np_types
self._output_shapes = output_shapes
self._input_indexes = input_indexes
self._func_name = func_name
self._iter_num = 0
def get_dataset_size(self):
return self._size
def get_repeat_count(self):
return self._repeat_count
def get_batch_size(self):
return self._batch_size
def output_types(self):
return self._np_types
def output_shapes(self):
return self._output_shapes
def create_tuple_iterator(self):
return self
@property
def input_indexes(self):
return self._input_indexes
@property
def func_name(self):
return self._func_name
def send(self):
pass
def __len__(self):
return self._size
def __iter__(self):
return self
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
self._iter_num += 1
next_value = []
for shape, typ in zip(self._output_shapes, self._np_types):
next_value.append(Tensor(np.ndarray(shape, typ)))
return tuple(next_value)
def next(self):
return self.__next__()
def reset(self):
self._iter_num = 0
class MindDataSet(MindData):
def __init__(self, dataset_types, dataset_shapes):
super(MindDataSet, self).__init__(size=2, batch_size=32,
np_types=dataset_types,
output_shapes=dataset_shapes,
input_indexes=(0, 1), func_name='')
def __next__(self):
if self._size < self._iter_num:
raise StopIteration
self._iter_num += 1
res = []
for shape, t in zip(self._output_shapes, self._np_types):
res.append(Tensor(np.ones(shape).astype(t)))
return tuple(res)
class NetFP16(nn.Cell):
def __init__(self, in_features, out_features):
super(NetFP16, self).__init__()
self.weight = Parameter(Tensor(np.ones([out_features, in_features]).astype(np.float32)), name="weight")
self.bias = Parameter(Tensor(np.ones([out_features]).astype(np.float32)), name="bias")
self.matmul = P.MatMul()
self.add = P.TensorAdd()
self.cast = P.Cast()
def construct(self, x):
output = self.cast(self.add(self.matmul(self.cast(x, mstype.float16),
self.cast(self.weight, mstype.float16)),
self.cast(self.bias, mstype.float16)), mstype.float32)
return output
def get_axis(x):
shape_op = P.Shape()
shape = shape_op(x)
length = F.tuple_len(shape)
perm = F.make_range(0, length)
return perm
class MSELoss(nn.Cell):
def __init__(self):
super(MSELoss, self).__init__()
self.sum = P.ReduceSum()
self.square = P.Square()
self.reduce_mean = P.ReduceMean()
def construct(self, data, label):
diff = data - label
return self.reduce_mean(self.square(diff), get_axis(diff))
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_loss_scale_fp16_lr_overflow():
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
label = Tensor(np.zeros([16, 16]).astype(np.float32))
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
lr = Tensor(np.ones([1], np.float32) * 0.1)
net = NetFP16(16, 16)
net.set_train()
loss = MSELoss()
optimizer = Momentum(net.trainable_params(), learning_rate=lr, momentum=0.9)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
output_1 = train_network(inputs, label, scaling_sens)
output_2 = train_network(inputs, label, scaling_sens)
assert output_1[0].asnumpy() == output_2[0].asnumpy()
assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_loss_scale_fp16_model_train_overflow():
dataset_types = (np.float32, np.float32)
dataset_shapes = ((16, 16), (16, 16))
dataset = MindDataSet(dataset_types, dataset_shapes)
net = NetFP16(16, 16)
net.set_train()
loss = MSELoss()
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
scale_manager = DynamicLossScaleManager(init_loss_scale=16, scale_factor=2, scale_window=2)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics=None, loss_scale_manager=scale_manager)
model.train(2, dataset, dataset_sink_mode=False)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_loss_scale_fp16_opt_rmsprop_overflow():
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
label = Tensor(np.zeros([16, 16]).astype(np.float32))
scaling_sens = Tensor(np.full(1, np.finfo(np.float32).max), dtype=mstype.float32)
net = NetFP16(16, 16)
net.set_train()
loss = MSELoss()
optimizer = RMSProp(net.trainable_params(), learning_rate=0.1)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
output_1 = train_network(inputs, label, scaling_sens)
output_2 = train_network(inputs, label, scaling_sens)
assert output_1[0].asnumpy() == output_2[0].asnumpy()
assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_loss_scale_fp16_overflow():
inputs = Tensor(np.ones([16, 16]).astype(np.float32))
label = Tensor(np.zeros([16, 16]).astype(np.float32))
scaling_sens = Tensor(np.full((1), np.finfo(np.float32).max), dtype=mstype.float32)
net = NetFP16(16, 16)
net.set_train()
loss = MSELoss()
optimizer = Lamb(net.trainable_params(), learning_rate=0.01)
net_with_loss = WithLossCell(net, loss)
net_with_loss.set_grad()
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
output_1 = train_network(inputs, label, scaling_sens)
output_2 = train_network(inputs, label, scaling_sens)
assert output_1[0].asnumpy() == output_2[0].asnumpy()
assert output_1[1].asnumpy() == output_2[1].asnumpy() == True
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore.nn import ReLU
from mindspore.nn import Cell
from mindspore.common.tensor import Tensor
from mindspore.common.api import ms_function
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_pynative_staging_together():
class NetPynative(Cell):
def __init__(self):
super().__init__()
self.relu = ReLU()
def construct(self, x):
return self.relu(x)
class NetStaging(Cell):
def __init__(self):
super().__init__()
self.relu = ReLU()
@ms_function
def construct(self, x):
return self.relu(x)
input1 = np.random.randn(2, 2).astype(np.float32)
net1 = NetPynative()
out_me_pynative = net1(Tensor(input1)).asnumpy()
net2 = NetStaging()
out_me_staging = net2(Tensor(input1)).asnumpy()
assert np.allclose(out_me_pynative, out_me_staging, 0.001, 0.001)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore as ms
from mindspore.nn import ReLU
from mindspore.nn import Cell
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
import numpy as np
import pytest
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parser_tensor_assign_slice():
class Net(Cell):
def __init__(self, U):
super(Net, self).__init__()
self.relu = ReLU()
self.U = U
def construct(self, x):
x = self.relu(x)
x[..., :2] = U
return x
input_np_x = np.random.rand(4, 4, 4)
input_me_x = Tensor(input_np_x, ms.float32)
U = 1.0
net = Net(U)
out_me = net(input_me_x)
input_np_x[..., :2] = U
assert np.allclose(out_me.asnumpy(), input_np_x, rtol=0.01, atol=0.01)
def test_parser_tensor_assign_slice_002():
class Net(Cell):
def __init__(self, U):
super(Net, self).__init__()
self.relu = ReLU()
self.U = U
def construct(self, x):
x = self.relu(x)
x[::, :, :1] = self.U
return x
input_np_x = np.random.rand(4, 4, 4)
input_me_x = Tensor(input_np_x, ms.float32)
U = 1.0
net = Net(U)
out_me = net(input_me_x)
input_np_x[::, :, :1] = U
assert np.allclose(out_me.asnumpy(), input_np_x, rtol=0.01, atol=0.01)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parser_tensor_assign_bool():
class Net(Cell):
def __init__(self, U):
super(Net, self).__init__()
self.relu = ReLU()
self.U = U
def construct(self, x, tensorB):
x = self.relu(x)
x[tensorB] = self.U
return x
input_np_x = np.random.rand(4, 4, 4)
input_me_x = Tensor(input_np_x, ms.float32)
numpy_B = np.random.randn(4, 4, 4) > 0
tensor_B = Tensor(numpy_B)
U = np.array([1])
net = Net(Tensor(U))
out_me = net(input_me_x, tensor_B)
input_np_x[numpy_B] = U
assert np.allclose(out_me.asnumpy(), input_np_x, rtol=0.01, atol=0.01)
def test_parser_tensor_assign_bool_002():
class Net(Cell):
def __init__(self, U):
super(Net, self).__init__()
self.relu = ReLU()
self.U = U
self.fill = P.Fill()
def construct(self, x, tensorB):
x = self.relu(x)
x[tensorB] = self.U
return x
input_np_x = np.random.rand(2, 2, 2)
input_me_x = Tensor(input_np_x, ms.float32)
numpy_B = np.random.randn(2, 2, 2) > 0
tensor_B = Tensor(numpy_B)
U = 1
net = Net(U)
out_me = net(input_me_x, tensor_B)
input_np_x[numpy_B] = U
assert np.allclose(out_me.asnumpy(), input_np_x, rtol=0.01, atol=0.01)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册