提交 f4bb43bb 编写于 作者: Y yangzhenzhang

add concat op

上级 a3959071
......@@ -199,6 +199,8 @@ class SoftmaxCost : public OperatorCost {
using SoftmaxCostPtr = std::shared_ptr<SoftmaxCost>;
using TileCost = SoftmaxCost;
using TileCostPtr = std::shared_ptr<TileCost>;
using ConcatCost = TileCost;
using ConcatCostPtr = std::shared_ptr<ConcatCost>;
class TmpIdentityCost : public OperatorCost {
public:
......
......@@ -136,6 +136,7 @@ REGISTER(EmbeddingLookupInfo);
REGISTER(TileInfo);
REGISTER(StridedSliceInfo);
REGISTER(DropoutInfo);
REGISTER(ConcatInfo);
} // namespace parallel
} // namespace mindspore
......
......@@ -24,7 +24,6 @@
namespace mindspore {
namespace parallel {
const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
MAKE_TUPLE,
J,
LIST_GETITEM,
ARRAY_GETITEM,
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/parallel/ops_info/concat_info.h"
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "frontend/parallel/device_matrix.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/tensor_layout/tensor_redistribution.h"
#include "pipeline/jit/resource.h"
namespace mindspore {
namespace parallel {
Status ConcatInfo::GetAttrs() {
int axis = 0;
auto axis_iter = attrs_.find(AXIS);
if (axis_iter != attrs_.end()) {
MS_EXCEPTION_IF_NULL(axis_iter->second);
if (axis_iter->second->isa<Int32Imm>()) {
axis = axis_iter->second->cast<Int32ImmPtr>()->value();
} else {
MS_LOG(ERROR) << name_ << ": The value of axis is not int";
return FAILED;
}
} else {
MS_LOG(ERROR) << name_ << ": Can not find the axis attr";
return FAILED;
}
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
int dim = SizeToInt(inputs_shape_[0].size());
if (axis < 0) {
axis = axis + dim;
}
axis_ = SizeToInt(axis);
return SUCCESS;
}
Status ConcatInfo::CheckStrategy(const StrategyPtr &strategy) {
MS_EXCEPTION_IF_NULL(strategy);
if (CheckStrategyValue(strategy, inputs_shape_, is_auto_parallel_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Invalid strategy";
return FAILED;
}
std::vector<Dimensions> stra = strategy->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << ": The strategy is empty";
return FAILED;
}
if (stra.size() != inputs_shape_.size()) {
MS_LOG(ERROR) << name_ << ": The size of strategy must be equal to the size of inputs shape";
return FAILED;
}
for (size_t i = 0; i < stra.size(); ++i) {
auto strategy_ele = stra[i];
auto input_shape_ele = inputs_shape_[i];
if (strategy_ele.size() != input_shape_ele.size()) {
MS_LOG(ERROR) << name_ << ": The size of strategy element must be equal to the size of input shape";
return FAILED;
}
if (axis_ >= strategy_ele.size()) {
MS_LOG(ERROR) << name_ << ": The axis is out of range, the axis is " << axis_;
return FAILED;
}
if (strategy_ele[axis_] != 1) {
MS_LOG(ERROR) << name_ << ": The axis can not be split";
return FAILED;
}
for (size_t j = 0; j < strategy_ele.size(); ++j) {
if (strategy_ele[j] != stra[0][j]) {
MS_LOG(ERROR) << name_ << ": The strategy of each input tensor must be equal";
return FAILED;
}
}
}
return SUCCESS;
}
Status ConcatInfo::InferDevMatrixShape() {
MS_EXCEPTION_IF_NULL(strategy_);
std::vector<Dimensions> stra = strategy_->GetInputDim();
if (stra.empty()) {
MS_LOG(ERROR) << name_ << "The strategy is empty";
return FAILED;
}
dev_matrix_shape_ = stra[0];
return SUCCESS;
}
Status ConcatInfo::InferTensorMap() {
TensorMap tensor_map;
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << "The inputs shape is empty";
return FAILED;
}
// cannot use dev_matrix_shape_ replace inputs_shape_[0], because it may not be fully split in all devices.
int32_t size = SizeToInt(inputs_shape_[0].size());
for (int i = 0; i < size; ++i) {
tensor_map.push_back(size - i - 1);
}
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
inputs_tensor_map_.push_back(tensor_map);
}
outputs_tensor_map_.push_back(tensor_map);
return SUCCESS;
}
Status ConcatInfo::InferMirrorOps() {
mirror_ops_.clear();
if (inputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs tensor map is empty";
return FAILED;
}
Shape input_tensor_map = inputs_tensor_map_[0];
std::vector<Group> group;
if (CreateGroupByTensorMap(input_tensor_map, &group) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Create group for input failed.";
return FAILED;
}
if (group.empty()) {
MS_LOG(INFO) << name_ << ": The mirror group is empty.";
return SUCCESS;
}
OperatorVector input_op;
input_op = CreateMirrorOps(group[0].name(), group[0].GetDevNum());
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
mirror_ops_.push_back(input_op);
}
return SUCCESS;
}
Status ConcatInfo::InferTensorInfo() {
if (inputs_shape_.empty() || outputs_shape_.empty() || inputs_tensor_map_.empty() || outputs_tensor_map_.empty()) {
MS_LOG(ERROR) << name_ << ": Invalid args";
return FAILED;
}
TensorLayout input_layout, output_layout;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
// infer tensor layout
if (input_layout.InitFromVector(dev_matrix_shape_, inputs_tensor_map_[i], inputs_shape_[i]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer input tensor layout failed.";
return FAILED;
}
TensorInfo input_tensor_info(input_layout);
inputs_tensor_info_.push_back(input_tensor_info);
}
if (output_layout.InitFromVector(dev_matrix_shape_, outputs_tensor_map_[0], outputs_shape_[0]) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer output tensor layout failed.";
return FAILED;
}
TensorInfo output_tensor_info(output_layout);
outputs_tensor_info_.push_back(output_tensor_info);
return SUCCESS;
}
void ConcatInfo::ReComputeBatchSplitFlagList() {
for (size_t i = 0; i < inputs_shape_.size(); i++) {
split_flag_list_[i] = true;
}
}
Status ConcatInfo::SetCostUnderStrategy(const StrategyPtr &strategy) {
if (SetCostUnderStrategyBase(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Set cost under strategy failed.";
return FAILED;
}
return SUCCESS;
}
Status ConcatInfo::GenerateStrategies(int32_t stage_id) {
if (InferAttrs() != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Infer attrs failed";
return FAILED;
}
if (inputs_shape_.empty()) {
MS_LOG(ERROR) << name_ << ": The inputs shape is empty";
return FAILED;
}
Shape input_split;
for (size_t i = 0; i < inputs_shape_[0].size(); ++i) {
if (i == axis_) {
input_split.push_back(0);
} else {
input_split.push_back(1);
}
}
Shapes splittable_inputs;
for (size_t i = 0; i < inputs_shape_.size(); ++i) {
splittable_inputs.push_back(input_split);
}
std::vector<StrategyPtr> sp_vector;
is_auto_parallel_ = true;
if (GenerateStrategiesWithBroadcast(stage_id, inputs_shape_, splittable_inputs, &sp_vector) != SUCCESS) {
return FAILED;
}
size_t success = 0;
for (auto &sp : sp_vector) {
PrintStrategy(sp);
if (SetCostUnderStrategy(sp) == SUCCESS) {
success++;
MS_LOG(INFO) << name_ << ": Successfully generated " << success << " strategy.";
PrintStrategy(sp);
}
}
return SUCCESS;
}
Status ConcatInfo::Init(const StrategyPtr &strategy) {
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
Status ConcatInfo::InitForCostModel(const StrategyPtr &strategy) {
if (InitForCostModelWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init for cost model failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init for cost model success.";
return SUCCESS;
}
} // namespace parallel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
#define MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
#include <string>
#include <memory>
#include <unordered_map>
#include <vector>
#include "ir/value.h"
#include "frontend/parallel/auto_parallel/operator_costmodel.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
namespace mindspore {
namespace parallel {
class ConcatInfo : public OperatorInfo {
public:
ConcatInfo(const std::string &operator_name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(operator_name, inputs_shape, outputs_shape, attrs, std::make_shared<ConcatCost>(false)) {}
~ConcatInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
Status GenerateStrategies(int32_t) override;
Status SetCostUnderStrategy(const StrategyPtr &) override;
void ReComputeBatchSplitFlagList() override;
protected:
Status GetAttrs() override;
Status CheckStrategy(const StrategyPtr &strategy) override;
Status InferMirrorOps() override;
Status InferForwardCommunication() override { return SUCCESS; }
Status InferTensorInfo() override;
Status InferDevMatrixShape() override;
Status InferTensorMap() override;
private:
size_t axis_ = 0;
};
using ConcatInfoPtr = std::shared_ptr<ConcatInfo>;
} // namespace parallel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_CONCAT_INFO_H_
......@@ -39,5 +39,6 @@
#include "frontend/parallel/ops_info/gather_v2_p_info.h"
#include "frontend/parallel/ops_info/tile_info.h"
#include "frontend/parallel/ops_info/strided_slice_info.h"
#include "frontend/parallel/ops_info/concat_info.h"
#endif // MINDSPORE_CCSRC_FRONTEND_PARALLEL_OPS_INFO_HEAD_FILES_H_
......@@ -118,6 +118,9 @@ bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) {
std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
std::vector<bool> is_parameter;
std::vector<AnfNodePtr> node_inputs{node->inputs()};
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
}
for (size_t i = 1; i < node_inputs.size(); ++i) {
auto input = node_inputs[i];
......@@ -192,6 +195,10 @@ std::vector<size_t> ExtractInputTypeLengthByNode(const CNodePtr &node) {
std::vector<size_t> inputs_type_len;
std::vector<AnfNodePtr> node_inputs{node->inputs()};
if ((node_inputs.size() == 2) && AnfNodeIsPrimitive(node_inputs[1], MAKE_TUPLE)) {
node_inputs = node_inputs[1]->cast<CNodePtr>()->inputs();
}
// extract input element length
for (auto &input : node_inputs) {
if (IsValueNode<RefKey>(input)) {
......@@ -255,7 +262,7 @@ bool IsSplittableOperator(const std::string &op_name) {
FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK,
REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING,
MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT,
LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT,
STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS};
// clang-format on
......@@ -275,7 +282,7 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) {
return false;
}
bool bool_result = IsParallelCareNode(cnode) && !IsSplittableOperator(prim->name());
if (bool_result) {
if (bool_result && (prim->name() != MAKE_TUPLE)) {
MS_LOG(EXCEPTION) << "Should implementing OperatorInfo for: " << prim->name();
} else if (prim->name() == CAST) {
if (cnode->fullname_with_scope().find(OPTIMIZER_SUB_STRING) != std::string::npos) {
......
......@@ -267,6 +267,33 @@ TensorLayout GetTensorInLayout(const CNodePtr &middle_node, const PrimitivePtr &
return tensorinfo_in.tensor_layout();
}
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name) {
MS_EXCEPTION_IF_NULL(anf_node);
auto cnode = anf_node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
auto value_node = cnode->input(0)->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == prim_name) {
return true;
}
return false;
}
std::string GetPrimName(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsValueNode<Primitive>(node->input(0))) {
MS_LOG(EXCEPTION) << "The node is not a primitive";
}
auto value_node = node->input(0)->cast<ValueNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(value_node);
MS_EXCEPTION_IF_NULL(prim);
return prim->name();
}
OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!IsParallelCareNode(node)) {
......@@ -274,7 +301,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
}
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
MS_LOG(EXCEPTION) << "Distribute operator is nullptr, the prim is " << GetPrimName(node);
}
return distribute_operator;
}
......@@ -423,6 +450,11 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
MS_EXCEPTION_IF_NULL(manager);
AnfNodeIndexSet node_set = manager->node_users()[node];
CNodePtr insert_node_new;
if (AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
return;
}
if (IsValueNode<Primitive>(node->input(0))) {
auto current_value = node->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(current_value);
......@@ -875,9 +907,15 @@ void InsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
if ((node->inputs().size() == 2) && AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE)) {
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
return;
}
if (mirror_ops.size() != node_size - 1) {
MS_LOG(EXCEPTION) << "Failure:Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size()
<< ", node_size is " << node_size;
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
<< node_size - 1;
}
for (size_t index = 1; index < node_size; ++index) {
OperatorVector backward_op = mirror_ops[index - 1];
......@@ -993,7 +1031,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
const std::vector<Shapes> &shape_list) {
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_ = OperatorInstanceByName(prim->name(), attrs, shape_list);
if (operator_ == nullptr) {
if ((operator_ == nullptr) && (prim->name() != MAKE_TUPLE)) {
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
MS_EXCEPTION_IF_NULL(operator_);
......@@ -1177,7 +1215,12 @@ std::vector<Shapes> ExtractShape(const CNodePtr &node) {
continue;
}
if (input_shapes.size() != 1) {
MS_LOG(EXCEPTION) << "ExtractShape:Get input shape failed";
if (inputs_size == 2) { // like concat
shape_inputs = input_shapes;
break;
} else {
MS_LOG(EXCEPTION) << "ExtractShape: Get input shape failed";
}
}
shape_inputs.push_back(input_shapes[0]);
}
......@@ -1269,8 +1312,8 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
}
TensorInfo tensorinfo_in = distribute_operator->inputs_tensor_info()[IntToSize(res.second - 1)];
Shape slice_shape = tensorinfo_in.slice_shape();
MS_LOG(DEBUG) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
<< MakeValue(slice_shape)->ToString();
MS_LOG(INFO) << "SetParallelShape slice_shape " << parameter->ToString() << " shape "
<< MakeValue(slice_shape)->ToString() << ", op name is " << distribute_operator->name();
std::shared_ptr<abstract::BaseShape> parallel_shape = std::make_shared<abstract::Shape>(slice_shape);
MS_EXCEPTION_IF_NULL(parallel_shape);
// Don't modify it in-place as the pointer of this AbstractValue may used as cache key in StaticAnalysis.
......@@ -1450,6 +1493,9 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
SetVirtualDatasetStrategy(cnode);
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == MAKE_TUPLE) {
continue;
}
auto attrs = prim->attrs();
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
if (IsParallelCareNode(cnode)) {
......@@ -2045,13 +2091,13 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
// the make_tuple is parallel care node, but it may have not operator info
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
continue;
}
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
if (distribute_operator == nullptr) {
continue;
}
MS_EXCEPTION_IF_NULL(distribute_operator);
// insert forward ops
InsertForwardOps(distribute_operator, cnode);
......@@ -2074,13 +2120,12 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
if (!IsValueNode<Primitive>(cnode->input(0))) {
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
continue;
}
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
if (distribute_operator == nullptr) {
continue;
}
MS_EXCEPTION_IF_NULL(distribute_operator);
// StepReplace
StepReplace(distribute_operator, cnode);
}
......@@ -2330,6 +2375,44 @@ Status ParallelInit() {
return SUCCESS;
}
void HandleForwardMakeTuple(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &node : all_nodes) {
if (!AnfNodeIsPrimitive(node, MAKE_TUPLE)) {
continue;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (!cnode->in_forward_flag()) {
continue;
}
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
MS_EXCEPTION_IF_NULL(manager);
auto make_tuple_user = manager->node_users()[cnode];
if (make_tuple_user.size() != 1) {
MS_LOG(EXCEPTION) << "Now the make_tuple's user must be 1, but got " << make_tuple_user.size();
}
CNodePtr make_tuple_next_cnode = make_tuple_user.pop().first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple_next_cnode);
std::string make_tuple_user_prim_name = GetPrimName(make_tuple_next_cnode);
if (!IsParallelCareNode(make_tuple_next_cnode)) {
MS_LOG(INFO) << "The make_tuple's user is " << make_tuple_user_prim_name << ", no need to set operator info";
continue;
}
if (make_tuple_next_cnode->inputs().size() != 2) {
MS_LOG(EXCEPTION) << "Now the make_tuple's user only support 1 input, but got "
<< make_tuple_next_cnode->inputs().size() - 1;
}
MS_LOG(INFO) << "Set the make_tuple's operator info, and the op name is " << make_tuple_user_prim_name;
OperatorInfoPtr op_info = GetDistributeOperator(make_tuple_next_cnode);
MS_EXCEPTION_IF_NULL(op_info);
cnode->set_user_data<OperatorInfo>(op_info);
}
}
bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) {
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(optimizer);
......@@ -2383,6 +2466,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
ExtractInformation(all_nodes);
ReshapeInit(all_nodes);
}
HandleForwardMakeTuple(all_nodes);
// save strategy as checkpoint for multi-train
if (StrategyCheckpoint::GetInstance().SaveCheckPointOn()) {
CheckpointStrategy(root);
......
......@@ -149,6 +149,8 @@ Status ParallelInit();
std::vector<std::string> ExtractInputsTensorName(const CNodePtr &node);
std::set<FuncGraphPtr> ForwardGraph(const FuncGraphPtr &root);
bool AnfNodeIsPrimitive(const AnfNodePtr &anf_node, const std::string &prim_name);
} // namespace parallel
} // namespace mindspore
......
......@@ -222,9 +222,17 @@ def get_bprop_virtual_div_operator(self):
dtype = P.DType()
def bprop(x, out, dout):
if F.issubclass_(F.dtype(dout), mstype.bool_):
return (dout,)
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
if F.issubclass_(F.typeof(dout), mstype.tensor):
if F.issubclass_(F.dtype(dout), mstype.bool_):
return (dout,)
dx = op(dout, cast(F.scalar_to_array(divisor), dtype(dout)))
return (dx,)
dx = ()
input_nums = F.tuple_len(dout)
for i in range(input_nums):
ele_grad = op(dout[i], cast(F.scalar_to_array(divisor), dtype(dout[i])))
dx = dx + (ele_grad,)
return (dx,)
return bprop
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore as ms
from mindspore import context, Tensor, Parameter
from mindspore.common.api import _executor
from mindspore.nn import Cell, TrainOneStepCell, Momentum
from mindspore.ops import operations as P
class Net(Cell):
def __init__(self, weight, weight2, strategy1=None, strategy2=None, is_parameter=True):
super().__init__()
self.concat = P.Concat(axis=0).set_strategy(strategy1)
if is_parameter:
self.weight = Parameter(weight, "w1")
else:
self.weight = weight
self.mul = P.Mul().set_strategy(strategy2)
self.weight2 = Parameter(weight2, "w2")
def construct(self, x, b):
out = self.concat((self.weight, self.weight2))
out = self.mul(x, out)
return out
class Net2(Cell):
def __init__(self, weight, strategy1=None, strategy2=None, axis=0):
super().__init__()
self.mul = P.Mul().set_strategy(strategy1)
self.concat = P.Concat(axis=axis).set_strategy(strategy2)
self.weight = Parameter(weight, "w")
def construct(self, x, b):
out = self.mul(x, b)
out = self.concat((out, self.weight))
return out
_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
_w1 = Tensor(np.ones([96, 64, 32]), dtype=ms.float32)
_w2 = Tensor(np.ones([32, 64, 32]), dtype=ms.float32)
_w3 = Tensor(np.ones([128, 16, 32]), dtype=ms.float32)
_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32)
def compile_net(net):
context.set_context(save_graphs=True)
optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9)
train_net = TrainOneStepCell(net, optimizer)
train_net.set_auto_parallel()
_executor.compile(train_net, _x, _b)
context.reset_auto_parallel_context()
def test_concat_parameter():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 4, 2), (1, 4, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
compile_net(net)
def test_concat_parameter_no_full_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 2, 2), (1, 2, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=True)
compile_net(net)
def test_concat_tensor_and_parameter():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((1, 2, 2), (1, 2, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net(_w1, _w2, strategy1, strategy2, is_parameter=False)
compile_net(net)
def test_concat_output():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
strategy2 = ((1, 4, 2), (1, 4, 2))
net = Net2(_w1, strategy1, strategy2)
compile_net(net)
def test_concat_output_no_full_split():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
strategy2 = ((1, 2, 2), (1, 2, 2))
net = Net2(_w1, strategy1, strategy2)
compile_net(net)
def test_concat_no_strategy():
context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0)
strategy1 = ((2, 2, 2), (2, 2, 2))
strategy2 = None
net = Net2(_w3, strategy1, strategy2, axis=1)
compile_net(net)
def test_concat_auto_parallel():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
net = Net2(_w2)
compile_net(net)
def test_concat_auto_parallel2():
context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=8, global_rank=0)
strategy1 = None
strategy2 = None
net = Net2(_w3, strategy1, strategy2, axis=1)
compile_net(net)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册