提交 47508610 编写于 作者: Y yangzhenzhang

fix layernorm bug

上级 348b0ef5
......@@ -69,7 +69,7 @@ Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
}
// check input strategy
for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
if (input_strategy[begin_norm_axis_] != NO_SPLIT_STRATEGY) {
if (input_strategy[i] != NO_SPLIT_STRATEGY) {
MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
return FAILED;
}
......
......@@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum
......@@ -24,7 +25,7 @@ from mindspore.common.initializer import initializer
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.begin_norm_axis = -1
self.begin_norm_axis = 2
self.begin_params_axis = 1
self.mul = P.Mul().set_strategy(strategy1)
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
......@@ -64,18 +65,18 @@ def test_layer_norm_data_parallel():
def test_layer_norm_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1))
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1))
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1))
strategy1 = ((1, 16, 1, 1), (1, 16, 1, 1))
strategy2 = ((1, 16, 1, 1), (16, 1, 1), (16, 1, 1))
strategy3 = ((1, 16, 1, 1), (1, 16, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
def test_layer_norm_hybrid_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy1 = ((2, 8, 1, 1), (2, 8, 1, 1))
strategy2 = ((2, 8, 1, 1), (8, 1, 1), (8, 1, 1))
strategy3 = ((2, 8, 1, 1), (2, 8, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
......@@ -89,8 +90,17 @@ def test_layer_norm_auto_parallel():
def test_layer_norm_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1))
strategy2 = ((2, 2, 1, 1), (2, 1, 1), (2, 1, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
def test_layer_norm_wrong_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 1, 2), (2, 1, 2), (2, 1, 2))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3)
with pytest.raises(RuntimeError):
compile(net)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册