提交 1bdb26f9 编写于 作者: B BowenK

Warming up python pass by adding inline passes before it

上级 0118930c
......@@ -49,7 +49,15 @@ FuncGraphPtr KPrim::GetBprop(const PrimitivePtr &prim) {
auto scope = std::make_shared<Scope>(gradients_scope + ScopeManager::GetInstance().GetCurrentScope()->name() +
grad_op_child_scope_prefix + prim->name());
ScopeGuard scope_guard(scope);
py::function fn = prim->is_base() ? GetBpropFunction(prim->name()) : prim->cast<PrimitivePyPtr>()->GetBpropFunction();
py::function fn;
if (prim->is_base()) {
fn = GetBpropFunction(prim->name());
} else {
fn = prim->cast<PrimitivePyPtr>()->GetBpropFunction();
if (py::isinstance<py::none>(fn)) {
fn = GetBpropFunction(prim->name());
}
}
if (!fn || py::isinstance<py::none>(fn)) {
MS_LOG(DEBUG) << "Fail to find bprop function for " << prim->name() << ".";
return nullptr;
......
......@@ -35,8 +35,10 @@ namespace internal {
const char PARAMETER_MODULE[] = "mindspore.common.parameter";
const char PARAMETER_CLASS[] = "Parameter";
const char SET_PARAM[] = "__setattr__";
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph);
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res);
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
const FuncGraphPtr &top_graph);
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
const MatchResultPtr &res);
void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor::TensorPtr default_input,
bool requires_grad, bool layerwise_parallel);
......@@ -72,7 +74,8 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
return std::make_shared<ValueNode>(input_tensor);
}
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg,
const FuncGraphPtr &top_graph) {
auto call_pattern = pattern->cast<CallPtr>();
MS_EXCEPTION_IF_NULL(call_pattern);
auto prim = call_pattern->prim_value();
......@@ -81,20 +84,20 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
auto prim_pattern = call_pattern->prim_pattern();
MS_EXCEPTION_IF_NULL(prim_pattern);
return ProcessSinglePattern(prim_pattern, res, fg);
return ProcessSinglePattern(prim_pattern, res, fg, top_graph);
}
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &top_graph) {
auto new_para_pattern = pattern->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
if (!new_para_pattern->built()) {
static int parameter_id = 0;
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
auto para_node = std::make_shared<Parameter>(func_graph);
auto para_node = std::make_shared<Parameter>(top_graph);
MS_EXCEPTION_IF_NULL(para_node);
para_node->set_name(para_name);
// Set function graph
para_node->set_func_graph(func_graph);
para_node->set_func_graph(top_graph);
// Set Debug Info
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
para_node->set_debug_info(debug_info);
......@@ -103,7 +106,7 @@ AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &re
MS_EXCEPTION_IF_NULL(default_value);
para_node->set_abstract(default_value->ToAbstract()->Broaden());
res->add_entry(pattern, para_node);
func_graph->add_parameter(para_node);
top_graph->add_parameter(para_node);
// Reflect back to Cell._params
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
new_para_pattern->layerwise_parallel());
......@@ -126,7 +129,8 @@ AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
return std::make_shared<ValueNode>(scalar_value_ptr);
}
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph,
const FuncGraphPtr &top_graph) {
auto target_node = res->get_node(pattern);
if (target_node != nullptr) {
// If pattern is NewParameter, check whether it shouldn't last and is not built
......@@ -141,9 +145,10 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
} else if (pattern->isa<NewTensor>()) {
return BuildNewTensor(pattern, res);
} else if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
} else if (pattern->isa<NewParameter>()) {
return BuildNewParameter(pattern, res, func_graph);
// Add new parameter to top graph instead of current graph
return BuildNewParameter(pattern, res, top_graph);
} else if (pattern->isa<Imm>()) {
return BuildImmNode(pattern, res);
} else {
......@@ -154,17 +159,18 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
}
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
const FuncGraphPtr &func_graph) {
const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph) {
if (pattern->isa<Call>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
return BuildPrimitiveValueNode(pattern, res, func_graph, top_graph);
}
return nullptr;
}
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph,
const MatchResultPtr &res) {
auto target_inputs = pattern->inputs();
if (target_inputs.size() == 0) {
auto new_node = ProcessSinglePattern(pattern, res, func_graph);
auto new_node = ProcessSinglePattern(pattern, res, func_graph, top_graph);
if (new_node != nullptr) {
res->add_entry(pattern, new_node);
}
......@@ -172,14 +178,14 @@ AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph
}
// Build up the AnfNode in a recursive manner
std::vector<AnfNodePtr> new_inputs;
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph);
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph, top_graph);
MS_EXCEPTION_IF_NULL(prim_value_node);
new_inputs.push_back(prim_value_node);
for (auto &iter : target_inputs) {
if (iter == pattern) {
MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
}
auto input_node = BuildTarget(iter, func_graph, res);
auto input_node = BuildTarget(iter, func_graph, top_graph, res);
if (input_node == nullptr) {
MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
}
......@@ -240,11 +246,12 @@ void Reset(PatternPtr pattern) {
} // namespace internal
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
const MatchResultPtr &res) {
auto match_res = src_pattern_->match(node);
if (match_res != nullptr) {
res->merge(match_res);
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, top_graph, res);
internal::Reset(dst_pattern());
return new_node;
}
......@@ -284,16 +291,19 @@ bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res)
}
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
auto graph_nodes_sorted = TopoSort(func_graph->output());
auto func_graphs = manager->func_graphs();
bool changes = false;
// Traverse once
for (auto &node : graph_nodes_sorted) {
AnfNodePtr new_node = Run(func_graph, node, res);
if (new_node != nullptr && new_node != node) {
(void)manager->Replace(node, new_node);
changes = true;
for (auto &fg : func_graphs) {
manager->AddFuncGraph(fg);
auto graph_nodes_sorted = TopoSort(fg->output());
// Traverse once
for (auto &node : graph_nodes_sorted) {
AnfNodePtr new_node = Run(fg, func_graph, node, res);
if (new_node != nullptr && new_node != node) {
MS_LOG(WARNING) << "Matched";
(void)manager->Replace(node, new_node);
changes = true;
}
}
}
return changes;
......
......@@ -39,7 +39,8 @@ class PythonPass {
~PythonPass() = default;
bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res);
std::string name() const { return name_; }
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res);
AnfNodePtr Run(const FuncGraphPtr &func_graph, const FuncGraphPtr &top_graph, const AnfNodePtr &node,
const MatchResultPtr &res);
PatternPtr src_pattern() { return src_pattern_; }
PatternPtr dst_pattern() { return dst_pattern_; }
......
......@@ -43,15 +43,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
}
PyPassManager::PyPassManager() {
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
phase_to_group_[Phase::PREAD] = std::make_shared<PassGroup>("Pre_AD_PassGroup");
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>("After_OPT_PassGroup");
res_ = std::make_shared<MatchResult>();
}
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool run_only_once) {
// NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pg = GetPassGroup(Phase::OPT);
bool requires_grad, bool run_only_once) {
PassGroupPtr cur_pg;
if (requires_grad) {
cur_pg = GetPassGroup(Phase::PREAD);
} else {
cur_pg = GetPassGroup(Phase::OPT);
}
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(run_only_once);
MS_EXCEPTION_IF_NULL(pattern);
......@@ -62,11 +66,13 @@ void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &patt
}
void PyPassManager::Unregiste(const std::string &pass_name) {
// NOTE: remove phase option to avoid unnecessary confusion.
auto cur_pm = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pm);
if (!cur_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "No such pass : " + pass_name + "\n";
auto opt_pm = GetPassGroup(Phase::OPT);
if (!opt_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "Opt has no such pass : " + pass_name + "\n";
}
auto pre_ad_pm = GetPassGroup(Phase::PREAD);
if (!pre_ad_pm->DeletePass(pass_name)) {
MS_LOG(WARNING) << "Pre_AD has no such pass : " + pass_name + "\n";
}
}
......@@ -92,7 +98,7 @@ void PyPassManager::ClearRes() {
REGISTER_PYBIND_DEFINE(
PyPassManager_, ([](const py::module *m) {
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("resolve", Phase::RESOLVE).value("opt", Phase::OPT);
(void)py::enum_<Phase>(*m, "phase", py::arithmetic()).value("pre_ad", Phase::PREAD).value("opt", Phase::OPT);
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass")
......
......@@ -38,7 +38,7 @@ namespace python_pass {
class PyPassManager;
using PyPassManagerPtr = std::shared_ptr<PyPassManager>;
enum Phase { RESOLVE, OPT };
enum Phase { PREAD, OPT };
class PyPassManager {
protected:
......@@ -52,8 +52,8 @@ class PyPassManager {
// Access the only global instance
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
bool run_only_once = false);
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target, bool requires_grad,
bool run_only_once);
void Unregiste(const std::string &pass_name);
void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase);
......
......@@ -288,6 +288,8 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
return true;
}
bool OptInlineAction(const ResourcePtr &res) { return OptimizeAction(res, kInlinePasses); }
bool GeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kGePasses); }
bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPasses); }
......@@ -460,7 +462,12 @@ bool ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
return ppm->GetPassGroup(phase)->Run(res->func_graph());
}
bool ResolveActionPyStub(const ResourcePtr &res) { return true || ActionPyStub(res, opt::python_pass::Phase::RESOLVE); }
bool PreAdActionPyStub(const ResourcePtr &res) {
if (!ActionPyStub(res, opt::python_pass::Phase::PREAD)) {
MS_LOG(DEBUG) << "No Match.";
}
return true;
}
bool OptActionVmPyStub(const ResourcePtr &res) {
if (ActionPyStub(res, opt::python_pass::Phase::OPT)) {
......@@ -516,12 +523,14 @@ static std::vector<ActionItem> CommonPipeline() {
if (!multi_graphs) {
actions.emplace_back(std::make_pair("combine_like_graphs", CombineLikeGraphs));
}
// Add resolve-stage python pass stub
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
// Do data structure simplifications and inline
actions.emplace_back(std::make_pair("inline", OptInlineAction));
// Add pre-ad, post-inline python pass stub
actions.emplace_back(std::make_pair("py_pre_ad", PreAdActionPyStub));
return actions;
}
......
......@@ -165,6 +165,12 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
return map_a;
}
OptPassGroupMap GetA1A2(const opt::irpass::OptimizeIRPassLib &irpass) {
auto opt_a = GetOptPassesA(irpass);
OptPassGroupMap a1_a2({opt_a[0], opt_a[1]});
return a1_a2;
}
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining,
......@@ -270,6 +276,7 @@ static std::unordered_map<std::string, std::shared_ptr<Optimizer>> g_pass_opts =
void InitOpt(const ResourcePtr &res) {
if (g_pass_opts.size() == 0) {
opt::irpass::OptimizeIRPassLib irpass;
g_pass_opts["a1a2"] = Optimizer::MakeOptimizer("a1a2", res, GetA1A2(irpass));
g_pass_opts["opt_a"] = Optimizer::MakeOptimizer("opt_a", res, GetOptPassesA(irpass));
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_after_cconv"] =
......@@ -318,6 +325,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
return true;
}
bool OptPassA1A2(const ResourcePtr &res) { return OptPassGroup(res, "a1a2"); }
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
......@@ -440,5 +448,7 @@ std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{"cconv", CconvPass},
{"transform_top", TransformTopGraphPass},
{"transform_graph", OptPassTransformGraphGroup}};
std::vector<PassItem> kInlinePasses = {{"simplify_data_structures", SimplifyDataStructuresPass}, {"a1a2", OptPassA1A2}};
} // namespace pipeline
} // namespace mindspore
......@@ -29,6 +29,7 @@ using PassItem = std::pair<std::string, std::function<bool(ResourcePtr)>>;
extern std::vector<PassItem> kGePasses;
extern std::vector<PassItem> kVmPasses;
extern std::vector<PassItem> kInlinePasses;
extern std::vector<PassItem> kPynativePasses;
bool CconvPass(const ResourcePtr &res);
......
......@@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""Reference for python pass registration."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm,\
set_reopt
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, _set_renorm,\
_set_reopt
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm",
"set_reopt"
"_set_renorm",
"_set_reopt"
]
......@@ -23,22 +23,26 @@ __all__ = [
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm",
"set_reopt"
"_set_renorm",
"_set_reopt"
]
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once (bool): Specify whether or not to run pass only once. Default: False.
Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, run_only_once=False):
def __init__(self, requires_grad=True, run_only_once=False):
if not isinstance(requires_grad, bool):
raise TypeError(f"Expect bool, got : ({type(requires_grad)}){requires_grad}")
if not isinstance(run_only_once, bool):
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
self.requires_grad = requires_grad
self.run_only_once_ = run_only_once
PyPassManager_.__init__(self)
......@@ -51,7 +55,7 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.run_only_once_)
super().registe(pass_name, pattern, target, self.requires_grad, self.run_only_once_)
def unregiste(self, py_pass):
if isinstance(py_pass, str):
......@@ -81,11 +85,12 @@ class PyPassManager(PyPassManager_):
raise TypeError(f"Expect do_reopt to be a bool, got {do_reopt}")
super().set_reopt(do_reopt)
def registe_pass(run_only_once=False):
def registe_pass(requires_grad=True, run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
requires_grad(bool): Do automatic-differentiation after modified graph if true. Default: True
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
......@@ -99,7 +104,7 @@ def registe_pass(run_only_once=False):
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return PyPassManager(run_only_once)
return PyPassManager(requires_grad, run_only_once)
def unregiste_pass(py_pass):
"""
......@@ -157,7 +162,7 @@ def cancel_new_parameter(pattern):
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)
def set_renorm(should_renorm):
def _set_renorm(should_renorm):
"""
Set whether or not to do renormalization after modified graph in python pass(es).
......@@ -171,7 +176,7 @@ def set_renorm(should_renorm):
ppm = PyPassManager()
ppm.set_renorm(should_renorm)
def set_reopt(do_reopt):
def _set_reopt(do_reopt):
"""
Set whether or not to do optimization after modified graph in python pass(es).
......
......@@ -19,8 +19,8 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter, set_reopt
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, _set_renorm, gen_new_parameter,\
cancel_new_parameter, _set_reopt
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.graph_utils.graph_pattern import OneOf, Prim, Call, NoneOf, Any, NewTensor, NewParameter, Imm
......@@ -157,8 +157,8 @@ def test_isnot_pattern_0():
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm(False)
set_reopt(False)
_set_renorm(False)
_set_reopt(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
......@@ -176,7 +176,7 @@ def test_isnot_pattern_0():
inputs = Tensor(np.random.normal(0, 1, (10, 32, 32, 32)), mindspore.float32)
conv_bn_model = ConvBN()
@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def single_bn_pass():
"""
Sub a BN which does NOT take Conv as inputs to ReLU6.
......@@ -188,7 +188,7 @@ def test_isnot_pattern_0():
target = Call(P.ReLU6(), [pattern_0])
return pattern, target
@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def bn_pass():
"""
Sub a BN to Softmax.
......@@ -202,7 +202,7 @@ def test_isnot_pattern_0():
unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
set_renorm(True)
_set_renorm(True)
def test_isnot_pattern_1():
"""
......@@ -234,12 +234,12 @@ def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
set_renorm(False)
set_reopt(False)
_set_renorm(False)
_set_reopt(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
......@@ -252,7 +252,7 @@ def test_newtensor_pattern():
unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
set_renorm(True)
_set_renorm(True)
def test_newparameter_pattern():
"""
......@@ -261,9 +261,9 @@ def test_newparameter_pattern():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
set_renorm(False)
set_reopt(False)
@registe_pass(run_only_once=True)
_set_renorm(False)
_set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_addn_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
......@@ -288,9 +288,9 @@ def test_imm_target():
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
set_renorm(False)
set_reopt(False)
@registe_pass(run_only_once=True)
_set_renorm(False)
_set_reopt(False)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_pass():
x = Any()
pattern = Call(P.Softmax(), [x])
......@@ -313,10 +313,10 @@ def test_gen_new_parameter():
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor)
set_renorm(False)
set_reopt(False)
_set_renorm(False)
_set_reopt(False)
gen_new_parameter(new_para)
@registe_pass(run_only_once=True)
@registe_pass(requires_grad=False, run_only_once=True)
def softmax_make_tuple_pass():
x = Any()
softmax = P.Softmax()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册