提交 a0956538 编写于 作者: K kpy 提交者: kuangpeiyu

optimize infer in pynative mode

上级 61639d90
......@@ -351,7 +351,7 @@ bool ExecuteAction(const ResourcePtr &res) {
}
auto graph_id = res->results()[kOutput].cast<GraphId>();
std::shared_ptr<compile::Backend> bc_ptr = res->results()[kBackend].cast<std::shared_ptr<compile::Backend>>();
std::shared_ptr<compile::MsBackend> msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr);
compile::MsBackend *msbc_ptr = std::dynamic_pointer_cast<compile::MsBackend>(bc_ptr).get();
MS_EXCEPTION_IF_NULL(msbc_ptr);
compile::VmEvalFuncPtr run =
std::make_shared<compile::VmEvalFunc>([msbc_ptr, graph_id](const VectorRef &args) -> BaseRef {
......
......@@ -205,6 +205,7 @@ Resource::Resource(const py::object &obj)
Resource::~Resource() {
MS_LOG(DEBUG) << "Resource clear";
std::unordered_map<std::string, Any>().swap(results_);
// If exit normally, these global variables will be cleaned
// in Resource::Clean call by MsPipeline::Compile, but if exit with MS_LOGEXCEPTION,
// these global variables may not being cleaned, it may
......
......@@ -54,12 +54,12 @@ struct OpExecInfo {
AbstractBasePtr abstract;
ValuePtr value = nullptr;
py::tuple op_inputs;
py::tuple inputs_mask;
py::list op_inputs;
py::dict op_attrs;
std::vector<bool> inputs_mask;
};
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args);
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"};
} // namespace pynative
......
......@@ -41,12 +41,20 @@ namespace py = pybind11;
using ResourcePtr = std::shared_ptr<pipeline::Resource>;
using GradOperationPtr = std::shared_ptr<prim::GradOperation>;
struct PrimAbsInfo {
abstract::AbstractBasePtr abs;
std::unordered_map<std::string, ValuePtr> attrs;
};
using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status);
py::tuple RunOp(const py::args &args);
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
py::list *const out_args_list);
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *const out_args,
py::list *const out_args_list);
void ClearPyNativeSession();
......@@ -82,7 +90,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void ClearRes();
bool grad_flag() { return grad_flag_; }
void set_grad_flag(bool flag) { grad_flag_ = flag; }
AnfNodePtr GetInput(const py::object &obj, const py::object &op_mask);
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
AnfNodePtr GetObjNode(const py::object &obj);
FuncGraphPtr curr_g() { return curr_g_; }
void set_pyobj(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
......@@ -95,11 +103,14 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, std::vector<int> index) {
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
}
CNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks,
abstract::AbstractBasePtrList *args_spec_list);
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
void SaveOpForwardValue(const OpExecInfoPtr &op_exec_info, const ValuePtr &value);
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
py::object Run(const py::tuple &args, const py::object &phase);
void Pushp();
......@@ -108,6 +119,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
size_t arg_size);
void SetTupleOutput(const py::object &obj, const AnfNodePtr &cnode, std::vector<int> idx);
AnfNodePtr MakeValueNode(const py::object &obj, const std::string &obj_id);
py::tuple RunOpInner(const py::args &args);
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
~PynativeExecutor();
......@@ -123,10 +136,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
std::unordered_map<std::string, ValuePtr> op_forward_map_;
std::unordered_map<std::string, size_t> op_id_map_;
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
std::stack<FuncGraphPtr> graph_p_;
FuncGraphPtr top_g_;
FuncGraphPtr df_builder_;
FuncGraphPtr curr_g_;
std::unordered_map<std::string, AbstractListMap> prim_abs_list;
};
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
......
......@@ -220,6 +220,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_is_const_value", &PrimitivePy::set_is_const_value, "Set primitive is const value.")
.def("set_signatures", &PrimitivePy::set_signatures, "Set primitive inputs signature.")
.def("register_hook", &PrimitivePy::set_hook, "Set primitive hook function.")
.def("set_instance_name", &PrimitivePy::set_instance_name, "Set primitive instance name.");
......
......@@ -21,6 +21,31 @@
namespace mindspore {
static std::string MakeId() {
// Use atomic to make id generator thread safe.
static std::atomic<uint64_t> last_id{1};
return "P" + std::to_string(last_id.fetch_add(1, std::memory_order_relaxed));
}
Primitive::Primitive(const std::string &name, const bool is_base, const PrimType prim_type)
: Named(name),
is_base_(is_base),
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false),
is_const_value_(false),
id_(MakeId()) {}
Primitive::Primitive(const Primitive &prim)
: Named(prim),
attrs_(prim.attrs_),
instance_name_(prim.instance_name_),
is_base_(prim.is_base_),
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false),
id_(prim.id_) {}
abstract::AbstractBasePtr Primitive::ToAbstract() {
return std::make_shared<abstract::PrimitiveAbstractClosure>(shared_from_base<Primitive>(), nullptr);
}
......
......@@ -40,22 +40,8 @@ enum PrimType {
class Primitive : public Named {
public:
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn)
: Named(name),
is_base_(is_base),
has_signature_(false),
prim_type_(prim_type),
record_evaluate_add_attr_(false) {}
Primitive(const Primitive &prim)
: Named(prim),
attrs_(prim.attrs_),
instance_name_(prim.instance_name_),
is_base_(prim.is_base_),
has_signature_(prim.has_signature_),
prim_type_(prim.prim_type_),
record_evaluate_add_attr_(false) {}
explicit Primitive(const std::string &name, const bool is_base = true, const PrimType prim_type = kPrimTypeBuiltIn);
Primitive(const Primitive &prim);
MS_DECLARE_PARENT(Primitive, Named);
abstract::AbstractBasePtr ToAbstract();
abstract::AbstractBasePtr ToPrimAbstract(const AnfNodePtr &anf_node);
......@@ -91,6 +77,12 @@ class Primitive : public Named {
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
const std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() const { return evaluate_added_attrs_; }
void set_evaluate_added_attrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
MS_LOG(INFO) << " set evalu attrl " << name() << attr.first;
attrs_[attr.first] = attr.second;
}
}
// if Primitive has any attribute, for Primitives like scalar_add, return, etc, don't have any attribute.
bool HasAttr() const { return !attrs_.empty(); }
......@@ -117,6 +109,9 @@ class Primitive : public Named {
bool is_base() const { return is_base_; }
virtual BaseRef RunHookFunction(const VectorRef &args) const { MS_LOG(EXCEPTION) << "call a empty function!"; }
virtual void CopyHookFunction(const PrimitivePtr &primitive) { MS_LOG(EXCEPTION) << "call a empty function!"; }
void set_is_const_value(bool value) { is_const_value_ = value; }
bool is_const_value() const { return is_const_value_; }
std::string id() const { return id_; }
protected:
std::unordered_map<std::string, ValuePtr> attrs_;
......@@ -128,6 +123,8 @@ class Primitive : public Named {
bool has_signature_;
PrimType prim_type_;
bool record_evaluate_add_attr_;
bool is_const_value_;
std::string id_{""};
};
inline std::ostream &operator<<(std::ostream &os, const PrimitivePtr &p) {
......
......@@ -335,7 +335,7 @@ static void PrintTimeStat(std::ostringstream &oss, const TimeInfoGroup &group, c
void MsProfile::Print() {
GetProfile()->Print();
std::vector<std::string> items = {"substitution.", "renormalize.", "replace.", "match.",
"func_graph_cloner_run.", "meta_graph.", "manager."};
"func_graph_cloner_run.", "meta_graph.", "manager.", "pynative"};
std::vector<TimeInfoGroup> groups(items.size() + 1);
const auto &stat = GetSingleton().time_stat_;
// group all time infos
......
......@@ -28,7 +28,7 @@ hastype = Primitive('hastype')
cast = P.Cast()
dtype = P.DType()
isconstant = Primitive('is_constant')
isconstant.add_prim_attr('const_value', True)
isconstant.set_is_const_value(True)
issubclass_ = P.IsSubClass()
......
......@@ -1027,7 +1027,7 @@ class InvertPermutation(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
"""init InvertPermutation"""
self.const_value = True
self.set_is_const_value(True)
def __infer__(self, x):
x_shp = x['shape']
......
......@@ -352,7 +352,7 @@ def constexpr(fn=None, get_instance=True, name=None):
def __init__(self):
op_name = name if name else fn.__name__
PrimitiveWithInfer.__init__(self, op_name)
self.const_value = True
self.set_is_const_value(True)
def infer_value(self, *args):
return fn(*args)
......
......@@ -65,27 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
py::none py_none;
py::args args = py::make_tuple(conv_obj, op_name, op_inputs);
py::list args_input = args[PY_INPUTS];
return GenerateOpExecInfo(args, &args_input);
}
TEST_F(TestPynativeExecute, TestRunOpInVM) {
py::tuple result;
PynativeStatusCode status;
auto op_exec_info_ptr = ConstructOpExecInfo();
result = pynative::RunOpInVM(op_exec_info_ptr, &status);
ASSERT_EQ(status, PYNATIVE_SUCCESS);
}
TEST_F(TestPynativeExecute, TestRunOp) {
py::none py_none;
auto op_exec_info_ptr = ConstructOpExecInfo();
py::tuple outputs = pynative::RunOp(
py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name, op_exec_info_ptr->op_inputs));
if (outputs.size() == 0) {
FAIL();
} else {
SUCCEED();
}
return GenerateOpExecInfo(args);
}
TEST_F(TestPynativeExecute, TestCreateContext) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册