提交 cf0cecc1 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2478 Support writing pass in python

Merge pull request !2478 from BowenK/python_pass_clean
......@@ -391,10 +391,6 @@ class TensorAddByZero : public AnfVisitor {
}
void Visit(const AnfNodePtr &node) override {
if (IsPrimitive(node, prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
if (node->isa<ValueNode>() && CheckTensorConstant(0).IsTensorScalarConstant(GetValueNode(node))) {
is_zero_ = true;
return;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/pass_group.h"
namespace mindspore {
namespace opt {
namespace python_pass {
void PassGroup::AddPass(const PythonPassPtr &pass) {
if (pass != nullptr) {
passes_.push_back(pass);
}
}
bool PassGroup::DeletePass(const std::string &pass_name) {
for (auto iter = passes_.begin(); iter != passes_.end(); iter++) {
if ((*iter)->name() == pass_name) {
*iter = nullptr;
passes_.erase(iter);
return true;
}
}
return false;
}
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const {
if (func_graph == nullptr) {
return false;
}
bool changed = false;
for (const auto &pass : passes) {
if (pass != nullptr) {
if (pass->Run(func_graph)) {
changed = true;
}
}
}
return changed;
}
bool PassGroup::Run(const FuncGraphPtr &func_graph) const {
bool changed = false;
// run all passes
bool change = true;
while (change) {
change = Run(func_graph, passes_);
changed = change || changed;
if (run_only_once_) {
break;
}
}
return changed;
}
} // namespace python_pass
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
#define MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
#include <utility>
#include <vector>
#include <string>
#include <memory>
#include "optimizer/py_pass.h"
namespace mindspore {
namespace opt {
namespace python_pass {
class PassGroup {
public:
explicit PassGroup(const std::string &name = "pass_group", bool run_only_once = false)
: name_(name), passes_{}, run_only_once_(run_only_once) {}
virtual ~PassGroup() = default;
// Add graph pass, the pass object will be freed when pass manager freed.
void AddPass(const PythonPassPtr &pass);
// Delete graph pass before the pass manager is freed.
bool DeletePass(const std::string &pass_name);
// Run passes added in pass manager on the input graph
// @param [inout] graph The graph to be optimized
// @return true, graph changed
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph) const;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [in] passes The given graph passes
// @return true, graph changed
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const;
std::string name() const { return name_; }
private:
const std::string name_;
std::vector<PythonPassPtr> passes_;
bool run_only_once_;
};
using PassGroupPtr = std::shared_ptr<PassGroup>;
} // namespace python_pass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_GROUP_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/py_pass.h"
#include <unordered_set>
#include <deque>
#include <algorithm>
#include <utility>
#include <vector>
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "pipeline/parse/parse_base.h"
#include "pipeline/resource.h"
namespace mindspore {
namespace opt {
namespace python_pass {
namespace internal {
std::string GetNodeRepr(AnfNodePtr node) {
if (node != nullptr) {
if (node->isa<CNode>()) {
std::string repr = "(";
auto const &inputs = node->cast<CNodePtr>()->inputs();
for (auto &input : inputs) {
repr += " ";
repr += GetNodeRepr(input);
repr += " ";
}
repr += ")";
return repr;
}
if (node->isa<ValueNode>()) {
return GetValueNode(node)->ToString();
}
return node->ToString();
}
return "";
}
void ResolveFuncGraph_(const FuncGraphPtr &fg) {
auto manager = Manage(fg, false);
parse::python_adapter::set_use_signature_in_resolve(false);
parse::ResolveAll(manager);
}
bool Match(const AnfNodePtr &pattern, const AnfNodePtr &node, const NodeEquivPtr &equiv_ptr) {
if (node == nullptr) {
return false;
}
MS_EXCEPTION_IF_NULL(pattern);
if (pattern->isa<ValueNode>()) {
if (!node->isa<ValueNode>()) {
return false;
}
if (GetNodeRepr(pattern) == GetNodeRepr(node)) {
// add to equiv_ptr
equiv_ptr->insert(std::make_pair(GetValueNode(pattern)->ToString(), node));
return true;
}
return false;
} else if (pattern->isa<Parameter>()) {
MS_LOG(DEBUG) << pattern->ToString() + "\n";
// add to equiv_ptr
equiv_ptr->insert(std::make_pair(pattern->ToString(), node));
return true;
} else if (pattern->isa<CNode>()) {
// match every single sub ANode
if (!node->isa<CNode>()) {
return false;
}
auto pattern_inputs = pattern->cast<CNodePtr>()->inputs();
auto node_inputs = node->cast<CNodePtr>()->inputs();
if (pattern_inputs.size() != node_inputs.size()) {
return false;
}
for (auto p_item = pattern_inputs.begin(), node_item = node_inputs.begin(); p_item != pattern_inputs.end();
p_item++, node_item++) {
auto res = Match(*p_item, *node_item, equiv_ptr);
if (!res) {
return false;
}
}
return true;
}
MS_LOG(EXCEPTION) << "Unexpected condition, (" + pattern->ToString() + " , " + node->ToString() + ")\n";
}
AnfNodePtr BuildTarget(const FuncGraphPtr &func_graph, const AnfNodePtr cur_raw_dst_node_,
const NodeEquivPtr &equiv_ptr) {
if (cur_raw_dst_node_->isa<Parameter>()) {
auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->ToString());
if (sub_pair != equiv_ptr->end()) {
return sub_pair->second;
}
MS_LOG(EXCEPTION) << "cur_raw_dst_node_ : " + internal::GetNodeRepr(cur_raw_dst_node_) + "\n";
} else if (cur_raw_dst_node_->isa<ValueNode>()) {
// check primitive ValueNode
auto sub_pair = equiv_ptr->find(cur_raw_dst_node_->cast<ValueNodePtr>()->value()->ToString());
if (sub_pair != equiv_ptr->end()) {
return sub_pair->second;
}
return cur_raw_dst_node_;
} else if (cur_raw_dst_node_->isa<CNode>()) {
std::vector<AnfNodePtr> new_inputs;
auto inputs = cur_raw_dst_node_->cast<CNodePtr>()->inputs();
for (auto sub_node = inputs.begin(); sub_node != inputs.end(); sub_node++) {
auto subed = internal::BuildTarget(func_graph, *sub_node, equiv_ptr);
new_inputs.push_back(subed);
}
return func_graph->NewCNode(new_inputs);
}
MS_LOG(EXCEPTION) << "Unexpected node type, got : " + internal::GetNodeRepr(cur_raw_dst_node_);
}
bool isTraversable(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
if (node->isa<CNode>() || node->isa<Parameter>()) {
return true;
}
if (IsValueNode<FuncGraph>(node) || IsValueNode<RefKey>(node)) {
return true;
}
return false;
}
} // namespace internal
void PythonPass::Build(const py::function &src, const py::function &dst) {
// 1. get FuncGraph from py::function
auto src_fg_ = parse::ParsePythonCode(src);
auto dst_fg_ = parse::ParsePythonCode(dst);
if (src_fg_ == nullptr || dst_fg_ == nullptr) {
MS_LOG(EXCEPTION) << "Failed to parse python code.\n";
}
// 2. Resolve
internal::ResolveFuncGraph_(src_fg_);
internal::ResolveFuncGraph_(dst_fg_);
// 3. from FuncGraphPtr to ValueNode
src_node_ = src_fg_->output();
dst_node_ = dst_fg_->output();
}
PythonPass::PythonPass(const std::string &name, const py::function &src, const py::function &dst, bool run_only_once,
bool multigraph)
: name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {
Build(src, dst);
}
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
auto equiv_ptr = std::make_shared<NodeEquiv>();
bool is_a_match = internal::Match(src_node_, node, equiv_ptr);
if (is_a_match) {
auto new_node = internal::BuildTarget(func_graph, dst_node_, equiv_ptr);
MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
return new_node;
}
return nullptr;
}
bool PythonPass::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
auto seen = NewSeenGeneration();
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.push_back(func_graph->output());
bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
todo.pop_front();
// check whether this node has been matched.
if (node == nullptr || node->seen_ == seen || !internal::isTraversable(node) || !all_nodes.contains(node)) {
continue;
}
node->seen_ = seen;
// select nodes that this transform can be applied.
AnfNodePtr new_node = Run(func_graph, node);
bool change = (new_node != nullptr);
if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node);
} else if (new_node == nullptr) {
new_node = node;
}
if (run_only_once_) {
return change;
}
// find success, and add them to todo list
if (IsValueNode<FuncGraph>(node)) {
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
}
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
}
auto &node_users = manager->node_users();
if (change && node_users.find(node) != node_users.end()) {
for (auto &use : node_users[node]) {
auto use_node = use.first;
if (use_node == nullptr) {
continue;
}
todo.push_back(use_node);
if (use_node->seen_ == seen) {
use_node->seen_--;
}
}
}
}
return changes;
}
} // namespace python_pass
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
#define MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
#include <string>
#include <memory>
#include <unordered_map>
#include "ir/anf.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
namespace mindspore {
namespace opt {
namespace python_pass {
class PythonPass;
using PythonPassPtr = std::shared_ptr<PythonPass>;
using NodeEquiv = std::unordered_map<std::string, AnfNodePtr>;
using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
class PythonPass {
public:
explicit PythonPass(const std::string &name, const py::function &src, const py::function &dst,
bool run_only_once = false, bool multigraph = true);
~PythonPass() = default;
bool Run(const FuncGraphPtr &func_graph);
std::string name() const { return name_; }
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
private:
void Build(const py::function &src, const py::function &dst);
AnfNodePtr src_node_ = nullptr;
AnfNodePtr dst_node_ = nullptr;
const std::string name_;
bool run_only_once_;
bool multigraph_ = true;
};
using PythonPassPtr = std::shared_ptr<PythonPass>;
} // namespace python_pass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_PASS_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "optimizer/py_pass_manager.h"
#include <functional>
#include <algorithm>
#include <utility>
#include <initializer_list>
#include "ir/manager.h"
#include "optimizer/pass_group.h"
namespace mindspore {
namespace opt {
namespace python_pass {
PyPassManagerPtr PyPassManager::global_instance = nullptr;
std::unordered_map<Phase, PassGroupPtr> PyPassManager::phase_to_group_;
PassGroupPtr PyPassManager::GetPassGroup(Phase phase) {
auto pm = phase_to_group_.find(phase);
if (pm == phase_to_group_.end()) {
return nullptr;
}
return pm->second;
}
PyPassManagerPtr PyPassManager::GetInstance() {
if (global_instance == nullptr) {
global_instance = std::shared_ptr<PyPassManager>(new (std::nothrow) PyPassManager());
}
return global_instance;
}
PyPassManager::PyPassManager() {
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
}
void PyPassManager::Registe(const std::string &pass_name, const py::function &pattern, const py::function &target,
Phase phase, bool run_only_once, bool multigraph) {
auto cur_pm = GetPassGroup(phase);
MS_EXCEPTION_IF_NULL(cur_pm);
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph);
cur_pm->AddPass(new_pass);
}
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
auto cur_pm = GetPassGroup(phase);
MS_EXCEPTION_IF_NULL(cur_pm);
if (!cur_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
}
}
void PyPassManager::ClearRes() {
MS_LOG(INFO) << "Clear PyPassManager resources!";
global_instance = nullptr;
phase_to_group_.clear();
}
REGISTER_PYBIND_DEFINE(
PyPassManager_, ([](const py::module *m) {
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT);
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass")
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass");
}));
} // namespace python_pass
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
#define MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
#include <memory>
#include <string>
#include <vector>
#include <unordered_map>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "utils/graph_utils.h"
#include "common/utils.h"
#include "pipeline/parse/resolve.h"
#include "optimizer/py_pass.h"
#include "optimizer/pass_group.h"
namespace mindspore {
namespace opt {
namespace python_pass {
class PyPassManager;
using PyPassManagerPtr = std::shared_ptr<PyPassManager>;
enum Phase { RESOLVE, OPT };
class PyPassManager {
protected:
PyPassManager();
static PyPassManagerPtr global_instance;
public:
// Singletons should not be cloneable and assignable
PyPassManager(const PyPassManager &other) = delete;
void operator=(const PyPassManager &) = delete;
// Access the only global instance
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const py::function &pattern, const py::function &target,
Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true);
void Unregiste(const std::string &pass_name, Phase phase);
PassGroupPtr GetPassGroup(Phase phase);
void ClearRes();
private:
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
};
} // namespace python_pass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_PY_PASS_MANAGER_H_
......@@ -39,6 +39,7 @@
#include "optimizer/optimizer.h"
#include "vm/transform.h"
#include "parse/python_adapter.h"
#include "optimizer/py_pass_manager.h"
namespace mindspore {
namespace pipeline {
......@@ -423,6 +424,25 @@ bool RemoveValueNodeDuplicationsAction(const ResourcePtr &res) {
bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }
void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
MS_EXCEPTION_IF_NULL(res->manager());
MS_EXCEPTION_IF_NULL(res->func_graph());
auto ppm = opt::python_pass::PyPassManager::GetInstance();
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
MS_LOG(DEBUG) << "No match.\n";
}
}
bool ResolveActionPyStub(const ResourcePtr &res) {
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
return true;
}
bool OptActionPyStub(const ResourcePtr &res) {
ActionPyStub(res, opt::python_pass::Phase::RESOLVE);
return true;
}
static std::vector<ActionItem> CommonPipeline() {
std::vector<ActionItem> actions;
......@@ -435,6 +455,8 @@ static std::vector<ActionItem> CommonPipeline() {
if (!multi_graphs) {
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}
// Add resolve-stage python pass stub
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
......@@ -446,6 +468,8 @@ std::vector<ActionItem> GePipeline() {
auto actions = CommonPipeline();
// optimize
actions.emplace_back(std::make_pair("optimize", GeOptimizeAction));
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
actions.emplace_back(std::make_pair("validate", ValidateAction));
return actions;
......@@ -457,6 +481,9 @@ std::vector<ActionItem> VmPipeline() {
// optimize
actions.emplace_back(std::make_pair("optimize", VmOptimizeAction));
// Add opt-stage python pass stub
actions.emplace_back(std::make_pair("py_opt", OptActionPyStub));
actions.emplace_back(std::make_pair("validate", ValidateAction));
// compile the ANF graph
......
......@@ -39,6 +39,7 @@
#include "device/kernel_runtime_manager.h"
#include "debug/trace.h"
#include "pynative/pynative_execute.h"
#include "optimizer/py_pass_manager.h"
#if (ENABLE_GE || ENABLE_D)
#include "pipeline/pipeline_ge.h"
......@@ -967,6 +968,7 @@ void ClearResAtexit() {
pipeline::ExecutorPy::ClearRes();
pipeline::ReclaimOptimizer();
pynative::PynativeExecutor::GetInstance()->ClearRes();
opt::python_pass::PyPassManager::GetInstance()->ClearRes();
#ifdef ENABLE_GE
transform::DfGraphManager::GetInstance().ClearGraph();
transform::DfGraphConvertor::get_adpt_map().clear();
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Python pass register"""
from inspect import isfunction
from mindspore._c_expression import PyPassManager_
from mindspore._c_expression import phase
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
pipeline_phase (phase): Specify the stage in which the pass will run in the pipeline. Default: phase.opt.
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
multigraph (bool): Whether or not the pattern exists across graphs. Default: True.
Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if not isinstance(run_only_once, bool):
raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}")
if not isinstance(multi_graph, bool):
raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}")
PyPassManager_.__init__(self)
self.phase_ = pipeline_phase
self.run_only_once_ = run_only_once
self.multi_graph_ = multi_graph
def registe(self, py_pass):
if not isfunction(py_pass):
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
if not isfunction(pattern):
raise TypeError(f"Expecting function pattern, got : ({type(pattern)}){pattern}")
if not isfunction(target):
raise TypeError(f"Expecting function target, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
def unregiste(self, py_pass, pipeline_phase=phase.opt):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if isinstance(py_pass, str):
super().unregiste(py_pass, pipeline_phase)
return
if isfunction(py_pass):
super().unregiste(py_pass.__name__, pipeline_phase)
return
raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
self.registe(py_pass)
return py_pass
def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
"""
Examples:
>>> @registe_pass()
>>> def toy_pass():
>>> def pattern():
>>> pass
>>> def target():
>>> pass
"""
return PyPassManager(pipeline_phase, run_only_once, multi_graph)
......@@ -173,7 +173,8 @@ class Dense(Cell):
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
activation (str): activate function applied to the output of the fully connected layer, eg. 'relu'.
Default: None.
Raises:
ValueError: If weight_init or bias_init shape is incorrect.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册