提交 d11dc827 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1181 fix gatherv2 replace graph in auto parallel

Merge pull request !1181 from yao_yf/fix_gather_v2_replace_graph
......@@ -258,16 +258,20 @@ Status GatherV2PInfo::ComputeReplaceGraph(const CNodePtr &cnode) {
return SUCCESS;
}
ReplaceGraphPtr GatherV2PInfo::replace_graph(const CNodePtr &cnode) {
auto param_strategy = strategy_->GetInputDim().at(0);
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return nullptr;
}
return replace_graph_;
}
Status GatherV2PInfo::Init(const StrategyPtr &strategy) {
auto param_strategy = strategy->GetInputDim().at(0);
if (InitWithAutoRepeatCalc(strategy) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": Init failed.";
return FAILED;
}
if (param_strategy.at(IntToSize(axis_)) != 1 && ComputeReplaceGraph(cnode_) != SUCCESS) {
MS_LOG(ERROR) << name_ << ": ComputeReplaceGraph failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": Init success.";
return SUCCESS;
}
......
......@@ -43,6 +43,7 @@ class GatherV2PInfo : public OperatorInfo {
Status GenerateStrategies(int32_t stage_id) override;
Status SetCostUnderStrategy(const StrategyPtr &strategy) override;
ReplaceGraphPtr replace_graph(const CNodePtr &cnode) override;
std::shared_ptr<std::vector<std::vector<int32_t>>> GenerateBatchStrategies() override;
protected:
......
......@@ -138,9 +138,9 @@ Status ReshapeInfo::ComputeReplaceOp() {
MS_LOG(ERROR) << name_ << ": tensor_redistribution init failed.";
return FAILED;
}
MS_LOG(INFO) << name_ << ": input " << input_layout_.ToString();
MS_LOG(INFO) << name_ << ": output " << output_layout_.ToString();
MS_LOG(INFO) << name_ << ": dev_list " << dev_list.size();
MS_LOG(DEBUG) << name_ << ": input " << input_layout_.ToString();
MS_LOG(DEBUG) << name_ << ": output " << output_layout_.ToString();
MS_LOG(DEBUG) << name_ << ": dev_list " << dev_list.size();
RedistributionOpListPtr redistribution_oplist_ptr = tensor_redistribution.InferTensorRedistributionOperatorList();
if (redistribution_oplist_ptr == nullptr) {
MS_LOG(ERROR) << name_ << "InferTensorRedistribution failed.";
......@@ -148,7 +148,7 @@ Status ReshapeInfo::ComputeReplaceOp() {
}
replace_op_ = redistribution_oplist_ptr->first;
replace_op_info_ = redistribution_oplist_ptr->second;
MS_LOG(INFO) << name_ << ": replace op size = " << replace_op_.size();
MS_LOG(DEBUG) << name_ << ": replace op size = " << replace_op_.size();
return SUCCESS;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册