提交 49b8a084 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!334 Add parallel operator for LayerNorm

Merge pull request !334 from yangzhenzhang/layernorm
......@@ -101,6 +101,7 @@ REGISTER(CosInfo);
REGISTER(ACosInfo);
REGISTER(LogicalNotInfo);
REGISTER(L2NormalizeInfo);
REGISTER(LayerNormInfo);
REGISTER(ReduceMaxInfo);
REGISTER(ArgMaxWithValueInfo);
REGISTER(ArgMinWithValueInfo);
......
......@@ -195,8 +195,8 @@ Status Softmax::GetAttrs() {
// for example: tensor dimension is 4, then axis range [-4, 3]
int32_t dim = SizeToInt(inputs_shape_.at(0).size());
auto it = std::find_if(axis_.begin(), axis_.end(),
[dim](const int32_t& element) { return ((element >= dim) || (element < -dim)); });
auto it =
std::find_if(axis_.begin(), axis_.end(), [dim](int32_t element) { return ((element >= dim) || (element < -dim)); });
if (it != axis_.end()) {
MS_LOG(ERROR) << name_ << " : The axis(" << *it << ") is out of range[" << -dim << ", " << dim - 1 << "].";
return FAILED;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "parallel/ops_info/layer_norm_info.h"
#include <algorithm>
#include <vector>
#include "parallel/device_matrix.h"
#include "parallel/strategy.h"
namespace mindspore {
namespace parallel {
Status LayerNormInfo::GetAttrs() {
auto iter = attrs_.find(BEGIN_NORM_AXIS);
if (iter == attrs_.end()) {
MS_LOG(ERROR) << name_ << ": Can not find the attr of begin norm axis";
return FAILED;
}
if ((iter->second == nullptr) || !iter->second->isa<Int32Imm>()) {
MS_LOG(ERROR) << name_ << ": The axis type is not int";
return FAILED;
}
int32_t dim = SizeToInt(input_shape_.size());
auto axis = GetValue<int32_t>(iter->second);
if ((axis >= dim) || (axis < -dim)) {
MS_LOG(ERROR) << name_ << ": The axis(" << axis << ") is out of range[" << -dim << ", " << dim - 1 << "]";
return FAILED;
}
if (axis < 0) {
axis = axis + dim;
}
begin_norm_axis_ = IntToSize(axis);
return SUCCESS;
}
Status LayerNormInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.size() != LAYER_NORM_INPUT_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid strategy size " << stra.size();
return FAILED;
}
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy value";
return FAILED;
}
Dimensions input_strategy = stra[LAYER_NORM_INPUT_INDEX];
Dimensions gamma_strategy = stra[LAYER_NORM_GAMMA_INDEX];
Dimensions beta_strategy = stra[LAYER_NORM_BETA_INDEX];
if (begin_norm_axis_ >= input_strategy.size()) {
MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_;
return FAILED;
}
// check input strategy
for (size_t i = begin_norm_axis_; i < input_strategy.size(); ++i) {
if (input_strategy[begin_norm_axis_] != NO_SPLIT_STRATEGY) {
MS_LOG(ERROR) << name_ << ": Invalid input strategy " << ShapeToString(input_strategy);
return FAILED;
}
}
// check gamma and beta strategy
if ((gamma_strategy.size() > input_strategy.size()) || (beta_strategy.size() > input_strategy.size())) {
MS_LOG(ERROR) << name_ << " : The strategy size of gamma or beta is lager than input strategy";
return FAILED;
}
size_t gamma_diff = input_strategy.size() - gamma_strategy.size();
for (size_t j = 0; j < gamma_strategy.size(); ++j) {
if (gamma_strategy[j] != input_strategy[gamma_diff + j]) {
MS_LOG(ERROR) << name_ << ": Invalid gamma strategy " << ShapeToString(gamma_strategy);
return FAILED;
}
}
size_t beta_diff = input_strategy.size() - beta_strategy.size();
for (size_t k = 0; k < beta_strategy.size(); ++k) {
if (beta_strategy[k] != input_strategy[beta_diff + k]) {
MS_LOG(ERROR) << name_ << ": Invalid beta strategy " << ShapeToString(beta_strategy);
return FAILED;
}
}
return SUCCESS;
}
Status LayerNormInfo::InferDevMatrixShape() {
if (strategy_ == nullptr) {
MS_LOG(ERROR) << name_ << ": The strategy is null";
return FAILED;
}
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
Status LayerNormInfo::CreateTensorMap(size_t input_index) {
if (inputs_shape_.size() <= input_index) {
MS_LOG(ERROR) << name_ << ": Invalid index" << input_index;
return FAILED;
}
Shape shape = inputs_shape_[input_index];
Shape tensor_map;
for (size_t i = 0; i < shape.size(); ++i) {
tensor_map.push_back(SizeToInt(shape.size() - i - 1));
}
inputs_tensor_map_.push_back(tensor_map);
outputs_tensor_map_.push_back(tensor_map);
return SUCCESS;
}
Status LayerNormInfo::InferTensorMap() {
if ((CreateTensorMap(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorMap(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateTensorMap(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create tensor map failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::CreateMirrorOp(size_t input_index) {
if (inputs_tensor_map_.size() <= input_index) {
MS_LOG(ERROR) << name_ << ": Invalid index " << input_index;
return FAILED;
}
Shape tensor_map = inputs_tensor_map_[input_index];
std::vector<Group> group;
if (CreateGroupByTensorMap(tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Create group for input " << input_index << " failed";
return FAILED;
}
OperatorVector mirror_op;
if (!group.empty()) {
mirror_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
MS_LOG(INFO) << name_ << " : Create the mirror ops for input " << input_index << " success, group is "
<< group[0].name();
}
mirror_ops_.push_back(mirror_op);
return SUCCESS;
}
Status LayerNormInfo::InferMirrorOps() {
if ((CreateMirrorOp(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateMirrorOp(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateMirrorOp(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create mirror op failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::CreateTensorInfo(size_t input_index) {
if ((inputs_shape_.size() <= input_index) || (inputs_tensor_map_.size() <= input_index)) {
MS_LOG(ERROR) << name_ << ": Invalid input index" << input_index;
return FAILED;
}
Shape tensor_map = inputs_tensor_map_[input_index];
Shape shape = inputs_shape_[input_index];
TensorLayout tensor_layout;
if (tensor_layout.InitFromVector(dev_matrix_shape_, tensor_map, shape) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init tensor layout for input " << input_index << " failed";
return FAILED;
}
TensorInfo tensor_info(tensor_layout);
inputs_tensor_info_.push_back(tensor_info);
outputs_tensor_info_.push_back(tensor_info);
return SUCCESS;
}
Status LayerNormInfo::InferTensorInfo() {
if ((CreateTensorInfo(LAYER_NORM_INPUT_INDEX) != SUCCESS) || (CreateTensorInfo(LAYER_NORM_GAMMA_INDEX) != SUCCESS) ||
(CreateTensorInfo(LAYER_NORM_BETA_INDEX) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Create tensor info failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::InferAsLossDivisor() {
if (outputs_tensor_map_.size() != LAYER_NORM_INPUT_SIZE) {
MS_LOG(ERROR) << name_ << ": The size of outputs tensor map " << outputs_tensor_map_.size() << " is error";
return FAILED;
}
as_loss_divisor_ = ComputeRepeatDeviceNumByTensorMap(dev_matrix_shape_, outputs_tensor_map_[0]);
MS_LOG(INFO) << name_ << " : The dev matrix shape is " << ShapeToString(dev_matrix_shape_)
<< ", the output[0]'s tensor map is " << ShapeToString(outputs_tensor_map_[0])
<< ", as_loss_divisor_ is " << as_loss_divisor_;
return SUCCESS;
}
Status LayerNormInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << " : Set cost failed";
return FAILED;
}
return SUCCESS;
}
Status LayerNormInfo::GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr> &sp_vector) {
if ((gamma_shape_.size() > input_shape_.size()) || (beta_shape_.size() > input_shape_.size())) {
MS_LOG(ERROR) << name_ << ": The dimension of gamma or beta is lager than input";
return FAILED;
}
size_t gamma_diff = input_shape_.size() - gamma_shape_.size();
size_t beta_diff = input_shape_.size() - beta_shape_.size();
for (auto &sp : sp_vector) {
if ((sp == nullptr) || sp->GetInputDim().empty()) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> tmp_strategy;
Dimensions input_strategy = sp->GetInputDim()[0];
Dimensions gamma_strategy = input_strategy;
(void)gamma_strategy.erase(gamma_strategy.begin(),
gamma_strategy.begin() + static_cast<different_type>(gamma_diff));
Dimensions beta_strategy = input_strategy;
(void)beta_strategy.erase(beta_strategy.begin(), beta_strategy.begin() + static_cast<different_type>(beta_diff));
// reset the strategy
tmp_strategy.push_back(input_strategy);
tmp_strategy.push_back(gamma_strategy);
tmp_strategy.push_back(beta_strategy);
sp->ResetInputs(tmp_strategy);
}
return SUCCESS;
}
Status LayerNormInfo::GenerateStrategies(int32_t stage_id) {
if (InitShapes() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init shapes failed";
return FAILED;
}
if (GetAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Get attrs failed";
return FAILED;
}
Shape input_split(input_shape_.size(), SPLIT_FLAG);
if (begin_norm_axis_ >= input_split.size()) {
MS_LOG(ERROR) << name_ << ": Invalid begin norm axis " << begin_norm_axis_;
return FAILED;
}
// Can not split the dimensions from begin norm axis
for (size_t i = begin_norm_axis_; i < input_split.size(); ++i) {
input_split[i] = NO_SPLIT_FLAG;
}
// Generate strategy for input
Shapes splittable_inputs = {input_split};
Shapes tmp_inputs_shape = {input_shape_};
std::vector<StrategyPtr> sp_vector;
is_auto_parallel_ = true;
if (GenerateStrategiesForIndependentInputs(stage_id, tmp_inputs_shape, splittable_inputs, &sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Generate input strategy failed";
return FAILED;
}
// Generate the strategies for gamma and beta
if (GenerateGammaAndBetaStrategies(sp_vector) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Generate gamma and beta strategies failed";
return FAILED;
}
size_t success = 0;
for (auto &sp : sp_vector) {
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(DEBUG) << name_ << ": Successfully generated " << success << " strategy";
}
}
return SUCCESS;
}
Status LayerNormInfo::InitShapes() {
if (inputs_shape_.size() != LAYER_NORM_INPUT_SIZE) {
MS_LOG(ERROR) << name_ << ": Invalid inputs size";
return FAILED;
}
input_shape_ = inputs_shape_[LAYER_NORM_INPUT_INDEX];
gamma_shape_ = inputs_shape_[LAYER_NORM_GAMMA_INDEX];
beta_shape_ = inputs_shape_[LAYER_NORM_BETA_INDEX];
return SUCCESS;
}
Status LayerNormInfo::Init(const StrategyPtr &strategy) {
if ((InitShapes() != SUCCESS) || (InitWithAutoRepeatCalc(strategy)) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success";
return SUCCESS;
}
Status LayerNormInfo::InitForCostModel(const StrategyPtr &strategy) {
if ((InitShapes() != SUCCESS) || (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS)) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_
#define MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "parallel/auto_parallel/operator_costmodel.h"
#include "parallel/ops_info/operator_info.h"
#include "parallel/strategy.h"
namespace mindspore {
namespace parallel {
constexpr size_t LAYER_NORM_INPUT_SIZE = 3;
constexpr size_t LAYER_NORM_INPUT_INDEX = 0;
constexpr size_t LAYER_NORM_GAMMA_INDEX = 1;
constexpr size_t LAYER_NORM_BETA_INDEX = 2;
constexpr char BEGIN_NORM_AXIS[] = "begin_norm_axis";
// The dimensions of input tensor starting from begin norm axis cannot be split. Other dimensions can be split
// arbitrarily. Gamma and beta should match input to meet the broadcast requirements of mul and add.
class LayerNormInfo : public OperatorInfo {
public:
LayerNormInfo(const std::string& operator_name, const Shapes& inputs_shape, const Shapes& outputs_shape,
const PrimitiveAttrs& attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<LayerNormCost>(true)),
begin_norm_axis_(0) {}
~LayerNormInfo() override = default;
Status Init(const StrategyPtr& strategy) override;
Status InitForCostModel(const StrategyPtr& strategy) override;
Status GenerateStrategies(int32_t) override;
Status SetCostUnderStrategy(const StrategyPtr&) override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr& strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
Status InferAsLossDivisor() override;
Status CreateTensorMap(size_t input_index);
Status CreateTensorInfo(size_t input_index);
Status CreateMirrorOp(size_t input_index);
Status GenerateGammaAndBetaStrategies(const std::vector<StrategyPtr>& sp_vector);
Status InitShapes();
private:
size_t begin_norm_axis_;
Shape input_shape_;
Shape gamma_shape_;
Shape beta_shape_;
};
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_LAYER_NORM_INFO_H_
......@@ -27,6 +27,7 @@
#include "parallel/ops_info/gather_v2_info.h"
#include "parallel/ops_info/get_next_info.h"
#include "parallel/ops_info/l2_normalize_info.h"
#include "parallel/ops_info/layer_norm_info.h"
#include "parallel/ops_info/loss_info.h"
#include "parallel/ops_info/matmul_info.h"
#include "parallel/ops_info/onehot_info.h"
......
......@@ -26,6 +26,8 @@ constexpr int32_t PRELU_CHANNEL_INDEX = 1;
constexpr int32_t PRELU_CHANNEL_STRATEGY = 1;
constexpr int32_t NO_SPLIT_MAP = -1;
constexpr int32_t NO_SPLIT_STRATEGY = 1;
constexpr int32_t SPLIT_FLAG = 1;
constexpr int32_t NO_SPLIT_FLAG = 0;
constexpr size_t MATMUL_ATTRS_SIZE = 2;
constexpr size_t MATMUL_INPUTS_SIZE = 2;
constexpr size_t MATMUL_OUTPUTS_SIZE = 1;
......@@ -173,6 +175,7 @@ constexpr char ARGMINWITHVALUE[] = "ArgMinWithValue";
constexpr char CONV2D[] = "Conv2D";
constexpr char FUSE_BATCH_NORM[] = "FusedBatchNorm";
constexpr char BATCH_NORM[] = "BatchNorm";
constexpr char LAYER_NORM[] = "LayerNorm";
constexpr char POOLING[] = "Pooling";
constexpr char CAST[] = "Cast";
constexpr char MAX_POOL_WITH_ARGMAX[] = "MaxPoolWithArgmax";
......
......@@ -82,6 +82,7 @@ std::vector<std::string> splittable_op_ = {MATMUL,
SIMPLE_MEAN,
FLATTEN,
BATCH_NORM,
LAYER_NORM,
BIAS_ADD,
ASSIGN_SUB,
COS,
......
......@@ -245,8 +245,8 @@ void ValidRedistributionLayoutCheck(const DeviceArrangement& in_device_arrangeme
unified_out_tensor_map, unified_tensor_shape);
}
void ValidRedistributionLayoutCheckAll(const int32_t& device_pow_size, const int32_t& tensor_pow_size,
const int32_t& max_device_dim, const int32_t& max_shape_dim) {
void ValidRedistributionLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow_size,
int32_t max_device_dim, int32_t max_shape_dim) {
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> layout_list;
GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim,
&layout_list);
......
......@@ -260,8 +260,8 @@ TEST_F(TestReshapeLayoutTransfer, ValidInferUnifiedLayoutCheck11) {
ValidUnifiedLayoutCheck(device_arrangement, in_tensor_map, in_tensor_shape, out_tensor_map, out_tensor_shape);
}
void ValidInferUnifiedLayoutCheckAll(const int32_t& device_pow_size, const int32_t& tensor_pow_size,
const int32_t& max_device_dim, const int32_t& max_shape_dim) {
void ValidInferUnifiedLayoutCheckAll(int32_t device_pow_size, int32_t tensor_pow_size,
int32_t max_device_dim, int32_t max_shape_dim) {
std::vector<std::tuple<DeviceArrangement, TensorMap, TensorShape>> layout_list;
GenerateValidLayoutByDeviceSizeAndTensorSize(device_pow_size, tensor_pow_size, max_device_dim, max_shape_dim,
&layout_list);
......
......@@ -51,7 +51,7 @@ std::vector<std::vector<int32_t>> combine(const std::vector<int32_t>& in, int32_
return output;
}
void GenerateValidShapeBySizeAndDim(const int32_t& pow_size, const int32_t& dim,
void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim,
std::vector<std::vector<int32_t>>* out) {
out->clear();
std::vector<int32_t> in;
......@@ -78,7 +78,7 @@ void GenerateValidShapeBySizeAndDim(const int32_t& pow_size, const int32_t& dim,
return;
}
void GenerateValidShapeBySize(const int32_t& pow_size, std::vector<std::vector<int32_t>>* out) {
void GenerateValidShapeBySize(int32_t pow_size, std::vector<std::vector<int32_t>>* out) {
out->clear();
for (int32_t dim = 1; dim <= pow_size; dim++) {
std::vector<std::vector<int32_t>> combine_result;
......@@ -148,8 +148,8 @@ void GenerateValidTensorMap(const std::vector<int32_t>& device_arrangement, cons
}
void GenerateValidLayoutByDeviceSizeAndTensorSize(
const int32_t& device_pow_size, const int32_t& tensor_pow_size, const int32_t& max_device_dim,
const int32_t& max_shape_dim,
int32_t device_pow_size, int32_t tensor_pow_size, int32_t max_device_dim,
int32_t max_shape_dim,
std::vector<std::tuple<std::vector<int32_t>, std::vector<int32_t>, std::vector<int32_t>>>* layout_list) {
layout_list->clear();
std::vector<std::vector<int32_t>> device_arrangement_list;
......
......@@ -27,10 +27,10 @@ namespace parallel {
std::vector<std::vector<int32_t>> combine(const std::vector<int32_t>& in, int32_t target);
void GenerateValidShapeBySizeAndDim(const int32_t& pow_size, const int32_t& dim,
void GenerateValidShapeBySizeAndDim(int32_t pow_size, int32_t dim,
std::vector<std::vector<int32_t>>* out);
void GenerateValidShapeBySize(const int32_t& pow_size, std::vector<std::vector<int32_t>>* out);
void GenerateValidShapeBySize(int32_t pow_size, std::vector<std::vector<int32_t>>* out);
std::vector<int32_t> GenerateTensorMap(const uint32_t& map_size, const std::vector<int32_t>& pos_index,
const std::vector<int32_t>& pos_value);
......@@ -39,8 +39,8 @@ void GenerateValidTensorMap(const std::vector<int32_t>& device_arrangement, cons
std::vector<std::vector<int32_t>>* tensor_map_list);
void GenerateValidLayoutByDeviceSizeAndTensorSize(
const int32_t& device_pow_size, const int32_t& tensor_pow_size, const int32_t& max_device_dim,
const int32_t& max_shape_dim,
int32_t device_pow_size, int32_t tensor_pow_size, int32_t max_device_dim,
int32_t max_shape_dim,
std::vector<std::tuple<std::vector<int32_t>, std::vector<int32_t>, std::vector<int32_t>>>* layout_list);
uint32_t ComputeNoneNumber(const std::vector<int32_t>& tensor_map);
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
from mindspore.common.api import _executor
from mindspore.common.initializer import initializer
class Net(Cell):
def __init__(self, mul_weight, strategy1=None, strategy2=None, strategy3=None):
super().__init__()
self.begin_norm_axis = -1
self.begin_params_axis = 1
self.mul = P.Mul().set_strategy(strategy1)
self.layer_norm = P.LayerNorm(self.begin_norm_axis, self.begin_params_axis).set_strategy(strategy2)
self.mul2 = P.Mul().set_strategy(strategy3)
self.mul_weight = Parameter(mul_weight, "w1")
self.normalized_shape = [64, 32, 16]
self.gamma = Parameter(initializer('ones', self.normalized_shape), name="gamma")
self.beta = Parameter(initializer('zeros', self.normalized_shape), name="beta")
def construct(self, x, b):
out = self.mul(x, self.mul_weight)
out, _, _ = self.layer_norm(out, self.gamma, self.beta)
out = self.mul2(out, b)
return out
_x = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
_w = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32, 16]), dtype=ms.float32)
def compile(net):
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_layer_norm_data_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((16, 1, 1, 1), (16, 1, 1, 1))
strategy2 = ((16, 1, 1, 1), (1, 1, 1), (1, 1, 1))
strategy3 = ((16, 1, 1, 1), (16, 1, 1, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
def test_layer_norm_model_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((1, 1, 16, 1), (1, 1, 16, 1))
strategy2 = ((1, 1, 16, 1), (1, 16, 1), (1, 16, 1))
strategy3 = ((1, 1, 16, 1), (1, 1, 16, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
def test_layer_norm_hybrid_parallel():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((2, 2, 4, 1), (2, 4, 1), (2, 4, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
def test_layer_norm_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0)
net = Net(_w)
compile(net)
def test_layer_norm_repeat_calc():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=16, global_rank=0)
strategy1 = ((2, 2, 4, 1), (2, 2, 4, 1))
strategy2 = ((1, 2, 2, 1), (2, 2, 1), (2, 2, 1))
strategy3 = ((2, 2, 4, 1), (2, 2, 4, 1))
net = Net(_w, strategy1, strategy2, strategy3)
compile(net)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册