提交 6c85fc9f 编写于 作者: Y Yi Huaijie

dropout do mask only replace first input of

dropout_gen_mask of the subgraph instead of
the whole sub graph.
上级 553432c9
......@@ -204,7 +204,7 @@ Status DropoutDoMaskInfo::InitForCostModel(const StrategyPtr &strategy) {
PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
}
......@@ -215,8 +215,7 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
}
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dropout_gen_mask_cnode);
if (dropout_gen_mask_cnode->inputs().size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
}
if (!IsValueNode<Primitive>(dropout_gen_mask_cnode->input(0))) {
......@@ -233,11 +232,45 @@ PrimitivePtr GetDropoutGenMaskPrim(const CNodePtr &cnode) {
return prim;
}
void SetGenMaskShape(const CNodePtr &cnode, const Shape &input_slice_shape) {
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout do mask cnode's inputs must be " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
}
AnfNodePtr dropout_gen_mask = cnode->input(DROPOUT_GEN_MASK_INDEX);
MS_EXCEPTION_IF_NULL(dropout_gen_mask);
if (!dropout_gen_mask->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The dropout do mask cnode's input[" << DROPOUT_GEN_MASK_INDEX << "] must be a cnode.";
}
auto dropout_gen_mask_cnode = dropout_gen_mask->cast<CNodePtr>();
if (dropout_gen_mask_cnode->size() != DROPOUT_GEN_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of dropout gen mask cnode's inputs must be " << DROPOUT_GEN_MASK_CNODE_INPUT_SIZE;
}
if (!IsValueNode<ValueTuple>(dropout_gen_mask_cnode->input(1))) {
MS_LOG(EXCEPTION) << "The input[1] of dropout gen mask cnode is not ValueTuple.";
}
FuncGraphPtr func_graph = cnode->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure: AddNode error since manager is nullptr.";
}
ValuePtr new_shape = MakeValue(input_slice_shape);
AnfNodePtr val = NewValueNode(new_shape);
(void)manager->Replace(dropout_gen_mask_cnode->input(1), val);
}
// DropoutDoMask needs to be used together with DropoutGenMask. Only the first input tensor of DropoutGenMask is
// split. Find the DropoutGenMask node in the anf graph according to DropoutDoMask node, and modify the input shape
// of DropoutGenMask according to the strategy of DropoutDoMask. When the DropoutDoMask performs repeated calculation
// and both seeds of DropoutGenMask are 0, two new seeds are automatically generated for DropoutGenMask.
Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
std::vector<Operator> DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
std::vector<Operator> replace_ops;
MS_EXCEPTION_IF_NULL(cnode);
PrimitivePtr prim = GetDropoutGenMaskPrim(cnode);
MS_EXCEPTION_IF_NULL(prim);
......@@ -260,15 +293,20 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
if ((attr.find(SEED0) == attr.end()) || (attr.find(SEED1) == attr.end())) {
MS_LOG(EXCEPTION) << "The attrs of dropout gen mask must be have seed0 and seed1";
}
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
int32_t seed_0 = GetValue<int32_t>(attr[SEED0]);
int32_t seed_1 = GetValue<int32_t>(attr[SEED1]);
if ((seed_0 == 0) && (seed_1 == 0) && (repeated_calc_num_ > 1)) {
seed_0 = SEED_NUM;
seed_1 = SEED_NUM;
SEED_NUM++;
} else {
SetGenMaskShape(cnode, input_slice_shape);
MS_LOG(DEBUG) << "The input slice shape droupout is " << ShapeToString(input_slice_shape);
return replace_ops;
}
Shape input_slice_shape = inputs_tensor_info_[0].slice_shape();
ValuePtr new_shape = MakeValue(input_slice_shape);
Attr attr_0 = std::make_pair(SEED0, MakeValue(seed_0));
Attr attr_1 = std::make_pair(SEED1, MakeValue(seed_1));
......@@ -278,7 +316,8 @@ Operator DropoutDoMaskInfo::GetDropoutGenMaskReplaceOp(const CNodePtr &cnode) {
OperatorParams params = {std::make_pair(param_0, 1), std::make_pair(param_1, 2)};
OperatorArgs args = std::make_pair(attrs, params);
Operator replace_op = {std::make_pair(DROPOUT_GEN_MASK, args)};
return replace_op;
replace_ops.push_back(replace_op);
return replace_ops;
}
} // namespace parallel
} // namespace mindspore
......@@ -41,7 +41,7 @@ class DropoutDoMaskInfo : public OperatorInfo {
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
Operator GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
std::vector<Operator> GetDropoutGenMaskReplaceOp(const CNodePtr &cnode);
protected:
Status CheckStrategy(const StrategyPtr &strategy) override;
......
......@@ -1876,11 +1876,15 @@ void HandleDropoutNode(const OperatorInfoPtr &distribute_operator, const CNodePt
DropoutDoMaskInfoPtr dropout_do_mask = std::dynamic_pointer_cast<DropoutDoMaskInfo>(distribute_operator);
MS_EXCEPTION_IF_NULL(dropout_do_mask);
Operator replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
std::vector<Operator> replace_op = dropout_do_mask->GetDropoutGenMaskReplaceOp(cnode);
if (replace_op.empty()) {
MS_LOG(DEBUG) << "No need to replace dropout_gen_mask";
return;
}
if (cnode->inputs().size() != DROPOUT_DO_MASK_CNODE_INPUT_SIZE) {
MS_LOG(EXCEPTION) << "The size of drop out do mask cnode's input is not " << DROPOUT_DO_MASK_CNODE_INPUT_SIZE;
}
ReplaceOneOp(replace_op, cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
ReplaceOneOp(replace_op[0], cnode->input(DROPOUT_GEN_MASK_INDEX)->cast<CNodePtr>());
}
void HandleSpecialNode(const OperatorInfoPtr &distribute_operator, const CNodePtr &cnode) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册