提交 b0921c15 编写于 作者: Y yao_yf

xreshape tensor_redistrinution bug fix

上级 716def7c
......@@ -616,8 +616,8 @@ using GatherV2CostPtr = std::shared_ptr<GatherV2Cost>;
class GatherV2PCost : public OperatorCost {
public:
explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related) {}
GatherV2PCost() : OperatorCost(true) {}
explicit GatherV2PCost(bool is_inputs_related) : OperatorCost(is_inputs_related), axis_(0) {}
GatherV2PCost() : OperatorCost(true), axis_(0) {}
~GatherV2PCost() override = default;
double GetCommCost(const std::vector<TensorInfo> &inputs, const std::vector<TensorInfo> &outputs,
......
......@@ -33,7 +33,10 @@ class GatherV2PInfo : public OperatorInfo {
public:
GatherV2PInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape,
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()) {}
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<GatherV2PCost>()),
axis_(0),
bias_(0),
slice_size_(0) {}
~GatherV2PInfo() override = default;
Status Init(const StrategyPtr &strategy) override;
Status InitForCostModel(const StrategyPtr &strategy) override;
......
......@@ -456,8 +456,7 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
return FAILED;
}
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
SetInputLayout(pre_out_tensor_layout);
SetInputLayout(pre_out_tensor_info.tensor_layout());
// infer pre_node output strategy from output_layout.
Dimensions stra = pre_out_tensor_info.InferStrategy();
if (stra.empty()) {
......@@ -481,15 +480,17 @@ Status ReshapeInfo::GenetateStrategyCosts(const std::vector<std::shared_ptr<Stra
return FAILED;
}
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
SetOutputLayout(next_in_tensor_layout);
SetOutputLayout(next_in_tensor_info.tensor_layout());
if (Init(nullptr) == FAILED) {
MS_LOG(ERROR) << "Failure:operator reshape init failed";
return FAILED;
MS_LOG(DEBUG) << "Failure:operator reshape init failed";
continue;
}
SetCostForReshape(reshape_stra);
}
}
if (strategy_cost_.empty()) {
return FAILED;
}
return SUCCESS;
}
} // namespace parallel
......
......@@ -38,6 +38,8 @@ class ReshapeInfo : public OperatorInfo {
const PrimitiveAttrs &attrs)
: OperatorInfo(name, inputs_shape, outputs_shape, attrs, std::make_shared<ReshapeCost>(false)),
dev_num_(0),
pre_operator_index_(0),
next_operator_index_(0),
input_layout_set_flag_(false),
output_layout_set_flag_(false) {}
~ReshapeInfo() override = default;
......
......@@ -30,9 +30,18 @@ Status ReshapeLayoutTransfer::CheckValidTransfer() {
std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::UnifyDeviceArrangementAndTensorShape() const {
bool is_unified = IsSameTensorShape();
std::shared_ptr<ReshapeLayoutTransfer> out_layout_ptr = std::make_shared<ReshapeLayoutTransfer>(*this);
if (out_layout_ptr == nullptr) {
return nullptr;
}
while (!is_unified) {
std::shared_ptr<ReshapeLayoutTransfer> temp_layout_ptr = out_layout_ptr->ExtendFromTensorShapeByTo();
if (temp_layout_ptr == nullptr) {
return nullptr;
}
out_layout_ptr = temp_layout_ptr->ExtendToTensorShapeByFrom();
if (out_layout_ptr == nullptr) {
return nullptr;
}
is_unified = out_layout_ptr->IsSameTensorShape();
}
return out_layout_ptr;
......@@ -91,6 +100,9 @@ std::shared_ptr<ReshapeLayoutTransfer> ReshapeLayoutTransfer::ExtendToTensorShap
}
std::shared_ptr<ReshapeLayoutTransfer> exchanged_out =
exchanged_from_and_to_ptr->ExpandFromTensorShapeAndExpandToDeviceArrangement(*expanded_shape_ptr);
if (exchanged_out == nullptr) {
return nullptr;
}
return exchanged_out->ExchangeFromAndTo();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册