提交 26a65bee 编写于 作者: L lichenever 提交者: 高东海

fix auto parallel st

上级 d3400cde
......@@ -130,9 +130,7 @@ class OneHotFactory:
context.reset_auto_parallel_context()
assert np.allclose(out_mindspore_single, out_mindspore_parallel, 0.0001, 0.0001)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_reid_onehot_forward_int32_128_depth1024_model_parallel():
fact = OneHotFactory(batch_size=128,
classes=1024,
......@@ -142,9 +140,7 @@ def test_reid_onehot_forward_int32_128_depth1024_model_parallel():
strategy=((1,device_num),(),()))
fact.forward_cmp()
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_reid_onehot_forward_int32_1024_depth128_model_parallel():
fact = OneHotFactory(batch_size=1024,
classes=128,
......@@ -153,4 +149,3 @@ def test_reid_onehot_forward_int32_1024_depth128_model_parallel():
axis=-1,
strategy=((1,device_num),(),()))
fact.forward_cmp()
......@@ -18,7 +18,6 @@ BASE_PATH=$(cd "$(dirname $0)"; pwd)
CONFIG_PATH=/home/workspace/mindspore_config
export DEVICE_NUM=8
export RANK_SIZE=$DEVICE_NUM
ulimit -n 65535
source ${BASE_PATH}/env.sh
unset SLOG_PRINT_TO_STDOUT
export MINDSPORE_HCCL_CONFIG_PATH=$CONFIG_PATH/hccl/rank_table_${DEVICE_NUM}p.json
......@@ -27,7 +26,7 @@ process_pid=()
for((i=0; i<$DEVICE_NUM; i++)); do
rm -rf ${BASE_PATH}/loss_expand${i}
mkdir ${BASE_PATH}/loss_expand${i}
cp -r soft_entropy_loss_expand_parallel.py ${BASE_PATH}/loss_expand${i}/
cp -r ${BASE_PATH}/soft_entropy_loss_expand_parallel.py ${BASE_PATH}/loss_expand${i}/
cd ${BASE_PATH}/loss_expand${i}
export RANK_ID=${i}
export DEVICE_ID=${i}
......
......@@ -27,7 +27,7 @@ process_pid=()
for((i=0; i<$DEVICE_NUM; i++)); do
rm -rf ${BASE_PATH}/resnet50_expand_loss${i}
mkdir ${BASE_PATH}/resnet50_expand_loss${i}
cp -r resnet50_expand_loss.py ${BASE_PATH}/resnet50_expand_loss${i}/
cp -r ${BASE_PATH}/resnet50_expand_loss.py ${BASE_PATH}/resnet50_expand_loss${i}/
cd ${BASE_PATH}/resnet50_expand_loss${i}
export RANK_ID=${i}
export DEVICE_ID=${i}
......
......@@ -27,7 +27,7 @@ process_pid=()
for((i=0; i<$DEVICE_NUM; i++)); do
rm -rf ${BASE_PATH}/onehot_model_parallel${i}
mkdir ${BASE_PATH}/onehot_model_parallel${i}
cp -r onehot_model_parallel.py ${BASE_PATH}/onehot_model_parallel${i}/
cp -r ${BASE_PATH}/onehot_model_parallel.py ${BASE_PATH}/onehot_model_parallel${i}/
cd ${BASE_PATH}/onehot_model_parallel${i}
export RANK_ID=${i}
export DEVICE_ID=${i}
......
......@@ -118,6 +118,9 @@ class Dataset():
def get_dataset_size(self):
return self.length
def get_repeat_count(self):
return self.length
class ModelCallback(Callback):
def __init__(self):
super(ModelCallback, self).__init__()
......@@ -177,7 +180,6 @@ class LossFactory():
dataGen = DataGenerator()
self.input_full, self.input_part = dataGen.input_data((batch_size, embed))
self.label_full, self.label_part = dataGen.label_data((batch_size,),embed)
self.expect_out = np.array([0.9205861 , 0.9205861 , 0.9205861 , 0.9201946 , 0.91951686, 0.919343])
def single_matmul_trains(self):
single_callback = ModelCallback()
......@@ -187,7 +189,8 @@ class LossFactory():
epoch_size = 6
dataset = Dataset(self.input_full, self.label_full)
model.train(epoch_size, dataset, callbacks=single_callback, dataset_sink_mode=False)
print("---loss---",single_callback.loss_list)
loss_value = np.array(single_callback.loss_list)
return loss_value
def data_parallel_matmul_trains(self):
parallel_callback = ModelCallback()
......@@ -199,7 +202,7 @@ class LossFactory():
dataset = Dataset(self.input_part, self.label_part)
model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
loss_value = np.array(parallel_callback.loss_list)
assert allclose(loss_value, self.expect_out, 0.00001, 0.00001)
return loss_value
def model_parallel_matmul_trains(self):
parallel_callback = ModelCallback()
......@@ -224,7 +227,7 @@ class LossFactory():
dataset = Dataset(self.input_part, self.label_part)
model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
loss_value = np.array(parallel_callback.loss_list)
assert allclose(loss_value, self.expect_out, 0.00001, 0.00001)
return loss_value
def mix_parallel_matmul_trains(self):
parallel_callback = ModelCallback()
......@@ -249,28 +252,13 @@ class LossFactory():
dataset = Dataset(self.input_part, self.label_part)
model.train(epoch_size, dataset, callbacks=parallel_callback, dataset_sink_mode=False)
loss_value = np.array(parallel_callback.loss_list)
assert allclose(loss_value, self.expect_out, 0.00001, 0.00001)
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_matmul_loss_data_parallel_trains():
loss_factory = LossFactory()
context.reset_auto_parallel_context()
loss_factory.data_parallel_matmul_trains()
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_matmul_loss_model_parallel_trains():
loss_factory = LossFactory()
context.reset_auto_parallel_context()
loss_factory.model_parallel_matmul_trains()
return loss_value
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_matmul_loss_mix_parallel_trains():
def test_all_trains():
loss_factory = LossFactory()
context.reset_auto_parallel_context()
loss_factory.mix_parallel_matmul_trains()
single_loss = loss_factory.single_matmul_trains()
model_parallel_loss = loss_factory.model_parallel_matmul_trains()
mix_parallel_loss = loss_factory.mix_parallel_matmul_trains()
assert allclose(single_loss, model_parallel_loss)
assert allclose(single_loss, mix_parallel_loss)
......@@ -18,7 +18,9 @@ import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.env_single
def test_expand_loss():
ret = os.system("sh run_auto_parallel_loss_expand.sh")
sh_path = os.path.split(os.path.realpath(__file__))[0]
ret = os.system(f"sh {sh_path}/run_auto_parallel_loss_expand.sh")
assert(ret==0)
......@@ -16,9 +16,6 @@
import os
import pytest
@pytest.mark.level0
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single
def test_expand_loss():
ret = os.system("sh run_onehot_model_parallel.sh")
assert(ret==0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册