提交 0fd57bd1 编写于 作者: L liubuyu

fix remove reshape pair pass

上级 09034282
......@@ -37,7 +37,7 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
std::string op_name = AnfAlgo::GetCNodeName(cnode);
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
// deal ref op
if (op_info->is_ref()) {
if (op_info != nullptr && op_info->is_ref()) {
auto ref_infos = op_info->ref_infos();
if (ref_infos.count(cur_out_index) != 0) {
auto in_index = ref_infos.at(cur_out_index);
......
......@@ -23,33 +23,33 @@
namespace mindspore {
namespace opt {
const BaseRef RemoveReshapePair::DefinePattern() const {
const auto prim_reshape = std::make_shared<Primitive>(prim::kPrimReshape->name());
VectorRef reshape({prim_reshape, input_varptr_});
return VectorRef({prim::kPrimReshape, reshape});
VarPtr X = std::make_shared<Var>();
MS_EXCEPTION_IF_NULL(X);
return VectorRef({prim::kPrimReshape, VectorRef({prim::kPrimReshape, X})});
}
const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &equiv) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(equiv);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto reshape_op_1 = CheckAnfNodeIfCNodeAndInputSize(node, kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_op_1);
// If reshape operator used by more than one other operators, reshape operator cant not be deleted directly
auto users = manager->node_users()[reshape_op_1];
if (users.size() > 1) {
if (IsUsedByOthers(func_graph, reshape_op_1)) {
return nullptr;
}
auto reshape_op_2 = CheckAnfNodeIfCNodeAndInputSize(reshape_op_1->input(1), kBackendReshapeInputNum);
MS_EXCEPTION_IF_NULL(reshape_op_2);
users = manager->node_users()[reshape_op_2];
if (users.size() > 1) {
if (IsUsedByOthers(func_graph, reshape_op_2)) {
return nullptr;
}
auto output_shape = AnfAlgo::GetOutputDeviceShape(reshape_op_2, 0);
auto input_shape = AnfAlgo::GetInputDeviceShape(reshape_op_1, 0);
if (input_shape == output_shape) {
auto input_node = reshape_op_2->input(1);
return input_node;
}
return nullptr;
}
} // namespace opt
} // namespace mindspore
......@@ -28,15 +28,10 @@ namespace mindspore {
namespace opt {
class RemoveReshapePair : public PatternProcessPass {
public:
explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {
input_varptr_ = std::make_shared<Var>();
}
explicit RemoveReshapePair(bool multigraph = true) : PatternProcessPass("remove_reshape_pair", multigraph) {}
~RemoveReshapePair() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
VarPtr input_varptr_;
};
} // namespace opt
} // namespace mindspore
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册