提交 7dc31684 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1073 fix reshape tensor redistribution bug

Merge pull request !1073 from yao_yf/reshape_bug_fix
...@@ -616,8 +616,8 @@ using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>; ...@@ -616,8 +616,8 @@ using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
class GatherV2PCost : public OperatorCost { class GatherV2PCost : public OperatorCost {
public: public:
explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {} explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {}
GatherV2PCost() : OperatorCost(true) {} GatherV2PCost() : OperatorCost(true), axis_(0) {}
~GatherV2PCost() override = default; ~GatherV2PCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs, double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
......
...@@ -33,7 +33,10 @@ class GatherV2PInfo : public OperatorInfo { ...@@ -33,7 +33,10 @@ class GatherV2PInfo : public OperatorInfo {
public: public:
GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()) {} : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
axis_(0),
bias_(0),
slice_size_(0) {}
~GatherV2PInfo() override = default; ~GatherV2PInfo() override = default;
Status Init(const StrategyPtr &strategy) override; Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override; Status InitForCostModel(const StrategyPtr &strategy) override;
......
...@@ -456,8 +456,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra ...@@ -456,8 +456,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
return FAILED; return FAILED;
} }
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index]; TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout(); SetInputLayout(pre_out_tensor_info.tensor_layout());
SetInputLayout(pre_out_tensor_layout);
// infer pre_node output strategy from output_layout. // infer pre_node output strategy from output_layout.
Dimensions stra = pre_out_tensor_info.InferStrategy(); Dimensions stra = pre_out_tensor_info.InferStrategy();
if (stra.empty()) { if (stra.empty()) {
...@@ -481,15 +480,17 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra ...@@ -481,15 +480,17 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
return FAILED; return FAILED;
} }
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index]; TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout(); SetOutputLayout(next_in_tensor_info.tensor_layout());
SetOutputLayout(next_in_tensor_layout);
if (Init(nullptr) == FAILED) { if (Init(nullptr) == FAILED) {
MS_LOG(ERROR) << "Failure:operator reshape init failed"; MS_LOG(DEBUG) << "Failure:operator reshape init failed";
return FAILED; continue;
} }
SetCostForReshape(reshape_stra); SetCostForReshape(reshape_stra);
} }
} }
if (strategy_cost_.empty()) {
return FAILED;
}
return SUCCESS; return SUCCESS;
} }
} // namespace parallel } // namespace parallel
......
...@@ -38,6 +38,8 @@ class ReshapeInfo : public OperatorInfo { ...@@ -38,6 +38,8 @@ class ReshapeInfo : public OperatorInfo {
const PrimitiveAttrs &attrs) const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)), : OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)),
dev_num_(0), dev_num_(0),
pre_operator_index_(0),
next_operator_index_(0),
input_layout_set_flag_(false), input_layout_set_flag_(false),
output_layout_set_flag_(false) {} output_layout_set_flag_(false) {}
~ReshapeInfo() override = default; ~ReshapeInfo() override = default;
......
...@@ -30,9 +30,18 @@ Status ReshapeLayoutTransfer::CheckValidTransfer() { ...@@ -30,9 +30,18 @@ Status ReshapeLayoutTransfer::CheckValidTransfer() {
std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const { std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const {
bool is_unified = IsSameTensorShape(); bool is_unified = IsSameTensorShape();
std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this); std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
if (out_layout_ptr == nullptr) {
return nullptr;
}
while (!is_unified) { while (!is_unified) {
std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo(); std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo();
if (temp_layout_ptr == nullptr) {
return nullptr;
}
out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom(); out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom();
if (out_layout_ptr == nullptr) {
return nullptr;
}
is_unified = out_layout_ptr->IsSameTensorShape(); is_unified = out_layout_ptr->IsSameTensorShape();
} }
return out_layout_ptr; return out_layout_ptr;
...@@ -91,6 +100,9 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShap ...@@ -91,6 +100,9 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShap
} }
std::shared_ptr<ReshapeLayoutTransfer> exchanged_out = std::shared_ptr<ReshapeLayoutTransfer> exchanged_out =
exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr); exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr);
if (exchanged_out == nullptr) {
return nullptr;
}
return exchanged_out->ExchangeFromAndTo(); return exchanged_out->ExchangeFromAndTo();
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册