提交 174acbec 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3265 Decouple ir from optimizer

Merge pull request !3265 from hewei/decouple_ir_optimizer
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_ANF_VISITOR_H_
#define MINDSPORE_CCSRC_OPTIMIZER_ANF_VISITOR_H_
#include <vector>
#include "ir/visitor.h"
#include "frontend/optimizer/optimizer_caller.h"
namespace mindspore {
class AnfVisitor : public AnfIrVisitor, public OptimizerCaller {};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_ANF_VISITOR_H_
......@@ -21,7 +21,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/opt.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
namespace mindspore {
namespace opt {
......
......@@ -23,9 +23,9 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
#include "ir/optimizer_caller.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
namespace mindspore {
namespace opt {
......
......@@ -22,7 +22,7 @@
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
......
......@@ -17,7 +17,7 @@
#include "frontend/optimizer/irpass/cast_eliminate.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "ir/func_graph.h"
#include "pipeline/jit/parse/data_converter.h"
......
......@@ -17,7 +17,7 @@
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CAST_ELIMINATE_H_
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
......
......@@ -21,7 +21,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "frontend/operator/ops.h"
......
......@@ -25,8 +25,8 @@
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
......
......@@ -24,7 +24,7 @@
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
......
......@@ -26,7 +26,7 @@
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
......
......@@ -23,7 +23,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "common/utils.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/ad/grad.h"
......
......@@ -24,7 +24,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "frontend/operator/ops.h"
......
......@@ -25,8 +25,8 @@
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
......
......@@ -22,7 +22,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -23,7 +23,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "frontend/operator/ops.h"
......
......@@ -21,8 +21,8 @@
#include <memory>
#include <vector>
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
......
......@@ -24,7 +24,7 @@
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "utils/graph_utils.h"
#include "frontend/operator/composite/composite.h"
......
......@@ -23,7 +23,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -22,7 +22,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -21,7 +21,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "pipeline/jit/parse/parse.h"
......
......@@ -23,7 +23,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -19,7 +19,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
namespace mindspore {
namespace opt {
......
......@@ -23,7 +23,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "abstract/dshape.h"
......
......@@ -20,8 +20,8 @@
#include <vector>
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
......
......@@ -22,9 +22,9 @@
#include <memory>
#include <vector>
#include "ir/optimizer_caller.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/prim_eliminate.h"
......
......@@ -26,7 +26,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "ir/manager.h"
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
......
......@@ -22,7 +22,7 @@
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/irpass.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "pipeline/jit/parse/data_converter.h"
#include "pipeline/jit/parse/python_adapter.h"
......
......@@ -22,7 +22,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -22,7 +22,7 @@
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/visitor.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -23,7 +23,7 @@
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/optimizer_caller.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/operator/ops.h"
namespace mindspore {
......
......@@ -14,20 +14,20 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_IR_OPTIMIZER_CALLER_H_
#define MINDSPORE_CORE_IR_OPTIMIZER_CALLER_H_
#ifndef MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_CALLER_H_
#define MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_CALLER_H_
#include <memory>
#include "ir/anf.h"
#include "ir/visitor.h"
namespace mindspore {
namespace opt {
class Optimizer;
using OptimizerPtr = std::shared_ptr<Optimizer>;
using OptimizerWeakPtr = std::weak_ptr<Optimizer>;
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
using PredicateFuncType = mindspore::PredicateFuncType;
} // namespace opt
class OptimizerCaller {
......@@ -36,4 +36,4 @@ class OptimizerCaller {
};
using OptimizerCallerPtr = std::shared_ptr<OptimizerCaller>;
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_OPTIMIZER_CALLER_H_
#endif // MINDSPORE_CCSRC_OPTIMIZER_OPTIMIZER_CALLER_H_
......@@ -36,7 +36,7 @@
namespace mindspore {
namespace {
class DeepFirstSearcher : public AnfVisitor {
class DeepFirstSearcher : public AnfIrVisitor {
public:
explicit DeepFirstSearcher(const IncludeFunc &include, const FilterFunc &filter = nullptr)
: include_(include), filter_(filter) {}
......@@ -67,7 +67,7 @@ class DeepFirstSearcher : public AnfVisitor {
res_.push_back(node);
}
if (incl == FOLLOW) {
AnfVisitor::Visit(node);
AnfIrVisitor::Visit(node);
}
}
......
......@@ -67,7 +67,7 @@ class BaseRef;
class Var;
using VarPtr = std::shared_ptr<Var>;
class AnfVisitor;
class AnfIrVisitor;
class ParamValue;
using ParamValuePtr = std::shared_ptr<ParamValue>;
......@@ -100,7 +100,7 @@ class AnfNode : public Base {
~AnfNode() override = default;
MS_DECLARE_PARENT(AnfNode, Base);
virtual void accept(AnfVisitor *) {}
virtual void accept(AnfIrVisitor *) {}
FuncGraphPtr func_graph() const { return func_graph_.lock(); }
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
......@@ -214,7 +214,7 @@ class CNode : public AnfNode {
~CNode() override = default;
MS_DECLARE_PARENT(CNode, AnfNode);
void accept(AnfVisitor *v) override;
void accept(AnfIrVisitor *v) override;
// check whether this cnode has some primitive value as the first input.
bool IsApply(const PrimitivePtr &) const;
......@@ -265,7 +265,7 @@ class Parameter : public ANode {
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
void accept(AnfVisitor *v) override;
void accept(AnfIrVisitor *v) override;
std::string DebugString(int recursive_level = 1) const override;
std::string name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
......@@ -332,7 +332,7 @@ class ValueNode : public ANode {
~ValueNode() override = default;
MS_DECLARE_PARENT(ValueNode, ANode);
void accept(AnfVisitor *v) override;
void accept(AnfIrVisitor *v) override;
const ValuePtr &value() const { return value_; }
std::string fullname_with_scope() override;
......
......@@ -93,7 +93,7 @@ std::string CNode::fullname_with_scope() {
return fullname_with_scope_;
}
void CNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
void CNode::accept(AnfIrVisitor *v) { v->Visit(shared_from_base<CNode>()); }
void ValueNode::accept(AnfIrVisitor *v) { v->Visit(shared_from_base<ValueNode>()); }
void Parameter::accept(AnfIrVisitor *v) { v->Visit(shared_from_base<Parameter>()); }
} // namespace mindspore
......@@ -22,8 +22,7 @@
#include <tuple>
#include <vector>
#include "ir/anf.h"
#include "ir/optimizer_caller.h"
#include "ir/visitor.h"
#include "base/core_ops.h"
namespace mindspore {
......@@ -41,7 +40,7 @@ namespace mindspore {
template <typename T>
class PBase {
public:
bool CheckFunc(const opt::PredicateFuncType &func, const AnfNodePtr &node) {
bool CheckFunc(const PredicateFuncType &func, const AnfNodePtr &node) {
return func(get_object().GetNode(node));
}
......
......@@ -18,24 +18,24 @@
#include "ir/visitor.h"
namespace mindspore {
void AnfVisitor::Visit(const AnfNodePtr &node) { node->accept(this); }
void AnfIrVisitor::Visit(const AnfNodePtr &node) { node->accept(this); }
void AnfVisitor::Visit(const CNodePtr &cnode) {
void AnfIrVisitor::Visit(const CNodePtr &cnode) {
for (auto &input : cnode->inputs()) {
Visit(input);
}
}
void AnfVisitor::Visit(const ValueNodePtr &vnode) {
void AnfIrVisitor::Visit(const ValueNodePtr &vnode) {
if (IsValueNode<FuncGraph>(vnode)) {
auto func_graph = GetValueNode<FuncGraphPtr>(vnode);
Visit(func_graph->output());
}
}
void AnfVisitor::Visit(const ParameterPtr &) {}
void AnfIrVisitor::Visit(const ParameterPtr &) {}
VisitFuncType AnfVisitor::Match(const PrimitivePtr &prim, const std::vector<opt::PredicateFuncType> &funcs) {
VisitFuncType AnfIrVisitor::Match(const PrimitivePtr &prim, const std::vector<PredicateFuncType> &funcs) {
auto fn = [prim, funcs, this](const AnfNodePtr &node) {
if (!IsPrimitiveCNode(node, prim)) {
return;
......
......@@ -18,18 +18,19 @@
#define MINDSPORE_CORE_IR_VISITOR_H_
#include <vector>
#include "ir/optimizer_caller.h"
#include "ir/anf.h"
namespace mindspore {
using VisitFuncType = std::function<void(const AnfNodePtr &)>;
class AnfVisitor : public OptimizerCaller {
using PredicateFuncType = std::function<bool(const AnfNodePtr &)>;
class AnfIrVisitor {
public:
virtual void Visit(const AnfNodePtr &);
virtual void Visit(const CNodePtr &);
virtual void Visit(const ValueNodePtr &);
virtual void Visit(const ParameterPtr &);
VisitFuncType Match(const PrimitivePtr &, const std::vector<opt::PredicateFuncType> & = {});
virtual ~AnfVisitor() = default;
VisitFuncType Match(const PrimitivePtr &, const std::vector<PredicateFuncType> & = {});
virtual ~AnfIrVisitor() = default;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_IR_VISITOR_H_
......@@ -24,7 +24,6 @@
#include "ir/func_graph_cloner.h"
#include "ir/manager.h"
#include "ir/value.h"
#include "ir/visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "pipeline/jit/resource.h"
......
......@@ -20,9 +20,9 @@
#include "common/py_func_graph_fetcher.h"
#include "ir/anf.h"
#include "ir/visitor.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/opt.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/irpass/arithmetic_simplify.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册