提交 8d693306 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4126 Add new parameter

Merge pull request !4126 from BowenK/new_parameter
......@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include "frontend/optimizer/pass_group.h"
#include "frontend/optimizer/py_pass_manager.h"
namespace mindspore {
namespace opt {
......@@ -35,14 +36,15 @@ bool PassGroup::DeletePass(const std::string &pass_name) {
return false;
}
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const {
bool PassGroup::Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes,
const MatchResultPtr &res) const {
if (func_graph == nullptr) {
return false;
}
bool changed = false;
for (const auto &pass : passes) {
if (pass != nullptr) {
if (pass->Run(func_graph)) {
if (pass->Run(func_graph, res)) {
changed = true;
}
}
......@@ -54,8 +56,9 @@ bool PassGroup::Run(const FuncGraphPtr &func_graph) const {
bool changed = false;
// run all passes
bool change = true;
auto res = PyPassManager::GetInstance()->GetMatchResult();
while (change) {
change = Run(func_graph, passes_);
change = Run(func_graph, passes_, res);
changed = change || changed;
if (run_only_once_) {
break;
......
......@@ -41,12 +41,14 @@ class PassGroup {
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph) const;
// Run the given graph passes on the input graph
// @param [inout] graph The graph to be optimized
// @param [inout] func_graph The graph to be optimized
// @param [in] passes The given graph passes
// @param [inout] res MatchResult used to collect all matched patterns and nodes
// @return true, graph changed
// @return false, graph not changed
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes) const;
bool Run(const FuncGraphPtr &func_graph, const std::vector<PythonPassPtr> &passes, const MatchResultPtr &res) const;
std::string name() const { return name_; }
void SetRunOnlyOnce(bool run_only_once) { run_only_once_ = run_only_once; }
private:
const std::string name_;
......
......@@ -96,6 +96,7 @@ MatchResultPtr IsIn::match(const AnfNodePtr &node) {
for (auto &iter : patterns_) {
auto res = iter->match(node);
if (res != nullptr) {
res->add_entry(shared_from_base<IsIn>(), node);
return res;
}
}
......@@ -151,6 +152,9 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<AnyPattern, std::shared_ptr<AnyPattern>, Pattern>(*m, "AnyPattern").def(py::init<>());
(void)py::class_<NewTensor, std::shared_ptr<NewTensor>, Pattern>(*m, "NewTensor_")
.def(py::init<tensor::TensorPtr>());
(void)py::class_<NewParameter, std::shared_ptr<NewParameter>, Pattern>(*m, "NewParameter_")
.def(py::init<string, tensor::TensorPtr, bool, bool, bool>());
(void)py::class_<Imm, std::shared_ptr<Imm>, Pattern>(*m, "Imm").def(py::init<int>());
}));
} // namespace python_pass
} // namespace opt
......
......@@ -42,6 +42,10 @@ class CallWith;
using CallWithPtr = std::shared_ptr<CallWith>;
class NewTensor;
using NewTensorPtr = std::shared_ptr<NewTensor>;
class NewParameter;
using NewParameterPtr = std::shared_ptr<NewParameter>;
class Imm;
using ImmPtr = std::shared_ptr<Imm>;
struct PatternHasher;
struct PatternEqual;
using PatternNodeMap = std::unordered_map<PatternPtr, AnfNodePtr, PatternHasher, PatternEqual>;
......@@ -55,6 +59,7 @@ class Pattern : public Base {
string unique_name() const { return unique_name_; }
vector<PatternPtr> inputs() { return inputs_; }
bool should_replace() { return should_replace_; }
void set_should_replace(bool should_replace) { should_replace_ = should_replace; }
virtual void reset() {}
protected:
......@@ -86,14 +91,14 @@ class IsPrimTypeOf : public Pattern {
~IsPrimTypeOf() = default;
IsPrimTypeOf(vector<PrimitivePyPtr> prims, string name, bool should_replace)
: primitives_(prims), name_(name), matched_prim_(nullptr) {
unique_name_ = std::to_string(g_id_++) + "_" + name;
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
should_replace_ = should_replace;
if (!should_replace) {
matched_prim_ = prims[0];
}
}
IsPrimTypeOf(vector<string> types, string name, bool should_replace) : types_(types), name_(name) {
unique_name_ = std::to_string(g_id_++) + "_" + name;
unique_name_ = std::to_string(g_id_++) + "IsPrimTypeOf_" + name;
// Make primitives_
for (auto &iter : types) {
primitives_.push_back(std::make_shared<PrimitivePy>(iter, py::cast(nullptr)));
......@@ -126,19 +131,20 @@ class CallWith : public Pattern {
CallWith(PatternPtr prim_pattern, vector<PatternPtr> inputs, bool should_replace) {
// NOTE: should_replace is ignored in this case, since each sub-pattern has its own setting
prim_pattern_ = prim_pattern;
unique_name_ = std::to_string(g_id_++) + prim_pattern->unique_name();
unique_name_ = std::to_string(g_id_++) + "CallWithPattern_" + prim_pattern->unique_name();
inputs_ = inputs;
should_replace_ = should_replace;
// NOTE: should_replace_ is overrided by it prim_pattern(if exists) silently.
should_replace_ = prim_pattern->should_replace();
}
CallWith(PrimitivePyPtr prim, vector<PatternPtr> inputs, bool should_replace) {
prim_ = prim;
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
unique_name_ = std::to_string(g_id_++) + "CallWithPrim_" + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
CallWith(string prim_str, vector<PatternPtr> inputs, bool should_replace) {
prim_ = std::make_shared<PrimitivePy>(prim_str, py::cast(nullptr));
unique_name_ = std::to_string(g_id_++) + prim_->ToString();
unique_name_ = std::to_string(g_id_++) + "CallWithStr_" + prim_->ToString();
inputs_ = inputs;
should_replace_ = should_replace;
}
......@@ -159,7 +165,7 @@ class IsIn : public Pattern {
IsIn() { unique_name_ = std::to_string(g_id_++); }
~IsIn() = default;
explicit IsIn(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++);
unique_name_ = std::to_string(g_id_++) + "IsIn";
for (auto &iter : patterns) {
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
......@@ -176,9 +182,9 @@ class IsNot : public Pattern {
IsNot() { unique_name_ = std::to_string(g_id_++); }
~IsNot() = default;
explicit IsNot(vector<PatternPtr> patterns) : patterns_(patterns) {
unique_name_ = std::to_string(g_id_++);
unique_name_ = std::to_string(g_id_++) + "IsNot";
for (auto &iter : patterns) {
unique_name_ = "IsNot_" + unique_name_ + "_" + iter->unique_name();
unique_name_ = unique_name_ + "_" + iter->unique_name();
}
}
MS_DECLARE_PARENT(IsNot, Pattern);
......@@ -200,7 +206,10 @@ class NewTensor : public Pattern {
public:
NewTensor() { unique_name_ = std::to_string(g_id_++); }
~NewTensor() = default;
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) { should_replace_ = false; }
explicit NewTensor(tensor::TensorPtr input_tensor) : input_tensor_(input_tensor) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "NewTensor";
}
MS_DECLARE_PARENT(NewTensor, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override {
MS_LOG(EXCEPTION) << "Find NewTensor in pattern, NewTensor should only appear in the target.\n";
......@@ -211,6 +220,54 @@ class NewTensor : public Pattern {
tensor::TensorPtr input_tensor_;
};
class NewParameter : public Pattern {
public:
NewParameter() { unique_name_ = std::to_string(g_id_++); }
explicit NewParameter(string para_name, tensor::TensorPtr default_tensor, bool requires_grad, bool layerwise_parallel,
bool should_replace)
: para_name_(para_name), requires_grad_(requires_grad), layerwise_parallel_(layerwise_parallel) {
should_replace_ = should_replace;
unique_name_ = std::to_string(g_id_++) + "NewParameter_" + para_name;
// clone input tensor
default_tensor_ = std::make_shared<tensor::Tensor>(*default_tensor.get());
built_ = false;
}
MS_DECLARE_PARENT(NewParameter, Pattern);
MatchResultPtr match(const AnfNodePtr &node) override {
MS_LOG(EXCEPTION) << "Find NewParameter in pattern, NewParameter should only appear in the target.\n";
}
string para_name() { return para_name_; }
tensor::TensorPtr default_tensor() { return default_tensor_; }
bool requires_grad() { return requires_grad_; }
bool layerwise_parallel() { return layerwise_parallel_; }
bool built() { return built_; }
void set_built(bool built) { built_ = built; }
void reset() override { built_ = false; }
private:
string para_name_;
bool requires_grad_;
bool layerwise_parallel_;
bool built_;
tensor::TensorPtr default_tensor_;
};
class Imm : public Pattern {
public:
Imm() { unique_name_ = std::to_string(g_id_++); }
explicit Imm(int value) : value_(value) {
should_replace_ = false;
unique_name_ = std::to_string(g_id_++) + "Imm_" + std::to_string(value);
}
MS_DECLARE_PARENT(Imm, Pattern);
// NOTE: Doesn't support Imm in src pattern currently.
MatchResultPtr match(const AnfNodePtr &node) override { return nullptr; }
int value() { return value_; }
private:
int value_;
};
class MatchResult {
public:
MatchResult() {}
......
......@@ -21,13 +21,26 @@
#include "ir/func_graph.h"
#include "ir/manager.h"
#include "pybind_api/ir/primitive_py.h"
#include "ir/scalar.h"
#include "ir/graph_utils.h"
#include "pipeline/jit/parse/parse_base.h"
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "utils/info.h"
#include "debug/anf_ir_dump.h"
#include "debug/draw.h"
namespace mindspore {
namespace opt {
namespace python_pass {
namespace internal {
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res);
const char PARAMETER_MODULE[] = "mindspore.common.parameter";
const char PARAMETER_CLASS[] = "Parameter";
const char SET_PARAM[] = "__setattr__";
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph);
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res);
void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor::TensorPtr default_input,
bool requires_grad, bool layerwise_parallel);
std::string GetNodeRepr(AnfNodePtr node) {
if (node != nullptr) {
......@@ -42,8 +55,10 @@ std::string GetNodeRepr(AnfNodePtr node) {
repr += ")";
return repr;
}
if (node->isa<ValueNode>()) {
return GetValueNode(node)->ToString();
if (node->isa<Parameter>()) {
return "[Parameter]" + node->ToString();
} else if (node->isa<ValueNode>()) {
return "[Value]" + GetValueNode(node)->ToString();
}
return node->ToString();
}
......@@ -82,7 +97,7 @@ AnfNodePtr BuildNewTensor(const PatternPtr &pattern, const MatchResultPtr &res)
return std::make_shared<ValueNode>(input_tensor);
}
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res) {
AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &fg) {
auto call_with_pattern = pattern->cast<CallWithPtr>();
MS_EXCEPTION_IF_NULL(call_with_pattern);
auto prim = call_with_pattern->prim_value();
......@@ -91,15 +106,70 @@ AnfNodePtr BuildPrimitiveValueNode(const PatternPtr &pattern, const MatchResultP
}
auto prim_pattern = call_with_pattern->prim_pattern();
MS_EXCEPTION_IF_NULL(prim_pattern);
return ProcessSinglePattern(prim_pattern, res);
return ProcessSinglePattern(prim_pattern, res, fg);
}
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res) {
AnfNodePtr BuildNewParameter(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
auto new_para_pattern = pattern->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
if (!new_para_pattern->built()) {
static int parameter_id = 0;
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name() + std::to_string(parameter_id++);
auto para_node = std::make_shared<Parameter>(func_graph);
MS_EXCEPTION_IF_NULL(para_node);
para_node->set_name(para_name);
// Set function graph
para_node->set_func_graph(func_graph);
// Set Debug Info
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
para_node->set_debug_info(debug_info);
// Set abstract
auto default_value = new_para_pattern->default_tensor();
MS_EXCEPTION_IF_NULL(default_value);
para_node->set_abstract(default_value->ToAbstract()->Broaden());
res->add_entry(pattern, para_node);
func_graph->add_parameter(para_node);
// Reflect back to Cell._params
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
new_para_pattern->layerwise_parallel());
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
new_para_pattern->set_built(true);
return para_node;
} else {
// Built, fetch the node
auto para_node = res->get_node(pattern);
MS_EXCEPTION_IF_NULL(para_node);
return para_node;
}
}
AnfNodePtr BuildImmNode(const PatternPtr &pattern, const MatchResultPtr &res) {
auto imm_pattern = pattern->cast<ImmPtr>();
MS_EXCEPTION_IF_NULL(imm_pattern);
auto value = imm_pattern->value();
auto scalar_value_ptr = std::make_shared<Int32Imm>(value);
return std::make_shared<ValueNode>(scalar_value_ptr);
}
AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr &res, const FuncGraphPtr &func_graph) {
if (pattern->should_replace()) {
// Find replacement in the MatchResult
auto target_node = res->get_node(pattern);
if (target_node == nullptr) {
MS_LOG(EXCEPTION) << "Cannot find target node in pattern match result, pattern: " + pattern->unique_name() + "\n";
// If it's base pattern(in contrast to complex pattern like CallWith/IsIn/IsNot), raise runtime exception.
if (pattern->isa<IsPrimTypeOf>() || pattern->isa<NewTensor>() || pattern->isa<NewParameter>()) {
MS_LOG(EXCEPTION) << "Cannot find target node, pattern: " + pattern->unique_name() + "\n";
return nullptr;
}
// Try to build this pattern and add to MatchResult, since this pattern is defined inside target
auto new_node = BuildTarget(pattern, func_graph, res);
if (new_node == nullptr) {
MS_LOG(EXCEPTION) << "Try to build pattern node but FAILED. pattern: " + pattern->unique_name() + "\n";
}
return new_node;
}
if (pattern->isa<NewParameter>()) {
return target_node;
}
return target_node;
}
......@@ -109,7 +179,19 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
} else if (pattern->isa<NewTensor>()) {
return BuildNewTensor(pattern, res);
} else if (pattern->isa<CallWith>()) {
return BuildPrimitiveValueNode(pattern, res);
return BuildPrimitiveValueNode(pattern, res, func_graph);
} else if (pattern->isa<NewParameter>()) {
return BuildNewParameter(pattern, res, func_graph);
} else if (pattern->isa<Imm>()) {
return BuildImmNode(pattern, res);
}
return nullptr;
}
AnfNodePtr ProcessComplexPatternFirstInput(const PatternPtr &pattern, const MatchResultPtr &res,
const FuncGraphPtr &func_graph) {
if (pattern->isa<CallWith>()) {
return BuildPrimitiveValueNode(pattern, res, func_graph);
}
return nullptr;
}
......@@ -117,91 +199,154 @@ AnfNodePtr ProcessSinglePattern(const PatternPtr &pattern, const MatchResultPtr
AnfNodePtr BuildTarget(const PatternPtr &pattern, const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
auto target_inputs = pattern->inputs();
if (target_inputs.size() == 0) {
return ProcessSinglePattern(pattern, res);
auto new_node = ProcessSinglePattern(pattern, res, func_graph);
if (new_node != nullptr) {
res->add_entry(pattern, new_node);
}
return new_node;
}
// Build up the AnfNode in a recursive manner
std::vector<AnfNodePtr> new_inputs;
auto prim_value_node = ProcessSinglePattern(pattern, res);
auto prim_value_node = ProcessComplexPatternFirstInput(pattern, res, func_graph);
MS_EXCEPTION_IF_NULL(prim_value_node);
new_inputs.push_back(prim_value_node);
for (auto &iter : target_inputs) {
if (iter == pattern) {
MS_LOG(EXCEPTION) << "Circle references: Pattern takes itself as input. Got pattern: " + pattern->unique_name() +
"\n";
MS_LOG(EXCEPTION) << "Circle references. Got pattern: " + pattern->unique_name() + "\n";
}
auto input_node = BuildTarget(iter, func_graph, res);
if (input_node == nullptr) {
MS_LOG(EXCEPTION) << "Failed to build input node for pattern : " + iter->unique_name() + "\n";
}
new_inputs.push_back(input_node);
}
auto new_node = func_graph->NewCNode(new_inputs);
res->add_entry(pattern, new_node);
return new_node;
}
void DrawNode(string name, AnfNodePtr node) {
auto context_ptr = MsContext::GetInstance();
bool save_graphs = context_ptr->save_graphs_flag();
auto save_graphs_path = context_ptr->save_graphs_path();
if (save_graphs_path.empty()) {
save_graphs_path = ".";
}
auto new_func_graph = std::make_shared<FuncGraph>();
new_func_graph->set_output(node, true);
if (save_graphs) {
auto ir_dump_path = save_graphs_path + "/" + name + ".ir";
auto dot_dump_path = save_graphs_path + "/" + name + ".dot";
DumpIR(ir_dump_path, new_func_graph);
draw::Draw(dot_dump_path, new_func_graph);
}
}
void ReflectParamBackToPython(const AnfNodePtr &param, string param_name, tensor::TensorPtr default_input,
bool requires_grad, bool layerwise_parallel) {
// 1. Get current cell object
auto ppm = opt::python_pass::PyPassManager::GetInstance();
auto resource = ppm->GetResource();
py::object top_cell = resource->input();
if (py::isinstance<py::none>(top_cell)) {
MS_LOG(EXCEPTION) << "Failed to get top cell from resource.";
}
// 2. New a Parameter object with the above-specified args
py::object parameter_class = py::module::import(PARAMETER_MODULE).attr(PARAMETER_CLASS);
py::object new_parameter = parameter_class(default_input, param_name, requires_grad, layerwise_parallel);
// 3. Add the new python Parameter object to Cell's _params atttributes
top_cell.attr(SET_PARAM)(param_name, new_parameter);
// 4. Set default_param for param_node
ValuePtr param_value = nullptr;
bool converted = parse::ConvertData(new_parameter, &param_value, false);
if (!converted) {
MS_LOG(EXCEPTION) << "Failed to convert new parameter to ValuePtr.";
}
MS_EXCEPTION_IF_NULL(param);
auto param_node = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_node);
param_node->set_default_param(param_value);
}
void Reset(PatternPtr pattern) {
if (pattern->isa<IsPrimTypeOf>()) {
auto prim_pattern = pattern->cast<IsPrimTypeOfPtr>();
prim_pattern->reset();
return;
} else if (pattern->isa<NewParameter>()) {
auto new_param_pattern = pattern->cast<NewParameterPtr>();
new_param_pattern->reset();
return;
} else if (pattern->isa<CallWith>()) {
auto call_with_pattern = pattern->cast<CallWithPtr>();
for (auto sub_pattern : call_with_pattern->inputs()) {
Reset(sub_pattern);
}
new_inputs.push_back(BuildTarget(iter, func_graph, res));
return;
}
return func_graph->NewCNode(new_inputs);
return;
}
} // namespace internal
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(src_pattern_);
MS_EXCEPTION_IF_NULL(dst_pattern_);
auto res = src_pattern_->match(node);
if (res != nullptr) {
res->dump();
MS_LOG(WARNING) << "Matched pattern: " + src_pattern_->unique_name();
AnfNodePtr PythonPass::Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res) {
auto match_res = src_pattern_->match(node);
if (match_res != nullptr) {
MS_LOG(DEBUG) << "Matched pattern: " + src_pattern_->unique_name() + " node : " + internal::GetNodeRepr(node);
res->merge(match_res);
auto new_node = internal::BuildTarget(dst_pattern_, func_graph, res);
dst_pattern_->reset();
MS_LOG(DEBUG) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
internal::Reset(dst_pattern());
MS_LOG(WARNING) << "To be replaced node: " + internal::GetNodeRepr(new_node) + "\n";
return new_node;
}
src_pattern_->reset();
internal::Reset(src_pattern());
return nullptr;
}
bool PythonPass::Run(const FuncGraphPtr &func_graph) {
bool PythonPass::Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res) {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(dst_pattern_);
if (src_pattern_ == nullptr) {
// Add NewParameter
auto new_para_pattern = dst_pattern_->cast<NewParameterPtr>();
if (new_para_pattern == nullptr) {
MS_LOG(EXCEPTION) << "Expect NewParameter pattern for target if src pattern is null.";
}
auto para_name = new_para_pattern->para_name() + new_para_pattern->unique_name();
MS_LOG(DEBUG) << "Adding New parameter : " + para_name;
auto para_node = std::make_shared<Parameter>(func_graph);
MS_EXCEPTION_IF_NULL(para_node);
para_node->set_name(para_name);
// Set function graph
para_node->set_func_graph(func_graph);
// Set Debug Info
auto debug_info = std::make_shared<NodeDebugInfo>(para_name);
para_node->set_debug_info(debug_info);
// Set abstract
auto default_value = new_para_pattern->default_tensor();
MS_EXCEPTION_IF_NULL(default_value);
para_node->set_abstract(default_value->ToAbstract()->Broaden());
res->add_entry(dst_pattern_, para_node);
func_graph->add_parameter(para_node);
// Reflect back to Cell._params
internal::ReflectParamBackToPython(para_node, para_name, default_value, new_para_pattern->requires_grad(),
new_para_pattern->layerwise_parallel());
MS_LOG(WARNING) << "Adding parameter: " + para_node->ToString() + " parameter name:" + para_node->name();
return true;
}
FuncGraphManagerPtr manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
manager->AddFuncGraph(func_graph);
auto seen = NewSeenGeneration();
// 1024 is for the initial capacity of deque
std::deque<AnfNodePtr> todo(1024);
todo.push_back(func_graph->output());
auto graph_nodes_sorted = TopoSort(func_graph->output());
bool changes = false;
auto &all_nodes = manager->all_nodes();
while (!todo.empty()) {
AnfNodePtr node = todo.front();
todo.pop_front();
// Check whether this node has been matched.
if (node == nullptr || node->seen_ == seen || !internal::IsTraversable(node) || !all_nodes.contains(node)) {
continue;
}
node->seen_ = seen;
// Select nodes that this transform can be applied.
AnfNodePtr new_node = Run(func_graph, node);
bool change = (new_node != nullptr);
// Traverse once
for (auto &node : graph_nodes_sorted) {
AnfNodePtr new_node = Run(func_graph, node, res);
if (new_node != nullptr && new_node != node) {
internal::DrawNode(dst_pattern_->unique_name(), new_node);
(void)manager->Replace(node, new_node);
} else if (new_node == nullptr) {
new_node = node;
}
if (run_only_once_) {
return change;
}
// Find success, and add them to todo list
if (IsValueNode<FuncGraph>(node)) {
todo.push_back(GetValueNode<FuncGraphPtr>(node)->output());
}
if (node->isa<CNode>()) {
auto &inputs = node->cast<CNodePtr>()->inputs();
(void)std::copy(inputs.begin(), inputs.end(), std::back_inserter(todo));
}
auto &node_users = manager->node_users();
if (change && node_users.find(node) != node_users.end()) {
for (auto &use : node_users[node]) {
auto use_node = use.first;
if (use_node == nullptr) {
continue;
}
todo.push_back(use_node);
if (use_node->seen_ == seen) {
use_node->seen_--;
}
}
changes = true;
}
}
return changes;
......
......@@ -34,20 +34,20 @@ using NodeEquivPtr = std::shared_ptr<NodeEquiv>;
class PythonPass {
public:
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false,
bool multigraph = true)
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once), multigraph_(multigraph) {}
explicit PythonPass(const std::string &name, const PatternPtr &src, const PatternPtr &dst, bool run_only_once = false)
: src_pattern_(src), dst_pattern_(dst), name_(name), run_only_once_(run_only_once) {}
~PythonPass() = default;
bool Run(const FuncGraphPtr &func_graph);
bool Run(const FuncGraphPtr &func_graph, const MatchResultPtr &res);
std::string name() const { return name_; }
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node);
AnfNodePtr Run(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const MatchResultPtr &res);
PatternPtr src_pattern() { return src_pattern_; }
PatternPtr dst_pattern() { return dst_pattern_; }
private:
PatternPtr src_pattern_;
PatternPtr dst_pattern_;
const std::string name_;
bool run_only_once_;
bool multigraph_ = true;
};
using PythonPassPtr = std::shared_ptr<PythonPass>;
......
......@@ -45,14 +45,19 @@ PyPassManagerPtr PyPassManager::GetInstance() {
PyPassManager::PyPassManager() {
phase_to_group_[Phase::RESOLVE] = std::make_shared<PassGroup>();
phase_to_group_[Phase::OPT] = std::make_shared<PassGroup>();
res_ = std::make_shared<MatchResult>();
}
void PyPassManager::Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase, bool run_only_once, bool multigraph) {
auto cur_pm = GetPassGroup(phase);
MS_EXCEPTION_IF_NULL(cur_pm);
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once, multigraph);
cur_pm->AddPass(new_pass);
Phase phase, bool run_only_once) {
auto cur_pg = GetPassGroup(phase);
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(run_only_once);
MS_EXCEPTION_IF_NULL(pattern);
MS_EXCEPTION_IF_NULL(target);
MS_EXCEPTION_IF_NULL(cur_pg);
PythonPassPtr new_pass = std::make_shared<PythonPass>(pass_name, pattern, target, run_only_once);
cur_pg->AddPass(new_pass);
}
void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
......@@ -63,6 +68,21 @@ void PyPassManager::Unregiste(const std::string &pass_name, Phase phase) {
}
}
void PyPassManager::GenNewParameter(const PatternPtr &parameter) {
MS_EXCEPTION_IF_NULL(parameter);
// Add new parameter after resolve
// NOTE: Add NewParameter at early stage will cause CSE problems
auto cur_pg = GetPassGroup(Phase::OPT);
MS_EXCEPTION_IF_NULL(cur_pg);
cur_pg->SetRunOnlyOnce(true);
auto new_para_pattern = parameter->cast<NewParameterPtr>();
MS_EXCEPTION_IF_NULL(new_para_pattern);
auto pass_name = new_para_pattern->para_name();
parameter->set_should_replace(false);
auto new_pass = std::make_shared<PythonPass>(pass_name, nullptr, parameter, true);
cur_pg->AddPass(new_pass);
}
void PyPassManager::ClearRes() {
MS_LOG(INFO) << "Clear PyPassManager resources!";
global_instance = nullptr;
......@@ -75,7 +95,9 @@ REGISTER_PYBIND_DEFINE(
(void)py::class_<PyPassManager, std::shared_ptr<PyPassManager>>(*m, "PyPassManager_")
.def(py::init([]() { return PyPassManager::GetInstance(); }))
.def("registe", &PyPassManager::Registe, "Registe python pass")
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass");
.def("unregiste", &PyPassManager::Unregiste, "Delete Python Pass")
.def("gen_new_parameter", &PyPassManager::GenNewParameter, "Generate new parameter")
.def("set_renorm", &PyPassManager::SetRenorm, "Set whether or not to do renorm after modified graph");
}));
} // namespace python_pass
} // namespace opt
......
......@@ -27,7 +27,7 @@
#include "ir/graph_utils.h"
#include "utils/ms_utils.h"
#include "pipeline/jit/parse/resolve.h"
#include "pipeline/jit/resource.h"
#include "frontend/optimizer/pattern.h"
#include "frontend/optimizer/py_pass.h"
#include "frontend/optimizer/pass_group.h"
......@@ -53,12 +53,21 @@ class PyPassManager {
static PyPassManagerPtr GetInstance();
virtual ~PyPassManager() = default;
void Registe(const std::string &pass_name, const PatternPtr &pattern, const PatternPtr &target,
Phase phase = Phase::RESOLVE, bool run_only_once = false, bool multigraph = true);
Phase phase = Phase::RESOLVE, bool run_only_once = false);
void Unregiste(const std::string &pass_name, Phase phase);
void GenNewParameter(const PatternPtr &parameter);
PassGroupPtr GetPassGroup(Phase phase);
void ClearRes();
MatchResultPtr GetMatchResult() { return res_; }
void SetRenorm(bool should_renorm) { should_renorm_ = should_renorm; }
bool ShouldRenorm() { return should_renorm_; }
void SetResource(pipeline::ResourcePtr resource) { resource_ = resource; }
pipeline::ResourcePtr GetResource() { return resource_; }
private:
bool should_renorm_ = true;
MatchResultPtr res_;
pipeline::ResourcePtr resource_;
static std::unordered_map<Phase, PassGroupPtr> phase_to_group_;
};
} // namespace python_pass
......
......@@ -448,8 +448,21 @@ void ActionPyStub(const ResourcePtr &res, opt::python_pass::Phase phase) {
MS_EXCEPTION_IF_NULL(res->manager());
MS_EXCEPTION_IF_NULL(res->func_graph());
auto ppm = opt::python_pass::PyPassManager::GetInstance();
ppm->SetResource(res);
if (!ppm->GetPassGroup(phase)->Run(res->func_graph())) {
MS_LOG(DEBUG) << "No match.\n";
} else if (phase == opt::python_pass::Phase::OPT && opt::python_pass::PyPassManager::GetInstance()->ShouldRenorm()) {
MS_LOG(DEBUG) << "Entered PyStub Renorm";
// Renomalize
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
abstract::AbstractBasePtrList args_spec;
auto parameters = func_graph->parameters();
(void)std::transform(parameters.begin(), parameters.end(), std::back_inserter(args_spec),
[](const AnfNodePtr &p) -> AbstractBasePtr { return p->abstract(); });
FuncGraphPtr new_fg = Renormalize(res, func_graph, args_spec);
res->set_func_graph(new_fg);
res->set_args_spec(args_spec);
}
}
......@@ -477,6 +490,7 @@ static std::vector<ActionItem> CommonPipeline() {
}
// Add resolve-stage python pass stub
actions.emplace_back(std::make_pair("py_resolve", ResolveActionPyStub));
actions.emplace_back(std::make_pair("inference_opt_prepare", InferenceOptPrepareAction));
// Evaluate type and shape, and specialize
actions.emplace_back(std::make_pair("abstract_specialize", AbstractSpecializeAction));
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""
......@@ -15,7 +15,8 @@
"""Patterns for describing graphs"""
from mindspore.ops import Primitive
from mindspore.common.tensor import Tensor
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_
from mindspore._c_expression import Pattern, IsIn_, IsPrimTypeOf_, CallWith_, IsNot_, AnyPattern, NewTensor_,\
NewParameter_, Imm
__all__ = [
"IsIn",
......@@ -24,17 +25,25 @@ __all__ = [
"IsNot",
"AnyPattern",
"NewTensor",
"NewParameter",
"Imm"
]
class IsIn(IsIn_):
"""
r"""
Express a pattern which allows a list of patterns.
"""
def __init__(self, patterns=None, should_replace=True):
r"""
Args:
patterns(list/tuple): list of allowed patterns
patterns(Union[tuple[:class:`mindspore.graph_utils.graph_pattern`],
list[:class:`mindspore.graph_utils.graph_pattern`]]): list of allowed patterns,
each element should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False
TypeError: raise type error for invalid inputs.
"""
if not should_replace:
raise ValueError("IsIn pattern does not have its own should_replace attribute. Set should_replace in \
......@@ -52,19 +61,28 @@ class IsIn(IsIn_):
class IsPrimTypeOf(IsPrimTypeOf_):
r"""
Express a pattern of certain primitive type(s).
NOTE: This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
NOTE:
This pattern will match and only match the primitive value node. If matching primitive CNode is needed,
please refer to CallWith pattern.
"""
def __init__(self, types, name=None, should_replace=True):
r"""
Args:
types (str/(list/tuple of Primitives)): Specify allowed types.
types (Union[str, :class:`mindspore.ops.Primitive`, list[:class:`mindspore.ops.Primitive`],
tuple[:class:`mindspore.ops.Primitive`]):
Specify allowed types.
If it is a string, the form could be
1) a single primitive type, e.g. 'Conv2D'
2) a set of primitive types separated by '|', e.g. 'MatMul|Conv2D'
It can also be a list of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional
should_replace
It can also be a Primitive or a list/tuple of Primitives, e.g. [ops.Conv2D(1, 6)]
name (str): name of the pattern, optional. Default: None.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
"""
if name is not None and not isinstance(name, str):
raise TypeError(f"Expect string, got : {name}")
......@@ -91,12 +109,21 @@ class CallWith(CallWith_):
r"""
Express a primitive CNode.
"""
def __init__(self, prim_pattern, inputs=None, should_replace=False):
def __init__(self, prim_pattern, inputs=None, should_replace=True):
r"""
Args:
prim_pattern (Pattern/Primitive/str): Primitive ValueNode in the Primitive CNode.
inputs (list/tuple): Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs;
if specified, input patterns should be of right order.
prim_pattern (Union[str, :class:`mindspore.graph_utils.graph_pattern.IsPrimTypeOf`,
:class:`mindspore.ops.Primitive`]): Primitive ValueNode in the Primitive CNode.
inputs (Union[list[:class:`mindspore.graph_utils.graph_pattern`],
tuple[:class:`mindspore.graph_utils.graph_pattern`]]):
Specify inputs pattern for the primitive(s), optional. If None, accepts any inputs; if specified, input
patterns should be of right order and each element should be one of the exposed Pattern instance.
should_replace(bool): If pattern is part of the pass replacement target, this would set how this pattern is
used when building the replacement target node. Use captured node if True, build from scratch otherwise.
Default: True.
Raises:
TypeError: raise type error for invalid argument.
"""
if not isinstance(prim_pattern, (Pattern, str, Primitive)):
raise TypeError(f"Expect prim_pattern to be Pattern, Primitive or string, got : {prim_pattern}")
......@@ -110,17 +137,23 @@ class CallWith(CallWith_):
raise TypeError(f"Expect inputs to be a list of Patterns, got : {inputs}")
CallWith_.__init__(self, self.prim_pattern, self.inputs, should_replace)
class IsNot(IsNot_):
r"""
Express a pattern which forbids a list of patterns.
NOTE: IsNot pattern should not be the root pattern.
NOTE:
IsNot pattern should not be the root pattern.
"""
def __init__(self, patterns=None, should_replace=True):
r"""
Args:
patterns(list/tuple): list of forbiden patterns
patterns(Union[list[:class:`mindspore.graph_utils.graph_pattern`]]: list of forbiden patterns, each element
should be one of the exposed Pattern instance.
should_replace(bool): added this for interface consistency. Should only set this in sub-patterns.
Raises:
ValueError: raise if should_replace is False.
TypeError: raise type error for invalid argument.
"""
if not should_replace:
raise ValueError("IsNot pattern does not have its own should_replace attribute. Set should_replace in \
......@@ -142,13 +175,48 @@ class NewTensor(NewTensor_):
def __init__(self, input_tensor, should_replace=False):
r"""
Args:
input_tensor(Tensor): new tensor to be used in the target
input_tensor(:class:`mindspore.common.tensor.Tensor`): new tensor to be used in the target
should_replace(bool): added this for interface consistency. NewTensor should only appear in the target.
Raises:
ValueError: raise if should_replace is True
TypeError: raise type error for invalid argument.
"""
if should_replace:
raise ValueError("NewTensor should only appear in the target, thus should_replace can onlyu be False.")
raise ValueError("NewTensor should only appear in the target, thus should_replace can only be False.")
self.input_tensor = input_tensor
if isinstance(input_tensor, Tensor):
NewTensor_.__init__(self, input_tensor)
else:
raise TypeError(f"Expect input_tensor to be a Tensor, got : {input_tensor}")
class NewParameter(NewParameter_):
r"""
New Parameter to be used in the target.
"""
def __init__(self, para_name, default_tensor, requires_grad=False, layerwise_parallel=False, should_replace=False):
r"""
Args:
para_name(str): name for the new Parameter
default_tensor(:class:`mindspore.common.tensor.Tensor`): default value for the new Parameter
requires_grad(bool): True if the parameter requires gradient. Default: True
layerwise_parallel(bool): switch for layerwise parallel mode. Default: False
should_replace(bool): gen new parameter once and replace after if set to be true; otherwise build a new
parameter everytime a pass target got built. Default: False
Raises:
TypeError: raise type error for invalid argument.
"""
self.para_name = para_name
self.default_tensor = default_tensor
self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel
self.should_replace = should_replace
if isinstance(para_name, str) and isinstance(default_tensor, Tensor) and isinstance(requires_grad, bool) and\
isinstance(layerwise_parallel, bool) and isinstance(should_replace, bool):
NewParameter_.__init__(self, self.para_name, self.default_tensor, self.requires_grad,
self.layerwise_parallel, self.should_replace)
else:
raise TypeError(f"Expect para_name(str), default_tensor(Tensor), requires_grad(bool), \
layerwise_parallel(bool) should_replace(bool), got : {para_name}, {default_tensor}, \
{requires_grad}, {layerwise_parallel}, {should_replace}")
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Top-level reference to python pass."""
from .python_pass_register import registe_pass, unregiste_pass, gen_new_parameter, cancel_new_parameter, set_renorm
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm"
]
......@@ -14,10 +14,17 @@
# ============================================================================
"""Python pass register"""
from inspect import isfunction
from mindspore.common.graph_pattern import Pattern
from mindspore._c_expression import PyPassManager_
from mindspore._c_expression import phase
from mindspore.graph_utils.graph_pattern import Pattern, NewParameter
from mindspore._c_expression import PyPassManager_, phase
__all__ = [
"registe_pass",
"unregiste_pass",
"gen_new_parameter",
"cancel_new_parameter",
"set_renorm"
]
class PyPassManager(PyPassManager_):
r"""
Used to registe and unregiste python passes which can be used to alter graphs.
......@@ -30,52 +37,134 @@ class PyPassManager(PyPassManager_):
Raises:
TypeError: If argument has invalid type.
"""
def __init__(self, pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
def __init__(self, pipeline_phase=phase.opt, run_only_once=False):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if not isinstance(run_only_once, bool):
raise TypeError(f"Expecting bool, got : ({type(run_only_once)}){run_only_once}")
if not isinstance(multi_graph, bool):
raise TypeError(f"Expecting bool, got : ({type(multi_graph)}){multi_graph}")
raise TypeError(f"Expect bool, got : ({type(run_only_once)}){run_only_once}")
PyPassManager_.__init__(self)
self.phase_ = pipeline_phase
self.run_only_once_ = run_only_once
self.multi_graph_ = multi_graph
def registe(self, py_pass):
if not isfunction(py_pass):
raise TypeError(f"Expecting function pass, got : ({type(py_pass)}){py_pass}")
raise TypeError(f"Expect function pass, got : ({type(py_pass)}){py_pass}")
pattern, target = py_pass()
pass_name = py_pass.__name__
if not isinstance(pattern, Pattern):
raise TypeError(f"Expecting pattern of Pattern type, got : ({type(pattern)}){pattern}")
raise TypeError(f"Expect pattern of Pattern type, got : ({type(pattern)}){pattern}")
if not isinstance(target, Pattern):
raise TypeError(f"Expecting target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_, self.multi_graph_)
raise TypeError(f"Expect target of Pattern type, got : ({type(target)}){target}")
super().registe(pass_name, pattern, target, self.phase_, self.run_only_once_)
def unregiste(self, py_pass, pipeline_phase=phase.opt):
if not isinstance(pipeline_phase, phase):
raise TypeError(f"Expecting phase, got : ({type(pipeline_phase)}){pipeline_phase}")
raise TypeError(f"Expect phase, got : ({type(pipeline_phase)}){pipeline_phase}")
if isinstance(py_pass, str):
super().unregiste(py_pass, pipeline_phase)
return
if isfunction(py_pass):
super().unregiste(py_pass.__name__, pipeline_phase)
return
raise TypeError(f"Expecting py_pass to be string or function, got ({type(py_pass)}){py_pass}")
raise TypeError(f"Expect py_pass to be string or function, got ({type(py_pass)}){py_pass}")
def __call__(self, py_pass):
self.registe(py_pass)
return py_pass
def registe_pass(pipeline_phase=phase.opt, run_only_once=False, multi_graph=True):
def gen_new_parameter(self, pattern):
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
super().gen_new_parameter(pattern)
def set_renorm(self, should_renorm):
if not isinstance(should_renorm, bool):
raise TypeError(f"Expect should_renorm to be a bool, got {should_renorm}")
super().set_renorm(should_renorm)
def registe_pass(pipeline_phase=phase.opt, run_only_once=False):
"""
Registe python pass to specified pipeline phase which would be used in compilation.
Args:
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
registed. Support phase.resolve and phase.opt. Default: phase.opt.
run_only_once(bool): Run this pass only once if set true. Otherwise run the pass until converge. Default: False.
Returns:
This function should be used as a decorator, return the decoratorated pass function.
Examples:
>>> from mindspore.graph_utils.graph_pattern import IsPrimTypeOf
>>> @registe_pass()
>>> def toy_pass():
>>> pattern = IsPrimTypeOf("ReLU")
>>> target = IsPrimTypeOf("ReLU6")
>>> return pattern, target
"""
return PyPassManager(pipeline_phase, run_only_once)
def unregiste_pass(py_pass, pipeline_phase=phase.opt):
"""
Unregiste python pass.
Args:
py_pass(Union(str, function)): target python pass to unregiste.
pipeline_phase(:class:`mindspore._c_expression.phase`): To which compilation pipeline stage the pass is
unregisted. Support phase.resolve and phase.opt. Default: phase.opt.
"""
ppm = PyPassManager()
ppm.unregiste(py_pass, pipeline_phase)
def gen_new_parameter(pattern):
"""
Generate specified parameter every time a network gets compiled.
NOTE:
In this way, every pass uses this pattern would be using the same Parameter. If use NewParameter without
gen_new_parameter, every pass match would build a new Parameter.
This would registe a pass to add new parameter in the compilation pipeline, so later compilation would
ALSO add this parameter unless the pass is unregisted. To unregiste this pass, call
cancel_new_parameter(pattern)
Args:
pattern (NewParameter): NewParameter type, could be used to build nested patterns across multiple passes
after gen_new_parameter.
Raises:
TypeError: If argument has invalid type.
Examples:
>>> from mindspore.graph_utils.graph_pattern import NewParameter
>>> abc = NewParameter("abc")
>>> gen_new_parameter(abc)
"""
ppm = PyPassManager()
ppm.gen_new_parameter(pattern)
def cancel_new_parameter(pattern):
"""
Use with gen_new_parameter to unregiste gen_new_parameter pass.
Args:
pattern (NewParameter): NewParameter type, cancel the pass which would add new parameter as this pattern
describes.
Examples:
>>> from mindspore.graph_utils.graph_pattern import NewParameter
>>> abc = NewParameter("abc")
>>> gen_new_parameter(abs)
>>> # some compilations
>>> cancel_new_parameter(abc)
"""
if not isinstance(pattern, NewParameter):
raise TypeError(f"Expect pattern to be a NewParameter Pattern, got {pattern}")
ppm = PyPassManager()
ppm.unregiste(pattern.para_name)
def set_renorm(should_renorm):
"""
Examples:
>>> @registe_pass()
>>> def toy_pass():
>>> def pattern():
>>> pass
>>> def target():
>>> pass
Set whether or not to do renorm after modified graph in python pass(es).
"""
return PyPassManager(pipeline_phase, run_only_once, multi_graph)
ppm = PyPassManager()
ppm.set_renorm(should_renorm)
......@@ -152,7 +152,7 @@ class Primitive(Primitive_):
Check if certain inputs should go to the backend. Subclass in need should override this method.
Args:
*args(Primitive args): Same as arguments of current Primitive.
args(Primitive args): Same as arguments of current Primitive.
Returns:
A tuple consisting of two elements. The first element indicates whether we should filter out current
......
......@@ -19,10 +19,12 @@ import mindspore.nn as nn
from mindspore import context
from mindspore.common.tensor import Tensor
from mindspore.ops import operations as P
from mindspore.common.python_pass_register import registe_pass, PyPassManager
from mindspore.graph_utils.python_pass import registe_pass, unregiste_pass, set_renorm, gen_new_parameter,\
cancel_new_parameter
from mindspore.common.api import _generate_pip_args
from mindspore._c_expression import generate_key, Executor_
from mindspore.common.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor
from mindspore.graph_utils.graph_pattern import IsIn, IsPrimTypeOf, CallWith, IsNot, AnyPattern, NewTensor,\
NewParameter, Imm
context.set_context(mode=context.GRAPH_MODE)
......@@ -56,12 +58,39 @@ def test_softmax_relu():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_relu_pass)
unregiste_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Softmax" not in transformed_repr
def test_isin_pattern():
def test_softmax_relu_sigmoid():
"""
Use python pass to transform from Softmax(x) to ReLU(Sigmoid(x)).
NOTE:
Sigmoid pattern only exists in the target.
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_relu_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
pattern = CallWith(softmax_pattern, inputs=[x])
sigmoid_pattern = IsPrimTypeOf(P.Sigmoid(), should_replace=False)
call_sigmoid = CallWith(sigmoid_pattern, [x])
relu_pattern = IsPrimTypeOf(P.ReLU(), should_replace=False)
target = CallWith(relu_pattern, inputs=[call_sigmoid])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(3)
unregiste_pass(softmax_relu_pass)
assert "ReLU" in transformed_repr
assert "Sigmoid" in transformed_repr
assert "Softmax" not in transformed_repr
def test_isin_pattern_0():
"""
Test IsIn pattern which expresses the IsIn/OneOf semantics.
"""
......@@ -81,16 +110,41 @@ def test_isin_pattern():
target = CallWith(relu6_pattern, inputs=[x])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_relu_pass)
unregiste_pass(softmax_relu_pass)
assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr
def test_isin_pattern_1():
"""
Test IsIn. IsIn is used as nested inputs for the target in this case.
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_neg_pass():
x = AnyPattern()
softmax_pattern = IsPrimTypeOf(P.Softmax())
call_softmax = CallWith(softmax_pattern, inputs=[x])
relu_pattern = IsPrimTypeOf(P.ReLU())
call_relu = CallWith(relu_pattern, inputs=[x])
pattern = IsIn([call_softmax, call_relu])
neg_ops = IsPrimTypeOf(P.Neg(), should_replace=False)
target = CallWith(neg_ops, inputs=[pattern])
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(4)
print(transformed_repr)
unregiste_pass(softmax_neg_pass)
assert "Neg" in transformed_repr
assert "Softmax" in transformed_repr
def test_isnot_pattern_0():
"""
Test IsNot pattern which expresses the IsNot semantics.
Case: IsNot pass failed to match
"""
set_renorm(False)
class ConvBN(nn.Cell):
def __init__(self):
super(ConvBN, self).__init__()
......@@ -132,11 +186,11 @@ def test_isnot_pattern_0():
return pattern, target
transformed_repr = get_func_graph(conv_bn_model, inputs).get_return().expanded_str(5)
ppm = PyPassManager()
ppm.unregiste(single_bn_pass)
ppm.unregiste(bn_pass)
unregiste_pass(single_bn_pass)
unregiste_pass(bn_pass)
assert "ReLU6" not in transformed_repr
assert "Softmax" in transformed_repr
set_renorm(True)
def test_isnot_pattern_1():
"""
......@@ -160,12 +214,15 @@ def test_isnot_pattern_1():
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
ppm = PyPassManager()
ppm.unregiste(single_bn_pass)
unregiste_pass(single_bn_pass)
assert "ReLU6" in transformed_repr
assert "Softmax" not in transformed_repr
def test_newtensor_pattern():
"""
Test NewTensor pattern in the target
"""
set_renorm(False)
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
......@@ -181,7 +238,84 @@ def test_newtensor_pattern():
target = CallWith(addn_ops, inputs=[x, new_weight], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(2)
ppm = PyPassManager()
ppm.unregiste(softmax_addn_pass)
unregiste_pass(softmax_addn_pass)
assert "AddN" in transformed_repr
assert "Softmax" not in transformed_repr
set_renorm(True)
def test_newparameter_pattern():
"""
Test NewParameter pattern in the target
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
default_tensor0 = Tensor(np.ones((4, 4)), mindspore.float32)
default_tensor1 = Tensor(np.ones((4, 4)), mindspore.float32)
new_para_0 = NewParameter("Merlin", default_tensor0)
new_para_1 = NewParameter("Arthur", default_tensor1)
target_0 = CallWith(P.MatMul(), inputs=[new_para_0, new_para_1], should_replace=False)
target = CallWith("make_tuple", inputs=[target_0], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass)
assert "MatMul" in transformed_repr
assert "make_tuple" in transformed_repr
assert "Softmax" not in transformed_repr
def test_imm_pattern():
"""
Test NewParameter pattern in the target
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
@registe_pass(run_only_once=True)
def softmax_addn_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
imm = Imm(0)
target_0 = CallWith("make_tuple", inputs=[pattern], should_replace=False)
target = CallWith("tuple_getitem", inputs=[target_0, imm], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
unregiste_pass(softmax_addn_pass)
assert "make_tuple" in transformed_repr
assert "tuple_getitem" in transformed_repr
assert "Softmax" in transformed_repr
def test_gen_new_parameter():
"""
Test gen_new_parameter
"""
inputs = Tensor(np.ones([42]), mindspore.float16)
softmax_model = nn.Softmax()
default_tensor = Tensor(np.ones((4, 4)), mindspore.float32)
new_para = NewParameter("Merlin", default_tensor, should_replace=True)
gen_new_parameter(new_para)
@registe_pass(run_only_once=True)
def softmax_make_tuple_pass():
x = AnyPattern()
softmax = P.Softmax()
pattern = CallWith(softmax, inputs=[x])
target = CallWith("make_tuple", inputs=[pattern, new_para], should_replace=False)
return pattern, target
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" in transformed_repr
unregiste_pass(softmax_make_tuple_pass)
cancel_new_parameter(new_para)
transformed_repr = get_func_graph(softmax_model, inputs).get_return().expanded_str(5)
print(transformed_repr)
assert "Merlin" not in transformed_repr
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册