提交 e6f82af8 编写于 作者: W Wei Luning

add cell class to c++

上级 879a5191
...@@ -194,9 +194,12 @@ def get_object_key(obj): ...@@ -194,9 +194,12 @@ def get_object_key(obj):
obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args) obj_key = "%s_ID" % (str(obj.__class__.__name__) + str(obj.__name__) + obj.cell_init_args)
obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj)) obj_id = "%s_ID%d" % (str(obj.__class__.__name__) + str(obj.__name__), id(obj))
else: else:
# `<class 'xxxxxxx'>`
# -> `xxxxxxx`
tag = str(obj.__class__)[8:-2]
if hasattr(obj, "cell_init_args"): if hasattr(obj, "cell_init_args"):
obj_key = "%s_ID" % (str(obj.__class__.__name__) + obj.cell_init_args) obj_key = "%s_ID" % (tag + obj.cell_init_args)
obj_id = "%s_ID%d" % (str(obj.__class__.__name__), id(obj)) obj_id = "%s_ID%d" % (tag, id(obj))
logger.debug("obj_key %s obj_id = %s", obj_key, obj_id) logger.debug("obj_key %s obj_id = %s", obj_key, obj_id)
# method has same id of different instance # method has same id of different instance
......
...@@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor { ...@@ -316,7 +316,6 @@ class IncorporateGetitemFromParam : public AnfVisitor {
} }
} }
// (void)mng->Replace(new_fg_parameters[param_i], new_param);
new_parameters.push_back(new_param); new_parameters.push_back(new_param);
curr_input_idx++; curr_input_idx++;
} }
......
...@@ -25,6 +25,7 @@ ...@@ -25,6 +25,7 @@
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/param_info.h" #include "ir/param_info.h"
#include "ir/cell.h"
#include "frontend/parallel/costmodel_context.h" #include "frontend/parallel/costmodel_context.h"
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#include "pipeline/jit/pass.h" #include "pipeline/jit/pass.h"
...@@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) { ...@@ -122,17 +123,29 @@ bool ParseAction(const ResourcePtr &res) {
parse::python_adapter::set_python_env_flag(true); parse::python_adapter::set_python_env_flag(true);
parse::python_adapter::SetPythonPath(dir); parse::python_adapter::SetPythonPath(dir);
FuncGraphPtr fg = parse::ConvertToFuncGraph(input); ValuePtr converted_ret = nullptr;
if (fg == nullptr) { bool converted = parse::ConvertData(input, &converted_ret, true);
MS_LOG(EXCEPTION) << "Parse error."; if (!converted) {
MS_LOG(EXCEPTION) << "Attribute convert error with type:" << std::string(py::str(input));
} }
res->set_func_graph(fg);
FuncGraphPtr top_graph = nullptr;
if (py::isinstance<Cell>(input)) {
top_graph = parse::MakeTopGraph(input, converted_ret);
} else if (converted_ret->isa<FuncGraph>()) {
top_graph = converted_ret->cast<FuncGraphPtr>();
} else {
MS_LOG(EXCEPTION) << "Object to parse " << std::string(py::str(input)) << " is not function or cell.";
}
parse::Parser::UpdateTopFuncGraph(top_graph);
res->set_func_graph(top_graph);
FuncGraphManagerPtr manager = res->manager(); FuncGraphManagerPtr manager = res->manager();
if (manager == nullptr) { if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Manager is nullptr."; MS_LOG(EXCEPTION) << "Manager is nullptr.";
} }
manager->AddFuncGraph(fg); manager->AddFuncGraph(top_graph);
return true; return true;
} }
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "frontend/operator/ops.h" #include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h" #include "frontend/operator/composite/composite.h"
#include "ir/func_graph_cloner.h" #include "ir/func_graph_cloner.h"
#include "ir/cell.h"
#include "utils/symbolic.h" #include "utils/symbolic.h"
#include "utils/ms_context.h" #include "utils/ms_context.h"
...@@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) { ...@@ -223,7 +224,8 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
return true; return true;
} }
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) { bool ConvertCellObjToFuncGraph(const CellPtr &cell, ValuePtr *const data) {
auto obj = py::cast(cell);
FuncGraphPtr func_graph = ConvertToFuncGraph(obj); FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
if (func_graph == nullptr) { if (func_graph == nullptr) {
MS_LOG(ERROR) << "Parse resolve function error."; MS_LOG(ERROR) << "Parse resolve function error.";
...@@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) { ...@@ -271,10 +273,6 @@ bool ConvertOtherObj(py::object obj, ValuePtr *const data) {
if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) { if (obj_type == RESOLVE_TYPE_CLASS_INSTANCE) {
// Create the namespace for common class instance // Create the namespace for common class instance
// When the obj is Cell, default parse the 'construct' // When the obj is Cell, default parse the 'construct'
if (data_converter::IsCellInstance(obj)) {
return ConvertCellObjToFuncGraph(obj, data);
}
py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE); py::module mod = python_adapter::GetPyModule(PYTHON_MOD_PARSE_MODULE);
py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj); py::object namespace_var = python_adapter::CallPyModFn(mod, PYTHON_MOD_GET_MEMBER_NAMESPACE_SYMBOL, obj);
*data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var); *data = std::make_shared<NameSpace>(RESOLVE_NAMESPACE_NAME_CLASS_MEMBER, namespace_var);
...@@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature ...@@ -404,6 +402,8 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
ret = ConvertTuple(obj, &converted, use_signature); ret = ConvertTuple(obj, &converted, use_signature);
} else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) { } else if (py::hasattr(obj, PYTHON_CELL_AS_LIST)) {
ret = ConvertCellList(obj, &converted, use_signature); ret = ConvertCellList(obj, &converted, use_signature);
} else if (py::isinstance<Cell>(obj)) {
return ConvertCellObjToFuncGraph(obj.cast<CellPtr>(), data);
} else if (py::isinstance<py::list>(obj)) { } else if (py::isinstance<py::list>(obj)) {
ret = ConvertList(obj, &converted, use_signature); ret = ConvertList(obj, &converted, use_signature);
} else if (py::isinstance<py::module>(obj)) { } else if (py::isinstance<py::module>(obj)) {
......
...@@ -140,34 +140,80 @@ void Parser::CleanParserResource() { ...@@ -140,34 +140,80 @@ void Parser::CleanParserResource() {
ScopeManager::GetInstance().ClearScope(); ScopeManager::GetInstance().ClearScope();
} }
FuncGraphPtr Parser::ParseFuncGraph() { AnfNodePtr AppendParameterObj(const FuncGraphPtr &func_graph, const py::object &obj) {
// get ast FunctionDef node MS_EXCEPTION_IF_NULL(func_graph);
py::object node = ast_->GetAstNode(); auto value = py::cast<tensor::MetaTensorPtr>(obj);
FunctionBlockPtr pFnBlock = ParseFunction(node); // parameter object should not be none
if (errcode() != PARSE_SUCCESS) { if (value == nullptr || !value->is_parameter()) {
MS_LOG(ERROR) << "Parse function error, code is " << errcode(); MS_LOG(EXCEPTION) << "Parameter error: because obj is not Parameter object.";
return nullptr; }
// get the parameter name from parameter object
auto param_name = value->param_info()->name();
auto top_graph = func_graph;
// if the parameter node has been created , return it
AnfNodePtr para_node = nullptr;
for (auto param : top_graph->parameters()) {
auto param_node = dyn_cast<Parameter>(param);
if (param_node != nullptr && param_node->name() == param_name) {
para_node = param;
break;
}
} }
if (para_node == nullptr) {
auto node = top_graph->AddWeightParameter(param_name);
RemoveUnnecessaryPhis(); node->set_default_param(value);
// set_abstract for parameter
auto abs = value->ToAbstract();
// boarden value
abs = abs->Broaden();
node->set_abstract(abs);
para_node = node;
}
return para_node;
}
MS_EXCEPTION_IF_NULL(pFnBlock); void UpdataParam(const FuncGraphPtr &top_graph, const py::object &cell) {
auto params = py::list(cell.attr("get_parameters")()).cast<std::vector<py::object>>();
for (size_t i = 0; i < params.size(); i++) {
(void)AppendParameterObj(top_graph, params[i]);
}
}
void CheckFuncReturn(const FuncGraphPtr &fn, const std::shared_ptr<ParseAst> &ast) {
// check whether the functions refered by this function and itself are missing 'return' statement // check whether the functions refered by this function and itself are missing 'return' statement
auto mng = Manage(pFnBlock->func_graph(), false); auto mng = Manage(fn, false);
for (auto func_graph : mng->func_graphs()) { for (auto func_graph : mng->func_graphs()) {
if (func_graph->get_return() != nullptr) { if (func_graph->get_return() != nullptr) {
continue; continue;
} }
py::list ret = ast_->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node); py::object node = ast->GetAstNode();
py::list ret = ast->CallParserObjMethod(PYTHON_PARSE_GET_LOCATION, node);
py::str desc = py::str desc =
python_adapter::CallPyModFn(ast_->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast_->function(), ret[0], ret[1]); python_adapter::CallPyModFn(ast->module(), PYTHON_MOD_GET_OBJECT_DESCRIPTION, ast->function(), ret[0], ret[1]);
MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << "."; MS_EXCEPTION(TypeError) << "Missing return statement in " << desc.cast<std::string>() << ".";
} }
// clear manager info after checking missing return // clear manager info after checking missing return
for (auto fg : mng->func_graphs()) { for (auto fg : mng->func_graphs()) {
fg->ClearAllManagerInfo(); fg->ClearAllManagerInfo();
} }
}
FuncGraphPtr Parser::ParseFuncGraph() {
// get ast FunctionDef node
py::object node = ast_->GetAstNode();
FunctionBlockPtr pFnBlock = ParseFunction(node);
if (errcode() != PARSE_SUCCESS) {
MS_LOG(ERROR) << "Parse function error, code is " << errcode();
return nullptr;
}
RemoveUnnecessaryPhis();
MS_EXCEPTION_IF_NULL(pFnBlock);
CheckFuncReturn(pFnBlock->func_graph(), ast_);
return pFnBlock->func_graph(); return pFnBlock->func_graph();
} }
...@@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no ...@@ -591,19 +637,24 @@ AnfNodePtr Parser::ParseCall(const FunctionBlockPtr &block, const py::object &no
return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack); return GenerateAnfNodeForCall(block, call_function_anf_node, packed_arguments, group_arguments, need_unpack);
} }
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node, CNodePtr MakeUnpackCall(const FuncGraphPtr &func_graph, const AnfNodePtr &call_function_anf_node,
const std::vector<AnfNodePtr> &packed_arguments, const std::vector<AnfNodePtr> &packed_arguments) {
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
// if there is keyword arguments or starred, using an unpack_call op to unpack the argument
if (need_unpack) {
std::vector<AnfNodePtr> unpack_call_nodes; std::vector<AnfNodePtr> unpack_call_nodes;
auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL)); auto unpack_call_op = NewValueNode(std::make_shared<prim::UnpackCall>(NAMED_METAGRAPH_UNPACKCALL));
unpack_call_nodes.push_back(unpack_call_op); unpack_call_nodes.push_back(unpack_call_op);
unpack_call_nodes.push_back(call_function_anf_node); unpack_call_nodes.push_back(call_function_anf_node);
(void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes), (void)std::transform(packed_arguments.begin(), packed_arguments.end(), std::back_inserter(unpack_call_nodes),
[](AnfNodePtr node) -> AnfNodePtr { return node; }); [](AnfNodePtr node) -> AnfNodePtr { return node; });
CNodePtr unpack_call = block->func_graph()->NewCNode(unpack_call_nodes); CNodePtr unpack_call = func_graph->NewCNode(unpack_call_nodes);
return unpack_call; return unpack_call;
}
AnfNodePtr Parser::GenerateAnfNodeForCall(const FunctionBlockPtr &block, const AnfNodePtr &call_function_anf_node,
const std::vector<AnfNodePtr> &packed_arguments,
const std::vector<AnfNodePtr> &group_arguments, bool need_unpack) const {
// if there is keyword arguments or starred, using an unpack_call op to unpack the argument
if (need_unpack) {
return MakeUnpackCall(block->func_graph(), call_function_anf_node, packed_arguments);
} }
// else there is no keyword arguments and starred, parsed as normal arguments without unpack // else there is no keyword arguments and starred, parsed as normal arguments without unpack
std::vector<AnfNodePtr> func_call_nodes; std::vector<AnfNodePtr> func_call_nodes;
...@@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) { ...@@ -1739,5 +1790,41 @@ bool UpdateFuncGraphFlags(py::object obj, const FuncGraphPtr &func_graph) {
return true; return true;
} }
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr) {
auto func_graph = std::make_shared<FuncGraph>();
func_graph->debug_info()->set_name("top");
// def top(*arg, *kwargs):
auto param_vargs = func_graph->add_parameter();
auto args_name = "args";
param_vargs->set_name(args_name);
param_vargs->debug_info()->set_name(args_name);
auto param_vkwargs = func_graph->add_parameter();
args_name = "kwargs";
param_vkwargs->set_name(args_name);
param_vkwargs->debug_info()->set_name(args_name);
func_graph->set_has_vararg(true);
func_graph->set_has_kwarg(true);
func_graph->set_kwonlyargs_count(0);
// cell_obj
parse::UpdateFuncGraphFlags(cell, func_graph);
// top graph's construct flag
if (py::hasattr(cell, "construct")) {
parse::UpdateFuncGraphFlags(cell.attr("construct"), func_graph);
}
UpdataParam(func_graph, cell);
// ret = cell_obj(*arg, *kwargs)
auto call_fn = MakeUnpackCall(func_graph, NewValueNode(cell_ptr), {param_vargs, param_vkwargs});
// return ret
func_graph->set_output(call_fn);
MS_LOG(DEBUG) << "add Flag for " << std::string(py::str(cell));
return func_graph;
}
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore
...@@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj, ...@@ -148,6 +148,9 @@ FuncGraphPtr ConvertToFuncGraph(const py::object &obj,
// Parse the python object to graph // Parse the python object to graph
FuncGraphPtr ParsePythonCode(const py::object &obj, FuncGraphPtr ParsePythonCode(const py::object &obj,
const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD); const std::string &python_mod_get_parse_method = PYTHON_MOD_GET_PARSE_METHOD);
// add wrap for cell top graph.
FuncGraphPtr MakeTopGraph(const py::object &cell, const ValuePtr &cell_ptr);
} // namespace parse } // namespace parse
} // namespace mindspore } // namespace mindspore
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind_api/ir/cell_py.h"
#include <string>
#include "pybind_api/api_register.h"
#include "abstract/abstract_value.h"
#include "pipeline/jit/parse/python_adapter.h"
namespace mindspore {
void CellPy::AddAttr(CellPtr cell, const std::string &name, const py::object &obj) {
std::string attr_name = name;
ValuePtr converted_ret = nullptr;
if (py::isinstance<py::module>(obj)) {
MS_LOG(EXCEPTION) << "Cell set_attr failed, attr should not be py::module";
}
bool converted = parse::ConvertData(obj, &converted_ret, true);
if (!converted) {
MS_LOG(DEBUG) << "Attribute convert error with type: " << std::string(py::str(obj));
} else {
MS_LOG(DEBUG) << cell->ToString() << " add attr " << attr_name << converted_ret->ToString();
cell->AddAttr(attr_name, converted_ret);
}
}
// Define python 'Cell' class.
REGISTER_PYBIND_DEFINE(Cell, ([](const py::module *m) {
(void)py::class_<Cell, std::shared_ptr<Cell>>(*m, "Cell_")
.def(py::init<std::string &>())
.def("__str__", &Cell::ToString)
.def("_add_attr", &CellPy::AddAttr, "Add Cell attr.")
.def("_del_attr", &Cell::DelAttr, "Delete Cell attr.")
.def(
"construct", []() { MS_LOG(EXCEPTION) << "we should define `construct` for all `cell`."; },
"construct");
}));
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_UTILS_CELL_PY_H_
#define MINDSPORE_CCSRC_UTILS_CELL_PY_H_
#include <memory>
#include <string>
#include <vector>
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "ir/cell.h"
namespace py = pybind11;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
// Other namespace should be a sub namespace of mindspore namespace in the ME project.
namespace mindspore {
// Cell python wrapper and adapter class.
class CellPy {
public:
static void AddAttr(CellPtr cell, const std::string &name, const py::object &obj);
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_CELL_PY_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/cell.h"
#include <utility>
#include <map>
#include <algorithm>
#include "abstract/abstract_value.h"
namespace mindspore {
using mindspore::abstract::AbstractFunction;
abstract::AbstractBasePtr Cell::ToAbstract() {
/*
std::vector<abstract::AbstractAttribute> abs_attrs;
std::transform(attrs_.begin(), attrs_.end(), std::back_inserter(abs_attrs),
[](std::pair<std::string, ValuePtr> attr) -> abstract::AbstractAttribute {
return std::make_pair(attr.first, attr.second->ToAbstract());
});
auto abs = std::make_shared<abstract::AbstractCell>(shared_from_base<Named>(), abs_attrs);
abs->set_value(shared_from_base<Value>());
return abs;
*/
return nullptr;
}
bool Cell::operator==(const Value &other) const {
if (other.isa<Cell>()) {
auto other_prim = static_cast<const Cell &>(other);
return *this == other_prim;
} else {
return false;
}
}
bool Cell::operator==(const Cell &other) const {
if (name() != other.name()) {
return false;
}
if (attrs_.size() != other.attrs_.size()) {
return false;
}
auto all = std::all_of(attrs_.begin(), attrs_.end(), [&other](const std::pair<std::string, ValuePtr> &item) -> bool {
if (item.second == nullptr) {
return false;
}
auto iter = other.attrs_.find(item.first);
if (iter == other.attrs_.end()) {
return false;
}
return *item.second == *iter->second;
});
return all;
}
std::string Cell::GetAttrString() const {
std::ostringstream buffer;
bool begin = true;
buffer << "{" << std::endl;
for (auto &attr : attrs_) {
if (!begin) {
buffer << ", " << std::endl;
} else {
begin = false;
}
buffer << attr.first << ":" << attr.second->ToString();
}
buffer << "}";
return buffer.str();
}
std::string Cell::ToString() const {
std::ostringstream buffer;
buffer << "Cell " << name();
return buffer.str();
}
void Cell::DelAttr(const std::string &name) { attrs_.erase(name); }
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_CELL_H_
#define MINDSPORE_CCSRC_IR_CELL_H_
#include <unordered_map>
#include <vector>
#include <memory>
#include <string>
#include <tuple>
#include "abstract/abstract_value.h"
#include "utils/misc.h"
namespace mindspore {
using abstract::AbstractBasePtr;
using abstract::AbstractBasePtrList;
// value for Cell
class Cell : public Named {
public:
explicit Cell(const std::string &name) : Named(name) {}
MS_DECLARE_PARENT(Cell, Named);
abstract::AbstractBasePtr ToAbstract() override;
std::string ToString() const override;
std::string GetAttrString() const;
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
void set_attrs(const std::unordered_map<std::string, ValuePtr> &attrs_input) { attrs_ = attrs_input; }
void AddAttr(const std::string &name, const ValuePtr &attr) { attrs_[name] = attr; }
void DelAttr(const std::string &name);
ValuePtr GetAttr(const std::string &attr_name) const {
auto iter = attrs_.find(attr_name);
return iter == attrs_.cend() ? nullptr : iter->second;
}
bool HasAttr(const std::string &attr_name) const {
auto iter = attrs_.find(attr_name);
return !(iter == attrs_.cend());
}
bool operator==(const Value &other) const override;
bool operator==(const Cell &other) const;
~Cell() override = default;
const bool parse_info_ = true;
private:
std::unordered_map<std::string, ValuePtr> attrs_;
};
using CellPtr = std::shared_ptr<Cell>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_CELL_H_
...@@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph, ...@@ -98,10 +98,11 @@ void FuncGraph::GenerateVarParams(const FuncGraphPtr &specialized_graph,
MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count MS_LOG(EXCEPTION) << "Function:" << this->ToString() << ", variable_args_count " << variable_args_count
<< " were given."; << " were given.";
} }
auto varg_name = specialized_graph->GetVariableArgName();
// for python variable argument input , there is no upper limit // for python variable argument input , there is no upper limit
for (int i = 0; i < variable_args_count; ++i) { for (int i = 0; i < variable_args_count; ++i) {
ParameterPtr p = std::make_shared<Parameter>(specialized_graph); ParameterPtr p = std::make_shared<Parameter>(specialized_graph);
std::string param_name = specialized_graph->GetVariableArgName() + std::to_string(i); std::string param_name = varg_name + std::to_string(i);
p->set_name(param_name); p->set_name(param_name);
MS_EXCEPTION_IF_NULL(p->debug_info()); MS_EXCEPTION_IF_NULL(p->debug_info());
p->debug_info()->set_name(param_name); p->debug_info()->set_name(param_name);
......
...@@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe ...@@ -49,7 +49,8 @@ NameWithTrace RootName(const DebugInfoPtr &debug_info, TraceLabelType trace_labe
while (temp_info != nullptr) { while (temp_info != nullptr) {
if (temp_info->trace_info() != nullptr) { if (temp_info->trace_info() != nullptr) {
if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() || if (temp_info->trace_info()->isa<TraceResolve>() || temp_info->trace_info()->isa<TraceExpandJ>() ||
temp_info->trace_info()->isa<TraceGenMetaFuncGraph>()) { temp_info->trace_info()->isa<TraceGenMetaFuncGraph>() ||
temp_info->trace_info()->isa<TraceGenerateVarArg>() || temp_info->trace_info()->isa<TraceGenerateKwArg>()) {
break; break;
} }
trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label)); trace_name.trace_labels.push_back(GetTraceName(temp_info->trace_info(), trace_label));
......
...@@ -24,14 +24,14 @@ from ..common import dtype as mstype ...@@ -24,14 +24,14 @@ from ..common import dtype as mstype
from ..common.api import _executor, _pynative_exec from ..common.api import _executor, _pynative_exec
from .._checkparam import _check_str_by_regular from .._checkparam import _check_str_by_regular
from ..common.parameter import Parameter, ParameterTuple from ..common.parameter import Parameter, ParameterTuple
from .._c_expression import init_backend from .._c_expression import init_backend, Cell_
from ..ops.primitive import Primitive from ..ops.primitive import Primitive
from ..ops.operations import HookBackward from ..ops.operations import HookBackward
from ..ops.functional import cast from ..ops.functional import cast
from ..parallel._tensor import _load_tensor_by_layout from ..parallel._tensor import _load_tensor_by_layout
from ..common.tensor import Tensor from ..common.tensor import Tensor
class Cell: class Cell(Cell_):
""" """
Base class for all neural networks. Base class for all neural networks.
...@@ -58,14 +58,21 @@ class Cell: ...@@ -58,14 +58,21 @@ class Cell:
>>> def construct(self, x): >>> def construct(self, x):
>>> return self.relu(x) >>> return self.relu(x)
""" """
IGNORE_LIST = ['_scope', '_cell_init_args', '_auto_prefix', '_cells', '_params', '_construct_inputs_names',
'_construct_inputs_num', '_create_time', '_mindspore_flags', '_parallel_inputs_run',
'_parameter_layout_dict', '_already_run', '_params_list', '_phase', '_auto_parallel_mode',
'_backward_hook', '_bprop_debug', '_is_run', '_param_prefix', '_attr_synced',
'enable_hook', 'pynative', 'requires_grad', '_auto_parallel_compile_and_run', 'cell_type']
def __init__(self, auto_prefix=True, flags=None): def __init__(self, auto_prefix=True, flags=None):
Cell_.__init__(self, self._cell_tag)
self._params = OrderedDict() self._params = OrderedDict()
self._cells = OrderedDict() self._cells = OrderedDict()
self._params_list = OrderedDict() self._params_list = OrderedDict()
self.training = False self.training = False
self.requires_grad = False self.requires_grad = False
self.pynative = False self.pynative = False
self._attr_synced = False
self._param_prefix = '' self._param_prefix = ''
self._auto_prefix = auto_prefix self._auto_prefix = auto_prefix
self._scope = None self._scope = None
...@@ -92,6 +99,12 @@ class Cell: ...@@ -92,6 +99,12 @@ class Cell:
def already_run(self): def already_run(self):
return self._already_run return self._already_run
@property
def _cell_tag(self):
# `<class 'xxxxxxx'>`
# -> `xxxxxxx`
return str(self.__class__)[8:-2]
@already_run.setter @already_run.setter
def already_run(self, value): def already_run(self, value):
self._already_run = value self._already_run = value
...@@ -222,6 +235,7 @@ class Cell: ...@@ -222,6 +235,7 @@ class Cell:
del self._cells[name] del self._cells[name]
else: else:
object.__delattr__(self, name) object.__delattr__(self, name)
self._attr_synced = False
def cast_inputs(self, inputs, dst_type): def cast_inputs(self, inputs, dst_type):
res = list() res = list()
...@@ -277,6 +291,34 @@ class Cell: ...@@ -277,6 +291,34 @@ class Cell:
self._already_run = True self._already_run = True
return output return output
def _add_attr(self, name, value):
if name and name[:2] != '__' and name not in Cell.IGNORE_LIST:
super(Cell, self)._add_attr(name, value)
def _sync_attr_for_compile(self):
"""Sync the attr to c++ object."""
if self._attr_synced:
return
cells = self.__dict__.get('_cells')
for key in cells:
cell = cells[key]
cell._sync_attr_for_compile()
self._add_attr(key, cell)
params = self.__dict__.get('_params')
for key in params:
if '.' in key:
continue
param = params[key]
self._add_attr(key, param)
params_list = self.__dict__.get('_params_list')
for key in params_list:
params_list_item = params_list[key]
self._add_attr(key, params_list_item)
for key in self.__dict__:
value = self.__dict__[key]
self._add_attr(key, value)
self._attr_synced = True
def __setattr__(self, name, value): def __setattr__(self, name, value):
cells = self.__dict__.get('_cells') cells = self.__dict__.get('_cells')
params = self.__dict__.get('_params') params = self.__dict__.get('_params')
...@@ -329,6 +371,8 @@ class Cell: ...@@ -329,6 +371,8 @@ class Cell:
if isinstance(value, Primitive): if isinstance(value, Primitive):
value.set_prim_instance_name(name) value.set_prim_instance_name(name)
object.__setattr__(self, name, value) object.__setattr__(self, name, value)
if name not in Cell.IGNORE_LIST:
self._attr_synced = False
def extend_repr(self): def extend_repr(self):
""" """
...@@ -451,7 +495,7 @@ class Cell: ...@@ -451,7 +495,7 @@ class Cell:
Object, the result of executing. Object, the result of executing.
""" """
self._auto_parallel_compile_and_run = True self._auto_parallel_compile_and_run = True
_executor.compile(self, *inputs, phase=self.phase, auto_parallel_mode=self._auto_parallel_mode) self.compile(*inputs)
if self._auto_parallel_mode: if self._auto_parallel_mode:
if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag: if inputs and isinstance(inputs[0], Tensor) and inputs[0].virtual_flag:
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""container""" """container"""
from collections import OrderedDict from collections import OrderedDict
from abc import abstractmethod, ABCMeta from abc import abstractmethod
from ..cell import Cell from ..cell import Cell
__all__ = ['SequentialCell', 'CellList'] __all__ = ['SequentialCell', 'CellList']
...@@ -34,7 +34,7 @@ def _valid_cell(cell): ...@@ -34,7 +34,7 @@ def _valid_cell(cell):
raise TypeError('Cell {} is not subclass of Cell'.format(cell)) raise TypeError('Cell {} is not subclass of Cell'.format(cell))
class _CellListBase(metaclass=ABCMeta): class _CellListBase():
""" """
An interface for base the cell as list. An interface for base the cell as list.
......
...@@ -51,7 +51,7 @@ def test_get_parameter_layout(): ...@@ -51,7 +51,7 @@ def test_get_parameter_layout():
exe.compile(net, x, phase='train', auto_parallel_mode=True) exe.compile(net, x, phase='train', auto_parallel_mode=True)
x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1] x_layout = [[2, 4], [1, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [1, -1]
weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1] weight_layout = [[2, 4], [0, -1], [16, 32], [0], [1]] # device_arrangement = [2, 4], tensor_map = [0, -1]
expect_dict = {'x': x_layout, 'w1': weight_layout} expect_dict = {'args0': x_layout, 'w1': weight_layout}
# to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut # to be resovled: static local variable count_p is used in step_parallel.cc, it needs to be reset between each ut
assert net.parameter_layout_dict == expect_dict assert net.parameter_layout_dict == expect_dict
......
...@@ -125,7 +125,7 @@ def test_grad_sens_parameter_type(): ...@@ -125,7 +125,7 @@ def test_grad_sens_parameter_type():
y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]] y_layout = [[8, 8], [-1, 0], [32, 8], [0], [1]]
b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]] b_layout = [[8, 8], [0, -1], [8, 64], [0], [1]]
sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]] sens_layout = [[8, 8], [1, -1], [16, 64], [0], [1]]
expect_dict = {'x': x_layout, 'y': y_layout, 'b': b_layout, 'sens': sens_layout} expect_dict = {'args0': x_layout, 'args1': y_layout, 'args2': b_layout, 'args3': sens_layout}
assert net.parameter_layout_dict == expect_dict assert net.parameter_layout_dict == expect_dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册