From 513f384c43f5d850fabdfc9ca878ed7cd7f403a3 Mon Sep 17 00:00:00 2001 From: yao_yf Date: Wed, 8 Apr 2020 17:24:22 +0800 Subject: [PATCH] fix auto parallel prelu --- mindspore/ccsrc/parallel/ops_info/prelu_info.cc | 2 +- tests/ut/cpp/parallel/ops_info/prelu_test.cc | 6 ++---- tests/ut/python/parallel/test_prelu.py | 17 +++++++++++++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc index 9aa851333..1a44501f4 100644 --- a/mindspore/ccsrc/parallel/ops_info/prelu_info.cc +++ b/mindspore/ccsrc/parallel/ops_info/prelu_info.cc @@ -52,7 +52,7 @@ Status PReLUInfo::CheckStrategy(const StrategyPtr& strategy) { } return FAILED; } - if ((stra[0][PRELU_CHANNEL_INDEX] != PRELU_CHANNEL_STRATEGY) || (stra[1][0] != PRELU_CHANNEL_STRATEGY)) { + if (stra[0][PRELU_CHANNEL_INDEX] != stra[1][0]) { if (is_auto_parallel_) { MS_LOG(DEBUG) << name_ << ": Invalid channel strategy."; } else { diff --git a/tests/ut/cpp/parallel/ops_info/prelu_test.cc b/tests/ut/cpp/parallel/ops_info/prelu_test.cc index 5ff261234..d6db1b846 100644 --- a/tests/ut/cpp/parallel/ops_info/prelu_test.cc +++ b/tests/ut/cpp/parallel/ops_info/prelu_test.cc @@ -146,11 +146,10 @@ TEST_F(TestPReLUInfo, CheckStrategy1) { } TEST_F(TestPReLUInfo, CheckStrategy2) { - // Success: {{2,1,8,16},{1}} std::vector inputs = {{2, 4, 8, 16}, {4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = prelu->Init(strategy); - ASSERT_EQ(ret, FAILED); + ASSERT_EQ(ret, SUCCESS); } TEST_F(TestPReLUInfo, AutoStrategy1) { @@ -252,11 +251,10 @@ TEST_F(TestPReLUInfo, CheckStrategy_2d1) { } TEST_F(TestPReLUInfo, CheckStrategy_2d2) { - // Success: {{2,1,8,16},{1}} std::vector inputs = {{128, 4}, {4}}; StrategyPtr strategy = NewStrategy(0, inputs); Status ret = prelu_2d->Init(strategy); - ASSERT_EQ(ret, FAILED); + ASSERT_EQ(ret, SUCCESS); } TEST_F(TestPReLUInfo, AutoStrategy_2d1) { diff --git a/tests/ut/python/parallel/test_prelu.py b/tests/ut/python/parallel/test_prelu.py index c60104549..d3ad1cc71 100755 --- a/tests/ut/python/parallel/test_prelu.py +++ b/tests/ut/python/parallel/test_prelu.py @@ -149,3 +149,20 @@ def test_prelu_parallel_success3(): w = Tensor(np.random.rand(16),dtype=ms.float32) net = GradWrap(NetWithLoss(Net(strategy1, strategy2))) _executor.compile(net, x, y, w) + +def test_prelu_parallel_success4(): + class Net(nn.Cell): + def __init__(self, strategy): + super().__init__() + self.prelu = P.PReLU().set_strategy(strategy) + def construct(self, x, y): + out = self.prelu(x, y) + return out + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=64, global_rank=0) + context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") + strategy = ((2, 4, 4, 2), (4, )) + x = Tensor(np.random.rand(4, 16, 32, 64),dtype=ms.float32) + w = Tensor(np.random.rand(16),dtype=ms.float32) + net = GradWrap(NetWithLoss(Net(strategy))) + _executor.compile(net, x, w) -- GitLab