提交 6566b383 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3033 decoupling primitive of compute function

Merge pull request !3033 from lianliguang/primi-decoupling-v2
......@@ -307,7 +307,7 @@ CNodePtr KernelGraph::NewCNode(const std::vector<AnfNodePtr> &inputs) {
if (inputs.size() == 1 || !feature_map_input_indexs.empty()) {
kernel_info->SetFeatureMapFlag(true);
}
if (AnfAlgo::IsRealCNodeKernel(cnode)) {
if (AnfAlgo::IsRealKernel(cnode)) {
AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode);
AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode);
}
......
......@@ -363,19 +363,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
}
auto func = op_exec_info->py_primitive->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
MS_LOG(ERROR) << "VM failed to get func";
auto primitive = op_exec_info->py_primitive;
MS_EXCEPTION_IF_NULL(primitive);
auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs);
if (py::isinstance<py::none>(result)) {
MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func";
*status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR;
py::tuple err_ret(0);
return std::move(err_ret);
}
// execute op
py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs));
py::tuple tuple_result = py::make_tuple(result);
*status = PYNATIVE_SUCCESS;
MS_LOG(INFO) << "RunOpInVM end";
return std::move(result);
return std::move(tuple_result);
}
bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
......
......@@ -15,6 +15,9 @@
*/
#include "utils/primitive_utils.h"
#include <memory>
#include "pipeline/jit/parse/python_adapter.h"
#include "utils/log_adapter.h"
#include "common/utils.h"
......@@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) {
py::object fn = mod.attr(common::SafeCStr(name));
return fn;
}
py::tuple ConvertDatatoPyTuple(const VectorRef &args) {
auto py_args = py::tuple(args.size());
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString();
i++;
}
return py_args;
}
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) {
auto func = GetComputeFunction(prim->name());
if (py::isinstance<py::none>(func)) {
MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented";
}
auto py_args = ConvertDatatoPyTuple(args);
py::object obj = func(*py_args);
return std::make_shared<PyObjectRef>(obj);
}
} // namespace mindspore
......@@ -19,6 +19,7 @@
#include <string>
#include "pybind11/pybind11.h"
#include "utils/base_ref.h"
namespace py = pybind11;
......@@ -28,6 +29,10 @@ py::function GetBpropFunctionByObj(py::object obj);
py::function GetBpropFunction(std::string name);
py::function GetComputeFunction(std::string name);
BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args);
py::tuple ConvertDatatoPyTuple(const VectorRef &args);
} // namespace mindspore
#endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_
......@@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) {
}
BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) {
PrimitivePyPtr operation = dyn_cast<PrimitivePy>(prim);
MS_LOG(DEBUG) << "operation start " << prim->name();
auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name());
if (py::isinstance<py::none>(func)) {
MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented";
}
py::tuple py_args = py::tuple(args.size());
MS_LOG(DEBUG) << "input for operation:";
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg: " << i << ":";
i++;
}
py::object obj = func(*py_args);
MS_LOG(DEBUG) << "result:" << py::str(obj);
return obj;
MS_EXCEPTION_IF_NULL(prim);
auto result = prim->RunComputeFunction(args);
if (result.is_null()) {
return RunComputeFunction(prim, args);
}
return result;
}
} // namespace compile
......
......@@ -83,6 +83,7 @@ class Primitive : public Named {
void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; }
void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); }
virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; }
ValuePtr GetAttr(const std::string &attrName) const {
auto iter = attrs_.find(attrName);
......
......@@ -79,13 +79,7 @@ py::function PrimitivePy::GetBpropFunction() {
}
BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
auto py_args = py::tuple(args.size());
size_t i = 0;
for (auto &arg : args) {
py_args[i] = BaseRefToPyData(arg);
MS_LOG(DEBUG) << "arg:" << i << ":";
i++;
}
auto py_args = ConvertDatatoPyTuple(args);
py::object obj;
bool is_bprop = this->HasAttr(kBpropAttrName);
if (is_bprop) {
......@@ -123,7 +117,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const {
return std::make_shared<PyObjectRef>(obj);
}
py::function PrimitivePy::GetComputeFunction() {
py::function PrimitivePy::GetComputeFunction() const {
static const char *const compute_func_name = "vm_impl";
if (py::hasattr(python_obj_, compute_func_name)) {
......@@ -176,6 +170,32 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) {
this->set_hook(primitive_py->hook());
}
BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const {
auto py_args = ConvertDatatoPyTuple(args);
auto result = this->RunPyComputeFunction(py_args);
if (py::isinstance<py::none>(result)) {
return std::make_shared<BaseRef>(nullptr);
}
return std::make_shared<PyObjectRef>(result);
}
py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const {
auto func = this->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
return py::none();
}
auto result = func(*py_args);
return result;
}
bool PrimitivePy::HasComputeFunction() const {
auto func = GetComputeFunction();
if (py::isinstance<py::none>(func)) {
return false;
}
return true;
}
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)
......
......@@ -41,7 +41,6 @@ class PrimitivePy : public Primitive {
~PrimitivePy() override = default;
MS_DECLARE_PARENT(PrimitivePy, Primitive);
py::function GetBpropFunction();
py::function GetComputeFunction();
void set_signatures(
std::vector<std::tuple<std::string, SignatureEnumRW, SignatureEnumKind, py::object, SignatureEnumDType>>
......@@ -57,11 +56,15 @@ class PrimitivePy : public Primitive {
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }
BaseRef RunHookFunction(const VectorRef &args) const override;
BaseRef RunComputeFunction(const VectorRef &args) const override;
py::object RunPyComputeFunction(const py::tuple &py_args) const;
bool HasComputeFunction() const;
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
bool is_tuple_input_ = false;
private:
py::function GetComputeFunction() const;
py::object python_obj_;
py::function hook_;
std::vector<Signature> signatures_;
......
......@@ -454,8 +454,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) {
ASSERT_TRUE(conv2d_ptr);
if (nullptr != conv2d_ptr) {
MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name();
auto func = conv2d_ptr->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
if(!conv2d_ptr->HasComputeFunction()){
MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented";
}
......
......@@ -294,8 +294,7 @@ TEST_F(TestStepParallel, CreatOpInstance) {
ASSERT_TRUE(allreduce_ptr);
if (nullptr != allreduce_ptr) {
MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name();
auto func = allreduce_ptr->GetComputeFunction();
if (py::isinstance<py::none>(func)) {
if (!allreduce_ptr->HasComputeFunction()) {
MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented";
}
......
......@@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) {
std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);
AnfNodePtrList anf_list;
AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
......@@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) {
std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);
AnfNodePtrList anf_list;
AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
......@@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) {
std::vector<BaseRef> todos(splits.size());
auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos),
[](const BaseRef& seg) -> bool { return utils::isa<VectorRef>(seg); });
[](const BaseRef &seg) -> bool { return utils::isa<VectorRef>(seg); });
todos.resize(std::distance(todos.begin(), it));
ASSERT_EQ(todos.size(), 1);
AnfNodePtrList anf_list;
AnfNodePtrList anf_list;
for (auto &item : utils::cast<VectorRef>(todos[0])) {
anf_list.push_back(utils::cast<AnfNodePtr>(item));
}
......@@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) {
TEST_F(TestCompileSegmentRunner, test_RunOperation1) {
VectorRef args({1});
auto res = RunOperation(prim::kPrimIdentity, args);
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimIdentity->name()), py::none()), args);
ASSERT_EQ(py::cast<int>(BaseRefToPyData(res)), 1);
}
TEST_F(TestCompileSegmentRunner, test_RunOperation2) {
VectorRef args({1, 2});
auto res = RunOperation(prim::kPrimScalarGt, args);
auto res = RunOperation(std::make_shared<PrimitivePy>(py::str(prim::kPrimScalarGt->name()), py::none()), args);
ASSERT_EQ(py::cast<bool>(BaseRefToPyData(res)), false);
}
} // namespace compile
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册