提交 0099da2c 编写于 作者: H huangdongrun

add support for tuple parameter transform

add support for pynative pass

add testcases
上级 1d0e0ae2
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "frontend/optimizer/graph_transform.h"
#include <vector>
#include <algorithm>
#include <string>
#include "ir/graph_utils.h"
namespace mindspore {
/* namespace to support opt */
namespace opt {
// check cnode input values, whether it is tuple input
bool CNodeHasTupleInput(const CNodePtr &cnode) {
auto &inputs = cnode->inputs();
for (size_t i = 1; i < inputs.size(); i++) {
if (IsValueNode<FuncGraph>(inputs[i])) {
continue;
}
if (IsValueNode<Primitive>(inputs[i])) {
// unexpected high order primitvie as cnode input when transform graph
MS_LOG(WARNING) << "CheckTupleInput, got unexpected primitve as input" << cnode->DebugString();
return false;
}
auto abs = inputs[i]->abstract();
if (abs == nullptr) {
MS_LOG(WARNING) << "CheckTupleInput, got abstract nullptr for node:" << cnode->DebugString();
return false;
}
if (abs->isa<abstract::AbstractTuple>()) {
return true;
}
}
return false;
}
bool FuncGraphHasTupleInput(const FuncGraphPtr &fg) {
auto &params = fg->parameters();
for (auto &param : params) {
if (param->abstract()->isa<abstract::AbstractTuple>()) {
return true;
}
}
return false;
}
std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node,
const abstract::AbstractTuplePtr &abs) {
auto &elements = abs->elements();
std::vector<AnfNodePtr> tuple_node_expanded;
for (size_t i = 0; i < elements.size(); i++) {
auto elem_node = fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(SizeToInt(i))});
elem_node->set_abstract(elements[i]);
if (elements[i]->isa<abstract::AbstractTuple>()) {
auto nodes = TransformTupleArgument(fg, elem_node, elements[i]->cast<abstract::AbstractTuplePtr>());
tuple_node_expanded.insert(tuple_node_expanded.end(), nodes.begin(), nodes.end());
} else {
tuple_node_expanded.push_back(elem_node);
}
}
return tuple_node_expanded;
}
AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) {
auto &cinputs = cnode->inputs();
auto fg = cnode->func_graph();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(trans_fg));
for (size_t i = 1; i < cinputs.size(); i++) {
auto abs = cinputs[i]->abstract();
if (abs == nullptr) {
MS_LOG(EXCEPTION) << "TransformCallGraph:Node abstract should not be nullptr" << cinputs[i]->DebugString();
}
if (abs->isa<abstract::AbstractTuple>()) {
auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>());
inputs.insert(inputs.end(), nodes.begin(), nodes.end());
} else {
inputs.push_back(cinputs[i]);
}
}
auto new_node = fg->NewCNode(inputs);
new_node->set_abstract(cnode->abstract());
return new_node;
}
AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode) {
auto &cinputs = cnode->inputs();
auto fg = cnode->func_graph();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimPartial));
inputs.push_back(NewValueNode(trans_fg));
for (size_t i = 2; i < cinputs.size(); i++) {
auto abs = cinputs[i]->abstract();
if (abs == nullptr) {
MS_LOG(EXCEPTION) << "TransformPartial:Node abstract should not be nullptr" << cinputs[i]->DebugString();
}
if (abs->isa<abstract::AbstractTuple>()) {
auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>());
inputs.insert(inputs.end(), nodes.begin(), nodes.end());
} else {
inputs.push_back(cinputs[i]);
}
}
auto new_node = fg->NewCNode(inputs);
new_node->set_abstract(cnode->abstract());
return new_node;
}
AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode) {
auto &cinputs = cnode->inputs();
auto fg = cnode->func_graph();
std::vector<AnfNodePtr> inputs;
inputs.push_back(swtich_node);
for (size_t i = 1; i < cinputs.size(); i++) {
auto abs = cinputs[i]->abstract();
if (abs == nullptr) {
MS_LOG(EXCEPTION) << "TransformSwitchCall:Node abstract should not be nullptr" << cinputs[i]->DebugString();
}
if (abs->isa<abstract::AbstractTuple>()) {
auto nodes = TransformTupleArgument(fg, cinputs[i], abs->cast<abstract::AbstractTuplePtr>());
inputs.insert(inputs.end(), nodes.begin(), nodes.end());
} else {
inputs.push_back(cinputs[i]);
}
}
auto new_node = fg->NewCNode(inputs);
new_node->set_abstract(cnode->abstract());
return new_node;
}
} // namespace opt
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
#include <unordered_map>
#include <string>
#include <vector>
#include <algorithm>
#include <memory>
#include "frontend/optimizer/optimizer.h"
namespace mindspore {
namespace opt {
bool CNodeHasTupleInput(const CNodePtr &cnode);
bool FuncGraphHasTupleInput(const FuncGraphPtr &fg);
std::vector<AnfNodePtr> TransformTupleArgument(const FuncGraphPtr &fg, const AnfNodePtr &node,
const abstract::AbstractTuplePtr &abs);
AnfNodePtr TransformCallGraph(const FuncGraphPtr &trans_fg, const CNodePtr &cnode);
AnfNodePtr TransformPartial(const FuncGraphPtr &trans_fg, const CNodePtr &cnode);
AnfNodePtr TransformSwitchCall(const AnfNodePtr &swtich_node, const CNodePtr &cnode);
class GraphTupleParamTransform {
public:
GraphTupleParamTransform() : cache_() {}
~GraphTupleParamTransform() { cache_.clear(); }
FuncGraphPtr operator()(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
if (cache_.find(fg) != cache_.end()) {
return cache_[fg];
}
auto new_fg = TransformGraphParam(fg, mng);
cache_[fg] = new_fg;
return new_fg;
}
AnfNodePtr GenerateTupleParams(const abstract::AbstractTuplePtr &tuple_abs, const FuncGraphPtr &fg,
std::vector<AnfNodePtr> *params) {
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto &elements = tuple_abs->elements();
for (auto &item : elements) {
if (item->isa<abstract::AbstractTuple>()) {
inputs.push_back(GenerateTupleParams(item->cast<abstract::AbstractTuplePtr>(), fg, params));
} else {
auto p = std::make_shared<Parameter>(fg);
p->set_abstract(item);
params->push_back(p);
inputs.push_back(params->back());
}
}
auto node = fg->NewCNode(inputs);
node->set_abstract(tuple_abs);
return node;
}
FuncGraphPtr TransformGraphParam(const FuncGraphPtr &fg, const FuncGraphManagerPtr &mng) {
Cloner cloner({fg}, false, false, false, std::make_shared<TraceCopy>(), std::make_shared<TraceCopy>());
auto new_fg = cloner[fg];
auto &params = new_fg->parameters();
std::vector<AnfNodePtr> new_params;
std::unordered_map<AnfNodePtr, AnfNodePtr> repl;
for (auto &param : params) {
auto abs = param->abstract();
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
auto tuple_abs = abs->cast<abstract::AbstractTuplePtr>();
std::vector<AnfNodePtr> tuple_params;
repl.emplace(param, GenerateTupleParams(tuple_abs, new_fg, &tuple_params));
std::transform(tuple_params.begin(), tuple_params.end(), std::back_inserter(new_params),
[](AnfNodePtr p) { return p; });
} else {
new_params.push_back(param);
}
}
auto tmp_mng = mindspore::Manage(new_fg, false);
auto tr = tmp_mng->Transact();
for (auto &item : repl) {
bool ret = tr.Replace(item.first, item.second);
if (ret == false) {
MS_LOG(ERROR) << "replace failed" << item.first->DebugString() << " with__" << item.second->DebugString(2);
}
}
tr.SetParameters(new_fg, new_params);
tr.Commit();
mng->AddFuncGraph(new_fg);
return new_fg;
}
std::unordered_map<FuncGraphPtr, FuncGraphPtr> cache_;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_GRAPH_TRANSFORM_H
......@@ -44,6 +44,7 @@
#include "frontend/optimizer/irpass/row_tensor_eliminate.h"
#include "frontend/optimizer/irpass/sparse_tensor_eliminate.h"
#include "frontend/optimizer/irpass/switch_layer_defer_inline.h"
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
namespace mindspore {
namespace opt {
......@@ -158,6 +159,10 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
unused_output_eliminate_ =
MakeSubstitution(std::make_shared<UnusedOutputEliminater>(), "unused_output_eliminate", IsCNodeGraphKernel);
// tuple parameter graph transform
call_graph_tuple_transform_ =
MakeSubstitution(std::make_shared<CallGraphTupleTransform>(), "graph_param_transorm", IsCNode);
// AddN eliminate
addn_eliminate_ = MakeSubstitution(std::make_shared<AddNEliminater>(), "addn_eliminate", IsCNodeGraphKernel);
......
......@@ -103,6 +103,9 @@ class OptimizeIRPassLib {
SubstitutionPtr unused_parameter_eliminate_;
SubstitutionPtr unused_output_eliminate_;
// tuple parameter graph transform
SubstitutionPtr call_graph_tuple_transform_;
// AddN eliminate
SubstitutionPtr addn_eliminate_;
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
#include <algorithm>
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/graph_transform.h"
namespace mindspore {
namespace opt {
namespace irpass {
// {G, Xs}-->transform graph call tuple inputs to flat inputs.
class GraphCallTupleTransform : public AnfVisitor {
public:
explicit GraphCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
~GraphCallTupleTransform() override = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto &inputs = cnode->inputs();
auto fg = GetValueNode<FuncGraphPtr>(inputs[0]);
if (fg == nullptr) {
return nullptr;
}
if (!CNodeHasTupleInput(node->cast<CNodePtr>())) {
return nullptr;
}
FuncGraphPtr transformed_fg = graph_transform_(fg, optimizer->manager());
auto new_node = TransformCallGraph(transformed_fg, node->cast<CNodePtr>());
return new_node;
}
private:
GraphTupleParamTransform &graph_transform_;
};
// {{switch, cond, true_branch, false_branch}, Xs} -->transform switch graph call tuple inputs to flat inputs.
class SwitchCallTupleTransform : public AnfVisitor {
public:
explicit SwitchCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
~SwitchCallTupleTransform() override = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto switch_call_cnode = node->cast<CNodePtr>();
auto call_inputs = switch_call_cnode->inputs();
if (call_inputs.size() < 1) {
return nullptr;
}
if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitch)) {
return nullptr;
}
auto swich_cnode = call_inputs[0]->cast<CNodePtr>();
auto switch_inputs = swich_cnode->inputs();
if (switch_inputs.size() != 4) {
return nullptr;
}
AnfNodePtr transformed = nullptr;
bool true_br_changed = TransformBranchNode(switch_inputs[2], optimizer->manager(), &transformed);
if (true_br_changed) {
switch_inputs[2] = transformed;
}
bool false_br_changed = TransformBranchNode(switch_inputs[3], optimizer->manager(), &transformed);
if (false_br_changed) {
switch_inputs[3] = transformed;
}
if (true_br_changed || false_br_changed) {
call_inputs[0] = swich_cnode->func_graph()->NewCNode(switch_inputs);
}
if (CNodeHasTupleInput(switch_call_cnode)) {
return TransformSwitchCall(call_inputs[0], switch_call_cnode);
}
if (true_br_changed || false_br_changed) {
return switch_call_cnode->func_graph()->NewCNode(call_inputs);
}
return nullptr;
}
bool TransformBranchNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) {
if (IsValueNode<FuncGraph>(node)) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(node);
if (FuncGraphHasTupleInput(fg)) {
FuncGraphPtr transformed_fg = graph_transform_(fg, mng);
*trans_node = NewValueNode(transformed_fg);
return true;
}
return false;
}
if (IsPrimitiveCNode(node, prim::kPrimPartial)) {
auto partial_inputs = node->cast<CNodePtr>()->inputs();
if (IsValueNode<FuncGraph>(partial_inputs[1])) {
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(partial_inputs[1]);
if (FuncGraphHasTupleInput(fg)) {
fg = graph_transform_(fg, mng);
}
if (CNodeHasTupleInput(node->cast<CNodePtr>())) {
*trans_node = TransformPartial(fg, node->cast<CNodePtr>());
return true;
}
}
return false;
}
MS_LOG(WARNING) << "Got unexpected switch branch node " << node->DebugString();
return false;
}
private:
GraphTupleParamTransform &graph_transform_;
};
// {{switch_layer, index, {make_tuple, br1, br2,...,}}, Xs} ->
// transform switch layer graph call tuple inputs to flat inputs.
class SwitchLayerCallTupleTransform : public AnfVisitor {
public:
explicit SwitchLayerCallTupleTransform(GraphTupleParamTransform &transformer) : graph_transform_(transformer) {}
~SwitchLayerCallTupleTransform() override = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
if (!node->isa<CNode>() || node->func_graph() == nullptr) {
return nullptr;
}
auto switch_layer_call_cnode = node->cast<CNodePtr>();
auto call_inputs = switch_layer_call_cnode->inputs();
if (call_inputs.size() < 1) {
return nullptr;
}
if (!IsPrimitiveCNode(call_inputs[0], prim::kPrimSwitchLayer)) {
return nullptr;
}
auto swich_layer_cnode = call_inputs[0]->cast<CNodePtr>();
auto switch_layer_inputs = swich_layer_cnode->inputs();
if (switch_layer_inputs.size() != 3) {
return nullptr;
}
AnfNodePtr transformed = nullptr;
bool layer_changed = TransformLayerNode(switch_layer_inputs[2], optimizer->manager(), &transformed);
if (layer_changed) {
switch_layer_inputs[2] = transformed;
call_inputs[0] = switch_layer_call_cnode->func_graph()->NewCNode(switch_layer_inputs);
}
if (CNodeHasTupleInput(switch_layer_call_cnode)) {
return TransformSwitchCall(call_inputs[0], switch_layer_call_cnode);
}
if (layer_changed) {
return switch_layer_call_cnode->func_graph()->NewCNode(call_inputs);
}
return nullptr;
}
bool TransformLayerNode(AnfNodePtr node, FuncGraphManagerPtr mng, AnfNodePtr *trans_node) {
if (!IsPrimitiveCNode(node, prim::kPrimMakeTuple)) {
MS_LOG(WARNING) << "SwitchLayer input is not MakeTuple";
return false;
}
auto tuple_inputs = node->cast<CNodePtr>()->inputs();
bool changed = false;
for (size_t i = 1; i < tuple_inputs.size(); i++) {
if (!IsValueNode<FuncGraph>(tuple_inputs[i])) {
MS_LOG(WARNING) << "SwitchLayer input is not FuncGraph";
return false;
}
FuncGraphPtr fg = GetValueNode<FuncGraphPtr>(tuple_inputs[i]);
if (FuncGraphHasTupleInput(fg)) {
FuncGraphPtr transformed_fg = graph_transform_(fg, mng);
tuple_inputs[i] = NewValueNode(transformed_fg);
changed = true;
}
}
if (changed) {
*trans_node = node->func_graph()->NewCNode(tuple_inputs);
}
return changed;
}
private:
GraphTupleParamTransform &graph_transform_;
};
class CallGraphTupleTransform : public OptimizerCaller {
public:
CallGraphTupleTransform()
: graph_transformer_(),
graph_call_transform_(std::make_shared<GraphCallTupleTransform>(graph_transformer_)),
switch_call_transform_(std::make_shared<SwitchCallTupleTransform>(graph_transformer_)),
switch_layer_call_transform_(std::make_shared<SwitchLayerCallTupleTransform>(graph_transformer_)) {
transformers_.emplace_back(graph_call_transform_);
transformers_.emplace_back(switch_call_transform_);
transformers_.emplace_back(switch_layer_call_transform_);
}
~CallGraphTupleTransform() = default;
AnfNodePtr operator()(const OptimizerPtr &optimizer, const AnfNodePtr &node) override {
AnfNodePtr new_node;
for (auto &transform : transformers_) {
new_node = (*transform)(optimizer, node);
if (new_node != nullptr) {
return new_node;
}
}
return nullptr;
}
private:
GraphTupleParamTransform graph_transformer_;
OptimizerCallerPtr graph_call_transform_;
OptimizerCallerPtr switch_call_transform_;
OptimizerCallerPtr switch_layer_call_transform_;
std::vector<OptimizerCallerPtr> transformers_{};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_CALL_GRAPH_TRANSFORM_H_
......@@ -277,6 +277,7 @@ bool OptimizeAction(const ResourcePtr &res, const std::vector<PassItem> &passes)
MS_EXCEPTION_IF_NULL(func_graph);
func_graph->DumpFuncGraph(fg_name);
DumpIR(fg_name + ".ir", func_graph);
ExportIR(fg_name + ".dat", "", func_graph);
MS_LOG(DEBUG) << "Dump " << fg_name << " func graph.";
}
counter++;
......
......@@ -33,6 +33,7 @@
#include "frontend/optimizer/clean.h"
#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/control_depend.h"
#include "frontend/optimizer/graph_transform.h"
#include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/step_auto_parallel.h"
#include "frontend/parallel/allreduce_fusion/step_allreduce_fusion.h"
......@@ -166,12 +167,23 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
OptPassGroupMap GetOptPassesAfterCconv(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig c_1 = opt::OptPassConfig({
// Safe inlining
// Safe inlining,
irpass.inline_,
irpass.partial_eliminate_,
});
OptPassGroupMap map_a({{"c_1", c_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
OptPassGroupMap map_a({{"c_1", c_1},
{"cse", opt::OptPassConfig(opt::CSEPass(false))},
{"renormalize", opt::OptPassConfig::Renormalize()}});
return map_a;
}
OptPassGroupMap GetOptPassesTransformGraph(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig d_1 = opt::OptPassConfig({// Safe inlining
irpass.call_graph_tuple_transform_, irpass.item_tuple_eliminate_});
OptPassGroupMap map_a({{"d_1", d_1}, {"renormalize", opt::OptPassConfig::Renormalize()}});
return map_a;
}
......@@ -262,6 +274,8 @@ void InitOpt(const ResourcePtr &res) {
g_pass_opts["opt_b"] = Optimizer::MakeOptimizer("opt_b", res, GetOptPassesB(irpass), false, true);
g_pass_opts["opt_after_cconv"] =
Optimizer::MakeOptimizer("opt_after_cconv", res, GetOptPassesAfterCconv(irpass), false, true);
g_pass_opts["opt_trans_graph"] =
Optimizer::MakeOptimizer("opt_trans_graph", res, GetOptPassesTransformGraph(irpass), true, true);
g_pass_opts["opt_graph_kernel_a"] =
Optimizer::MakeOptimizer("opt_graph_kernel_a", res, GetOptPassesGraphKernelA(irpass), true);
g_pass_opts["opt_graph_kernel_b"] =
......@@ -307,6 +321,7 @@ bool OptPassGroup(const ResourcePtr &res, const std::string &name) {
bool OptPassAGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_a"); }
bool OptPassBGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_b"); }
bool OptPassAfterCconvGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_after_cconv"); }
bool OptPassTransformGraphGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_trans_graph"); }
bool OptPassGraphKernelGroupA(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_a"); }
bool OptPassGraphKernelGroupB(const ResourcePtr &res) { return OptPassGroup(res, "opt_graph_kernel_b"); }
bool ControlGroup(const ResourcePtr &res) { return OptPassGroup(res, "opt_control"); }
......@@ -365,6 +380,24 @@ bool CconvPass(const ResourcePtr &res) {
return true;
}
bool TransformTopGraphPass(const ResourcePtr &res) {
if (res->func_graph() == nullptr) {
MS_LOG(EXCEPTION) << "Transform top graph error.";
}
FuncGraphPtr func_graph = res->func_graph();
if (opt::FuncGraphHasTupleInput(func_graph)) {
opt::GraphTupleParamTransform graph_trans;
func_graph = graph_trans(func_graph, res->manager());
res->set_func_graph(func_graph);
AbstractBasePtrList abs_spec_list;
auto &params = func_graph->parameters();
std::transform(params.begin(), params.end(), std::back_inserter(abs_spec_list),
[](AnfNodePtr node) { return node->abstract(); });
res->set_args_spec(abs_spec_list);
}
return true;
}
bool ValidatePass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res->func_graph());
FuncGraphPtr func_graph = res->func_graph();
......@@ -388,6 +421,7 @@ std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStru
{"cconv", CconvPass},
{"opt_after_cconv", OptPassAfterCconvGroup},
{"remove_dup_value", RemoveValueNodeDuplicationsPass},
{"tuple_transform", OptPassTransformGraphGroup},
{"opt_graph_kernel_a", OptPassGraphKernelGroupA},
{"opt_graph_kernel_b", OptPassGraphKernelGroupB},
{"add_control_depend", AddControlDependPass}};
......@@ -401,6 +435,10 @@ std::vector<PassItem> kGePasses = {{"simplify_data_structures", SimplifyDataStru
{"opt_prepare", PrepareGroup},
{"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup}, {"opt_b", OptPassBGroup}, {"cconv", CconvPass}};
std::vector<PassItem> kPynativePasses = {{"opt_a", OptPassAGroup},
{"opt_b", OptPassBGroup},
{"cconv", CconvPass},
{"transform_top", TransformTopGraphPass},
{"transform_graph", OptPassTransformGraphGroup}};
} // namespace pipeline
} // namespace mindspore
......@@ -1351,9 +1351,46 @@ void PynativeExecutor::ClearRes() {
resource_.reset();
}
size_t GetTupleSize(const py::tuple &args) {
size_t count = 0;
for (size_t i = 0; i < args.size(); i++) {
if (py::isinstance<py::tuple>(args[i])) {
count += GetTupleSize(args[i]);
} else {
count += 1;
}
}
return count;
}
void ConvertTupleArg(py::tuple *res, size_t *index, const py::tuple &arg) {
for (size_t i = 0; i < arg.size(); i++) {
if (py::isinstance<py::tuple>(arg[i])) {
ConvertTupleArg(res, index, arg[i]);
} else {
(*res)[(*index)++] = arg[i];
}
}
}
py::tuple ConvertArgs(const py::tuple &args) {
size_t tuple_size = GetTupleSize(args);
py::tuple res(tuple_size);
size_t index = 0;
for (size_t i = 0; i < args.size(); i++) {
if (py::isinstance<py::tuple>(args[i])) {
ConvertTupleArg(&res, &index, args[i]);
} else {
res[index++] = args[i];
}
}
return res;
}
py::object PynativeExecutor::Run(const py::tuple &args, const py::object &phase) {
VectorRef arg_list;
pipeline::ProcessVmArgInner(args, resource_, &arg_list);
py::tuple converted_args = ConvertArgs(args);
pipeline::ProcessVmArgInner(converted_args, resource_, &arg_list);
if (resource_->results().find(pipeline::kOutput) == resource_->results().end() ||
!resource_->results()[pipeline::kOutput].is<compile::VmEvalFuncPtr>()) {
MS_LOG(EXCEPTION) << "Can't find run graph func for ";
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
from mindspore import RowTensor
from mindspore import context, nn, Tensor, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.common import ms_function
from mindspore.ops import operations as P
from mindspore.ops import composite as C
def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
class _Grad(nn.Cell):
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
super().__init__()
self.network = network
self.grad = grad
self.sens_param = self.grad.sens_param
self.wrt_params = wrt_params
self.real_inputs_count = real_inputs_count
if self.wrt_params:
self.params = ParameterTuple(self.network.trainable_params())
def construct(self, *inputs):
if self.wrt_params:
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network, self.params)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network)(*real_inputs, sense_param_inputs)
class GradOfFirstInput(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=C.GradOperation(sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
class GradOfAllInputs(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_row_tensor_in_while():
class RowTensorValuesDouble(nn.Cell):
def construct(self, x):
indices = x.indices
values = x.values * 2
dense_shape = x.dense_shape
return RowTensor(indices, values, dense_shape)
class RowTensorValuesAdd2(nn.Cell):
def construct(self, x):
indices = x.indices
values = x.values + 2
dense_shape = x.dense_shape
return RowTensor(indices, values, dense_shape)
class RowTensorWithControlWhile(nn.Cell):
def __init__(self, dense_shape):
super().__init__()
self.op1 = RowTensorValuesDouble()
self.op2 = RowTensorValuesAdd2()
self.dense_shape = dense_shape
@ms_function
def construct(self, a, b, indices, values):
x = RowTensor(indices, values, self.dense_shape)
x = self.op2(x)
while a > b:
x = self.op1(x)
b = b + 1
return x.indices, x.values, x.dense_shape
a = Tensor(np.array(3).astype(np.int32))
b = Tensor(np.array(0).astype(np.int32))
indices = Tensor(np.array([0, 2]).astype(np.int32))
values = Tensor(np.ones([2, 2]).astype(np.float32))
dense_shape = (5, 2)
net = RowTensorWithControlWhile(dense_shape)
out = net(a, b, indices, values)
assert np.allclose(indices.asnumpy(), out[0].asnumpy(), .0, .0)
assert np.allclose(values.asnumpy()*24, out[1].asnumpy(), .0, .0)
assert dense_shape == out[2]
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_parser_switch_layer_inputs_tuple():
class Add(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.TensorAdd()
def construct(self, x):
y = self.op(x[0], x[1])
return self.op(x[0], y)
class Mul(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
def construct(self, x):
y = self.op(x[0], x[1])
return self.op(x[0], y)
class MulTwoInput(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.Mul()
@ms_function
def construct(self, x, y):
y = self.op(x, y)
return self.op(x, y)
class TwoInputTupleFinalNet(nn.Cell):
def __init__(self, funcs):
super().__init__()
self.funcs = funcs
@ms_function
def construct(self, i, inputa, inputb):
inputs = (inputa, inputb)
x = self.funcs[i](inputs)
return x
func1 = Add()
func2 = Mul()
funcs = (func1, func2)
net = TwoInputTupleFinalNet(funcs)
input_data = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
input2 = Tensor(np.random.randn(2, 3, 4, 5).astype(np.float32))
i = Tensor(1, mstype.int32)
netout = net(i, input_data, input2)
net_good = MulTwoInput()
goodout = net_good(input_data, input2)
assert np.allclose(goodout.asnumpy(), netout.asnumpy(), 0, 0)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_imagenet():
class ImageGradients(nn.Cell):
def __init__(self):
super().__init__()
self.imagegradients = nn.ImageGradients()
def construct(self, inputs):
return self.imagegradients(inputs)
net = ImageGradients()
net_me = GradOfFirstInput(net, real_inputs_count=1)
net_me.set_train()
input_data = Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32)
output_grad = (Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32),
Tensor(np.ones([32, 16, 8, 8]), dtype=mstype.float32))
net_me(input_data, *output_grad)
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
from mindspore import RowTensor
from mindspore import context, nn, Tensor, ParameterTuple
from mindspore.common import dtype as mstype
from mindspore.common import ms_function
from mindspore.ops import composite as C
def setup_module():
context.set_context(mode=context.PYNATIVE_MODE, enable_sparse=False)
class _Grad(nn.Cell):
def __init__(self, grad, network, wrt_params=False, real_inputs_count=None):
super().__init__()
self.network = network
self.grad = grad
self.sens_param = self.grad.sens_param
self.wrt_params = wrt_params
self.real_inputs_count = real_inputs_count
if self.wrt_params:
self.params = ParameterTuple(self.network.trainable_params())
def construct(self, *inputs):
if self.wrt_params:
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network, self.params)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network, self.params)(*real_inputs, sense_param_inputs)
if self.real_inputs_count is None or self.sens_param is False:
return self.grad(self.network)(*inputs)
real_inputs = inputs[:self.real_inputs_count]
sense_param_inputs = inputs[self.real_inputs_count:]
return self.grad(self.network)(*real_inputs, sense_param_inputs)
class GradOfFirstInput(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=C.GradOperation(sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
class GradOfAllInputs(_Grad):
"""
get grad of first input
"""
def __init__(self, network, sens_param=True, real_inputs_count=None):
super().__init__(grad=C.GradOperation(get_all=True, sens_param=sens_param),
network=network, real_inputs_count=real_inputs_count)
def test_row_tensor_in_while():
class RowTensorValuesDouble(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
indices = x.indices
values = x.values * 2
dense_shape = x.dense_shape
return RowTensor(indices, values, dense_shape)
class RowTensorValuesAdd2(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x):
indices = x.indices
values = x.values + 2
dense_shape = x.dense_shape
return RowTensor(indices, values, dense_shape)
class RowTensorWithControlWhile(nn.Cell):
def __init__(self, dense_shape):
super().__init__()
self.op1 = RowTensorValuesDouble()
self.op2 = RowTensorValuesAdd2()
self.dense_shape = dense_shape
@ms_function
def construct(self, a, b, indices, values):
x = RowTensor(indices, values, self.dense_shape)
x = self.op2(x)
while (a > b):
x = self.op1(x)
b = b + 1
return x.indices, x.values, x.dense_shape
a = Tensor(np.array(3).astype(np.int32))
b = Tensor(np.array(0).astype(np.int32))
indices = Tensor(np.array([0, 2]).astype(np.int32))
values = Tensor(np.ones([2, 2]).astype(np.float32))
dense_shape = (5, 2)
net = RowTensorWithControlWhile(dense_shape)
net(a, b, indices, values)
def test_multi_out_sens():
class ImageGradients(nn.Cell):
def __init__(self):
super().__init__()
def construct(self, x, y, z):
resa = x * y
resb = y * z
resc = x * z
return resa, (resb, resc)
net = ImageGradients()
net_me = GradOfAllInputs(net, real_inputs_count=3)
net_me.set_train()
input_data = Tensor(np.ones([32]), dtype=mstype.float32)
output_grad = (Tensor(np.ones([32]), dtype=mstype.float32),
(Tensor(np.ones([32]), dtype=mstype.float32), Tensor(np.ones([32]), dtype=mstype.float32)))
net_me(input_data, input_data, input_data, *output_grad)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册