提交 ffe8b5d3 编写于 作者: L lvliang

pynative-add-op-supported

上级 936bae7b
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
constexpr size_t kType32Len = 4;
std::vector<int> Convert2Int(const std::vector<size_t> &v) { std::vector<int> Convert2Int(const std::vector<size_t> &v) {
std::vector<int> result; std::vector<int> result;
(void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt); (void)std::transform(v.begin(), v.end(), std::back_inserter(result), SizeToInt);
...@@ -264,6 +265,62 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod ...@@ -264,6 +265,62 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
} }
} }
template <typename T>
tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
size_t data_length) {
MS_EXCEPTION_IF_NULL(value_tuple_ptr);
MS_EXCEPTION_IF_NULL(type_ptr);
std::vector<T> values;
for (const auto &v : value_tuple_ptr->value()) {
MS_EXCEPTION_IF_NULL(v);
if (v->isa<Scalar>()) {
ScalarPtr scalar = v->cast<ScalarPtr>();
values.push_back(GetValue<T>(scalar));
} else {
MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
return nullptr;
}
}
std::vector<int> tensor_shape = {SizeToInt(values.size())};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
MS_EXCEPTION_IF_NULL(tensor);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
tensor->set_device_info(device_info);
auto data_ptr = tensor->data_c(true);
MS_EXCEPTION_IF_NULL(data_ptr);
auto elem_num = values.size() * data_length;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
}
return tensor;
}
tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
MS_EXCEPTION_IF_NULL(value_tuple);
tensor::TensorPtr tensor = nullptr;
ValuePtr v = *(value_tuple->value().begin());
MS_EXCEPTION_IF_NULL(v);
// Currently we only deal with the scalar tuple
if (!v->isa<Scalar>()) {
MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
return nullptr;
}
ScalarPtr scalar = v->cast<ScalarPtr>();
MS_EXCEPTION_IF_NULL(scalar);
if (scalar->isa<IntergerImm>()) {
tensor = CreateTensorWithValueTuple<int>(value_tuple, kInt32, kType32Len);
} else if (scalar->isa<FloatImm>()) {
tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, kType32Len);
} else {
auto type = scalar->type();
auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
return nullptr;
}
return tensor;
}
bool IsNopNode(const AnfNodePtr &node) { bool IsNopNode(const AnfNodePtr &node) {
auto context_ptr = MsContext::GetInstance(); auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr); MS_EXCEPTION_IF_NULL(context_ptr);
......
...@@ -135,6 +135,11 @@ void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_i ...@@ -135,6 +135,11 @@ void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_i
void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num, void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &kernel_graph, const AnfNodePtr &anf_node_ptr, size_t output_num,
std::vector<AnfNodePtr> *outputs); std::vector<AnfNodePtr> *outputs);
tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
size_t data_length);
tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple);
bool IsNopNode(const AnfNodePtr &node); bool IsNopNode(const AnfNodePtr &node);
void HideNopNode(session::KernelGraph *const graph); void HideNopNode(session::KernelGraph *const graph);
......
...@@ -17,10 +17,44 @@ ...@@ -17,10 +17,44 @@
#include <utility> #include <utility>
#include "utils/utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
#include "operator/ops.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
ConstInputToAttrInfoRegistry::ConstInputToAttrInfoRegistry() {
Register(prim::kPrimCast->name(), {1});
Register(prim::kPrimConv2DBackpropInput->name(), {2});
Register(prim::kPrimConv2DBackpropFilter->name(), {2});
Register(prim::kPrimReshape->name(), {1});
Register(prim::kPrimReduceMax->name(), {1});
Register(prim::kPrimReduceMin->name(), {1});
Register(prim::kPrimReduceSum->name(), {1});
Register(prim::kPrimReduceMean->name(), {1});
Register(prim::kPrimGatherV2->name(), {2});
Register(prim::kPrimTranspose->name(), {1});
Register(prim::kPrimUnsortedSegmentSum->name(), {2});
Register(prim::kPrimOneHot->name(), {1});
Register(kUnsortedSegmentProdOpName, {2});
Register(kUnsortedSegmentMinOpName, {2});
Register(kSimpleMeanGradOpName, {1});
Register(kMeanGradOpName, {1});
Register(kSliceOpName, {1, 2});
Register(kSliceGradOpName, {2, 3});
Register(kTileOpName, {1});
Register(kScatterNdOpName, {2});
Register(kStridedSliceAssignOpName, {1, 2, 3});
Register(kStridedSliceOpName, {1, 2, 3});
Register(kStridedSliceGradOpName, {1, 2, 3, 4});
Register(kFlattenGradOpName, {1});
Register(kExpandDimsOpName, {1});
Register(kSplitOpName, {0});
Register(kTopKOpName, {1});
Register(kSparseApplyAdagradOpName, {2});
Register(kResizeNearestNeighborGrad, {1});
}
ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() { ConstInputToAttrInfoRegistry &ConstInputToAttrInfoRegistry::Instance() {
static ConstInputToAttrInfoRegistry instance; static ConstInputToAttrInfoRegistry instance;
return instance; return instance;
......
...@@ -54,7 +54,7 @@ class ConstInputToAttrInfoRegistry { ...@@ -54,7 +54,7 @@ class ConstInputToAttrInfoRegistry {
bool GetRegisterByOpName(const std::string &op_name, ConstInputToAttrInfoRegister *reg) const; bool GetRegisterByOpName(const std::string &op_name, ConstInputToAttrInfoRegister *reg) const;
private: private:
ConstInputToAttrInfoRegistry() = default; ConstInputToAttrInfoRegistry();
~ConstInputToAttrInfoRegistry() = default; ~ConstInputToAttrInfoRegistry() = default;
DISABLE_COPY_AND_ASSIGN(ConstInputToAttrInfoRegistry) DISABLE_COPY_AND_ASSIGN(ConstInputToAttrInfoRegistry)
std::unordered_map<std::string, ConstInputToAttrInfoRegister> op_input_to_attr_map_; std::unordered_map<std::string, ConstInputToAttrInfoRegister> op_input_to_attr_map_;
......
...@@ -87,37 +87,5 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An ...@@ -87,37 +87,5 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo()); ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
return cnode; return cnode;
} }
void ConvertConstInputToAttr::Init() {
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimCast->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimConv2DBackpropInput->name(), {2});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimConv2DBackpropFilter->name(), {2});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReshape->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceMax->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceMin->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceSum->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimReduceMean->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimGatherV2->name(), {2});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimTranspose->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimUnsortedSegmentSum->name(), {2});
ConstInputToAttrInfoRegistry::Instance().Register(prim::kPrimOneHot->name(), {1});
ConstInputToAttrInfoRegistry::Instance().Register(kUnsortedSegmentProdOpName, {2});
ConstInputToAttrInfoRegistry::Instance().Register(kUnsortedSegmentMinOpName, {2});
ConstInputToAttrInfoRegistry::Instance().Register(kSimpleMeanGradOpName, {1});
ConstInputToAttrInfoRegistry::Instance().Register(kMeanGradOpName, {1});
ConstInputToAttrInfoRegistry::Instance().Register(kSliceOpName, {1, 2});
ConstInputToAttrInfoRegistry::Instance().Register(kSliceGradOpName, {2, 3});
ConstInputToAttrInfoRegistry::Instance().Register(kTileOpName, {1});
ConstInputToAttrInfoRegistry::Instance().Register(kScatterNdOpName, {2});
ConstInputToAttrInfoRegistry::Instance().Register(kStridedSliceAssignOpName, {1, 2, 3});
ConstInputToAttrInfoRegistry::Instance().Register(kStridedSliceOpName, {1, 2, 3});
ConstInputToAttrInfoRegistry::Instance().Register(kStridedSliceGradOpName, {1, 2, 3, 4});
ConstInputToAttrInfoRegistry::Instance().Register(kFlattenGradOpName, {1});
ConstInputToAttrInfoRegistry::Instance().Register(kExpandDimsOpName, {1});
ConstInputToAttrInfoRegistry::Instance().Register(kSplitOpName, {0});
ConstInputToAttrInfoRegistry::Instance().Register(kTopKOpName, {1});
ConstInputToAttrInfoRegistry::Instance().Register(kSparseApplyAdagradOpName, {2});
ConstInputToAttrInfoRegistry::Instance().Register(kResizeNearestNeighborGrad, {1});
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore
...@@ -27,14 +27,11 @@ namespace opt { ...@@ -27,14 +27,11 @@ namespace opt {
class ConvertConstInputToAttr : public PatternProcessPass { class ConvertConstInputToAttr : public PatternProcessPass {
public: public:
explicit ConvertConstInputToAttr(bool multigraph = true) explicit ConvertConstInputToAttr(bool multigraph = true)
: PatternProcessPass("convert_const_input_to_attr", multigraph) { : PatternProcessPass("convert_const_input_to_attr", multigraph) {}
Init();
}
~ConvertConstInputToAttr() override = default; ~ConvertConstInputToAttr() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private: private:
void Init();
std::unordered_map<std::string, std::unordered_set<size_t>> op_input_attr_map_; std::unordered_map<std::string, std::unordered_set<size_t>> op_input_attr_map_;
}; };
} // namespace opt } // namespace opt
......
...@@ -19,69 +19,13 @@ ...@@ -19,69 +19,13 @@
#include <memory> #include <memory>
#include "utils/graph_utils.h" #include "utils/graph_utils.h"
#include "pre_activate/common/helper.h"
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h" #include "session/kernel_graph.h"
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
constexpr size_t kType32Len = 4;
template <typename T>
tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_ptr, const TypePtr &type_ptr,
size_t data_length) {
MS_EXCEPTION_IF_NULL(value_tuple_ptr);
MS_EXCEPTION_IF_NULL(type_ptr);
std::vector<T> values;
for (const auto &v : value_tuple_ptr->value()) {
MS_EXCEPTION_IF_NULL(v);
if (v->isa<Scalar>()) {
ScalarPtr scalar = v->cast<ScalarPtr>();
values.push_back(GetValue<T>(scalar));
} else {
MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
return nullptr;
}
}
std::vector<int> tensor_shape = {SizeToInt(values.size())};
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_ptr->type_id(), tensor_shape);
MS_EXCEPTION_IF_NULL(tensor);
tensor::DeviceInfo device_info{kOpFormat_DEFAULT, type_ptr};
tensor->set_device_info(device_info);
auto data_ptr = tensor->data_c(true);
MS_EXCEPTION_IF_NULL(data_ptr);
auto elem_num = values.size() * data_length;
auto ret_code = memcpy_s(data_ptr, static_cast<size_t>(tensor->data().nbytes()), values.data(), elem_num);
if (ret_code != 0) {
MS_LOG(EXCEPTION) << "Failed to copy data into Tensor.";
}
return tensor;
}
tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) {
MS_EXCEPTION_IF_NULL(value_tuple);
tensor::TensorPtr tensor = nullptr;
ValuePtr v = *(value_tuple->value().begin());
MS_EXCEPTION_IF_NULL(v);
// Currently we only deal with the scalar tuple
if (!v->isa<Scalar>()) {
MS_LOG(WARNING) << "The value " << v << "of tuple is not a scalar";
return nullptr;
}
ScalarPtr scalar = v->cast<ScalarPtr>();
MS_EXCEPTION_IF_NULL(scalar);
if (scalar->isa<IntergerImm>()) {
tensor = CreateTensorWithValueTuple<int>(value_tuple, kInt32, kType32Len);
} else if (scalar->isa<FloatImm>()) {
tensor = CreateTensorWithValueTuple<float>(value_tuple, kFloat32, kType32Len);
} else {
auto type = scalar->type();
auto type_str = (type == nullptr) ? "nullptr" : type->ToString();
MS_LOG(ERROR) << "Invalid scalar type: " << type_str;
return nullptr;
}
return tensor;
}
AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) { AnfNodePtr CreateTensorInput(const KernelGraphPtr &kernel_graph, const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node); MS_EXCEPTION_IF_NULL(input_node);
auto value_node = input_node->cast<ValueNodePtr>(); auto value_node = input_node->cast<ValueNodePtr>();
......
...@@ -158,8 +158,9 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat ...@@ -158,8 +158,9 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat
session->Init(ms_context->device_id()); session->Init(ms_context->device_id());
std::string graph_info = GetSingleOpGraphInfo(op_exec_info); std::string graph_info = GetSingleOpGraphInfo(op_exec_info);
session->BuildOp(*op_exec_info, graph_info); std::vector<tensor::TensorPtr> input_tensors;
py::tuple result = session->RunOp(*op_exec_info, graph_info); session->BuildOp(*op_exec_info, graph_info, &input_tensors);
py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors);
ms_context->set_enable_pynative_infer(false); ms_context->set_enable_pynative_infer(false);
*status = PYNATIVE_SUCCESS; *status = PYNATIVE_SUCCESS;
return result; return result;
......
...@@ -204,10 +204,12 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra ...@@ -204,10 +204,12 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !";
// construct graph include one op // construct graph include one op
auto graph = ConstructSingleOpGraph(op_run_info); auto graph = ConstructSingleOpGraph(op_run_info, input_tensors);
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
opt::RunOpAscendBackendIRFusionOptimization(graph); opt::RunOpAscendBackendIRFusionOptimization(graph);
// kernel select // kernel select
...@@ -222,14 +224,12 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph ...@@ -222,14 +224,12 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph
run_op_graphs_[graph_info] = graph; run_op_graphs_[graph_info] = graph;
} }
py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
auto graph = run_op_graphs_[graph_info]; auto graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!";
// malloc mem // malloc mem
std::vector<tensor::TensorPtr> input_tensors = {};
std::vector<bool> tensors_mask = {};
ToTensorPtr(op_run_info, &input_tensors, &tensors_mask);
RunOpMemoryAlloc(input_tensors, graph.get()); RunOpMemoryAlloc(input_tensors, graph.get());
// load input data to device // load input data to device
LoadInputData(graph, input_tensors); LoadInputData(graph, input_tensors);
......
...@@ -41,8 +41,10 @@ class AscendSession : public SessionBasic { ...@@ -41,8 +41,10 @@ class AscendSession : public SessionBasic {
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildGraph(GraphId) override; void BuildGraph(GraphId) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; std::vector<tensor::TensorPtr> *input_tensors) override;
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) override;
// set parameters of final graph // set parameters of final graph
GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override; GraphId SetFinalGraphInput(const std::vector<AnfNodePtr> &args) override;
......
...@@ -132,9 +132,11 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten ...@@ -132,9 +132,11 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
} }
} }
void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors) {
// Prepare the graph // Prepare the graph
auto kernel_graph = ConstructSingleOpGraph(op_run_info); MS_EXCEPTION_IF_NULL(input_tensors);
auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors);
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
SelectKernel(kernel_graph); SelectKernel(kernel_graph);
StartKernelRT(); StartKernelRT();
...@@ -142,12 +144,10 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in ...@@ -142,12 +144,10 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in
run_op_graphs_[graph_info] = kernel_graph; run_op_graphs_[graph_info] = kernel_graph;
} }
py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) { py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) {
auto kernel_graph = run_op_graphs_[graph_info]; auto kernel_graph = run_op_graphs_[graph_info];
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<tensor::TensorPtr> input_tensors = {};
std::vector<bool> tensors_mask = {};
ToTensorPtr(op_run_info, &input_tensors, &tensors_mask);
RunOpAllocateMemory(input_tensors, kernel_graph.get()); RunOpAllocateMemory(input_tensors, kernel_graph.get());
// Execute the computation // Execute the computation
LoadInputData(kernel_graph, input_tensors); LoadInputData(kernel_graph, input_tensors);
......
...@@ -39,8 +39,10 @@ class GPUSession : public SessionBasic { ...@@ -39,8 +39,10 @@ class GPUSession : public SessionBasic {
GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override; GraphId CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) override;
void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override; void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) override;
void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info) override; std::vector<tensor::TensorPtr> *input_tensors) override;
py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info,
const std::vector<tensor::TensorPtr> &input_tensors) override;
private: private:
void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; void SelectKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <utility> #include <utility>
#include <algorithm> #include <algorithm>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "pipeline/parse/data_converter.h" #include "pipeline/parse/data_converter.h"
#include "ir/manager.h" #include "ir/manager.h"
#include "operator/ops.h" #include "operator/ops.h"
...@@ -26,6 +27,7 @@ ...@@ -26,6 +27,7 @@
#include "session/anf_runtime_algorithm.h" #include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h" #include "kernel/oplib/oplib.h"
#include "pre_activate/common/common_backend_optimization.h" #include "pre_activate/common/common_backend_optimization.h"
#include "pre_activate/pass/const_input_to_attr_registry.h"
#include "pre_activate/common/helper.h" #include "pre_activate/common/helper.h"
#include "common/utils.h" #include "common/utils.h"
#include "ir/dtype.h" #include "ir/dtype.h"
...@@ -178,56 +180,113 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph, ...@@ -178,56 +180,113 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return ret; return ret;
} }
std::string FindOpInputParameterType(const std::string &op_name, kernel::OpImplyType implyType, size_t index) { bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim,
std::string para_type; const std::unordered_set<size_t> &input_attrs) {
auto op_info = kernel::OpLib::FindOp(op_name, implyType); MS_EXCEPTION_IF_NULL(op_prim);
if (op_info == nullptr) { auto input_names_value = op_prim->GetAttr(kAttrInputNames);
return para_type; if (input_names_value == nullptr) {
return false;
}
auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value);
if (input_index >= input_names_vec.size()) {
MS_LOG(EXCEPTION) << "The input index: " << input_index << " is large than the input names vector size!";
} }
auto op_inputs_info_vec = op_info->inputs_ptr();
if (index >= op_inputs_info_vec.size()) { if (input_attrs.find(input_index) != input_attrs.end()) {
return para_type; ValuePtr value = parse::data_converter::PyDataToValue(input_object);
MS_EXCEPTION_IF_NULL(value);
auto input_name = input_names_vec[input_index];
op_prim->set_attr(input_name, value);
return true;
} }
auto op_io_info = op_inputs_info_vec[index]; return false;
MS_EXCEPTION_IF_NULL(op_io_info);
para_type = op_io_info->param_type();
return para_type;
} }
void RunOpConvertConstInputToAttr(const OpRunInfo &op_run_info, const std::shared_ptr<CNode> &cnode) { void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim,
MS_EXCEPTION_IF_NULL(cnode); std::vector<tensor::TensorPtr> *input_tensor) {
auto op_inputs = op_run_info.op_inputs; MS_EXCEPTION_IF_NULL(op_prim);
// get input names vector from attrs MS_EXCEPTION_IF_NULL(input_tensor);
auto primitive = AnfAlgo::GetCNodePrimitive(cnode); for (const auto &input_object : tuple_inputs) {
MS_EXCEPTION_IF_NULL(primitive); if (!py::isinstance<tensor::Tensor>(input_object)) {
auto input_names_value = primitive->GetAttr(kAttrInputNames); MS_LOG(EXCEPTION) << "The input object is not a tensor!";
if (input_names_value == nullptr) { }
auto tensor = py::cast<tensor::TensorPtr>(input_object);
MS_EXCEPTION_IF_NULL(tensor);
input_tensor->push_back(tensor);
}
op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
}
void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *input_tensor) {
MS_EXCEPTION_IF_NULL(input_tensor);
ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
MS_EXCEPTION_IF_NULL(input_value);
if (!input_value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
}
auto value_tuple = input_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
tensor::TensorPtr tensor_ptr = nullptr;
tensor_ptr = opt::CreateTupleTensor(value_tuple);
MS_EXCEPTION_IF_NULL(tensor_ptr);
input_tensor->push_back(tensor_ptr);
}
void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensor) {
MS_EXCEPTION_IF_NULL(op_prim);
MS_EXCEPTION_IF_NULL(input_tensor);
tensor::TensorPtr tensor_ptr = nullptr;
if (py::isinstance<tensor::Tensor>(input_object)) {
tensor_ptr = py::cast<tensor::TensorPtr>(input_object);
} else if (py::isinstance<py::float_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(input_object), kFloat32);
} else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), nullptr);
} else if (py::isinstance<py::list>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr);
} else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::tuple>(input_object)) {
auto tuple_inputs = py::cast<py::tuple>(input_object);
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensor);
} else {
ConvertValueTupleToTensor(input_object, input_tensor);
}
return; return;
} else {
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";
} }
auto input_names_vec = GetValue<std::vector<std::string>>(input_names_value); MS_EXCEPTION_IF_NULL(tensor_ptr);
// convert const input to attr input_tensor->push_back(tensor_ptr);
size_t input_num = op_inputs.size(); }
if (input_num != input_names_vec.size()) {
MS_LOG(EXCEPTION) << "input name number " << input_names_vec.size() << "is not equal to input value number " void ConvertInputPyobject(const OpRunInfo &op_run_info, const PrimitivePtr &op_prim,
<< input_num; std::vector<tensor::TensorPtr> *input_tensors, std::vector<bool> *tensors_mask) {
MS_EXCEPTION_IF_NULL(op_prim);
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(tensors_mask);
if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) {
MS_LOG(EXCEPTION) << "Op input size " << op_run_info.op_inputs.size() << " should be equal to op input mask size "
<< op_run_info.inputs_mask.size();
} }
opt::ConstInputToAttrInfoRegister reg;
bool reg_exist = opt::ConstInputToAttrInfoRegistry::Instance().GetRegisterByOpName(op_run_info.op_name, &reg);
size_t input_num = op_run_info.op_inputs.size();
MS_LOG(INFO) << "py input size: " << input_num;
for (size_t index = 0; index < input_num; ++index) { for (size_t index = 0; index < input_num; ++index) {
// skip tensor // convert const input to attr
if (py::isinstance<tensor::Tensor>(op_inputs[index])) { if (reg_exist &&
continue; RunOpConvertConstInputToAttr(op_run_info.op_inputs[index], index, op_prim, reg.GetConstInputAttrInfo())) {
}
// convert to attr
auto para_type = FindOpInputParameterType(op_run_info.op_name, kernel::OpImplyType::kTBE, index);
if (!para_type.empty() && para_type == kAttrDynInput) {
auto tuple_inputs = py::cast<py::tuple>(op_inputs[index]);
primitive->set_attr(kAttrDynInputSizes, MakeValue(std::vector<int>{SizeToInt(tuple_inputs.size())}));
continue; continue;
} }
ValuePtr value = parse::data_converter::PyDataToValue(op_inputs[index]); // convert const and tuple input to tensor
MS_EXCEPTION_IF_NULL(value); ConvertPyObjectToTensor(op_run_info.op_inputs[index], op_prim, input_tensors);
auto input_name = input_names_vec[index]; // make tensors, weight : 1, data : 0
// set the input node as attr of the cnode, key is name of input node,value is input node's value std::vector<bool> new_mask(input_tensors->size() - tensors_mask->size(),
primitive->set_attr(input_name, value); py::cast<bool>(op_run_info.inputs_mask[index]));
tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end());
} }
} }
...@@ -638,40 +697,6 @@ void SessionBasic::Summary(KernelGraph *graph) { ...@@ -638,40 +697,6 @@ void SessionBasic::Summary(KernelGraph *graph) {
summary_callback_(0, params_list); summary_callback_(0, params_list);
} }
void SessionBasic::ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor::TensorPtr> *inputs,
std::vector<bool> *tensor_mask) {
MS_EXCEPTION_IF_NULL(inputs);
MS_EXCEPTION_IF_NULL(tensor_mask);
if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) {
MS_LOG(EXCEPTION) << "Op input size " << op_run_info.op_inputs.size() << " should be equal to op input mask size "
<< op_run_info.inputs_mask.size();
}
size_t input_num = op_run_info.op_inputs.size();
// get tensors from op_inputs
for (size_t i = 0; i < input_num; ++i) {
tensor::TensorPtr tensor_ptr = nullptr;
auto param_type = FindOpInputParameterType(op_run_info.op_name, kernel::OpImplyType::kTBE, i);
if (py::isinstance<tensor::Tensor>(op_run_info.op_inputs[i])) {
tensor_ptr = py::cast<tensor::TensorPtr>(op_run_info.op_inputs[i]);
} else if (!param_type.empty() && param_type == kAttrDynInput) {
auto tuple_inputs = py::cast<py::tuple>(op_run_info.op_inputs[i]);
for (auto &&tuple_input : tuple_inputs) {
tensor_ptr = py::cast<tensor::TensorPtr>(tuple_input);
MS_EXCEPTION_IF_NULL(tensor_ptr);
inputs->push_back(tensor_ptr);
tensor_mask->push_back(py::cast<bool>(op_run_info.inputs_mask[i]));
}
continue;
} else if (op_run_info.op_name == kApplyMomentumOpName && py::isinstance<py::float_>(op_run_info.op_inputs[i])) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::float_>(op_run_info.op_inputs[i]), kFloat32);
}
if (tensor_ptr != nullptr) {
inputs->push_back(tensor_ptr);
tensor_mask->push_back(py::cast<bool>(op_run_info.inputs_mask[i]));
}
}
}
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> output_args; std::vector<AnfNodePtr> output_args;
...@@ -724,30 +749,27 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr ...@@ -724,30 +749,27 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
MS_LOG(INFO) << "Finish!"; MS_LOG(INFO) << "Finish!";
} }
std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info) { std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info,
std::vector<tensor::TensorPtr> *input_tensors) {
MS_EXCEPTION_IF_NULL(input_tensors);
auto graph = std::make_shared<KernelGraph>(); auto graph = std::make_shared<KernelGraph>();
std::vector<AnfNodePtr> inputs; std::vector<AnfNodePtr> inputs;
if (op_run_info.op_inputs.size() != op_run_info.inputs_mask.size()) {
MS_LOG(EXCEPTION) << "op_run_info inputs.size" << op_run_info.op_inputs.size()
<< " should be equal to parameter_mask.size " << op_run_info.inputs_mask.size();
}
// set input[0] // set input[0]
if (op_run_info.py_primitive == nullptr) { PrimitivePtr op_prim = op_run_info.py_primitive;
inputs.push_back(std::make_shared<ValueNode>(std::make_shared<Primitive>(op_run_info.op_name))); if (op_prim == nullptr) {
} else { op_prim = std::make_shared<Primitive>(op_run_info.op_name);
inputs.push_back(std::make_shared<ValueNode>(op_run_info.py_primitive));
} }
inputs.push_back(std::make_shared<ValueNode>(op_prim));
// set input parameter // set input parameter
std::vector<tensor::TensorPtr> input_tensors;
std::vector<bool> tensors_mask; std::vector<bool> tensors_mask;
ToTensorPtr(op_run_info, &input_tensors, &tensors_mask); ConvertInputPyobject(op_run_info, op_prim, input_tensors, &tensors_mask);
MS_LOG(INFO) << "Input tensor size" << input_tensors.size(); MS_LOG(INFO) << "Input tensor size: " << input_tensors->size();
if (input_tensors.size() != tensors_mask.size()) { if (input_tensors->size() != tensors_mask.size()) {
MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() << " should be equal to tensors mask size " MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size "
<< tensors_mask.size(); << tensors_mask.size();
} }
for (size_t i = 0; i < input_tensors.size(); ++i) { for (size_t i = 0; i < input_tensors->size(); ++i) {
auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); auto parameter = ConstructRunOpParameter(graph, input_tensors->at(i), tensors_mask[i]);
inputs.push_back(parameter); inputs.push_back(parameter);
graph->MutableInputs()->push_back(parameter); graph->MutableInputs()->push_back(parameter);
} }
...@@ -756,8 +778,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf ...@@ -756,8 +778,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
// set abstract,which include inferred shapes and types // set abstract,which include inferred shapes and types
cnode->set_abstract(op_run_info.abstract); cnode->set_abstract(op_run_info.abstract);
// set const input to attr if value is not a tensor,such as scalar or tuple
RunOpConvertConstInputToAttr(op_run_info, cnode);
// set execution order // set execution order
std::vector<CNodePtr> exe_order = {cnode}; std::vector<CNodePtr> exe_order = {cnode};
graph->set_execution_order(exe_order); graph->set_execution_order(exe_order);
......
...@@ -61,9 +61,11 @@ class SessionBasic { ...@@ -61,9 +61,11 @@ class SessionBasic {
virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0; virtual void RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) = 0;
virtual void BuildOp(const OpRunInfo &, const GraphInfo &) {} virtual void BuildOp(const OpRunInfo &, const GraphInfo &, std::vector<tensor::TensorPtr> *input_tensors) {}
virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &) { return py::tuple(); } virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors) {
return py::tuple();
}
virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback);
...@@ -96,10 +98,8 @@ class SessionBasic { ...@@ -96,10 +98,8 @@ class SessionBasic {
void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph); void CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr<KernelGraph> &graph);
CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph); CNodePtr ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph);
// create a single run op graph // create a single run op graph
std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info); std::shared_ptr<KernelGraph> ConstructSingleOpGraph(const OpRunInfo &op_run_info,
// get tensors from op inputs std::vector<tensor::TensorPtr> *input_tensor);
void ToTensorPtr(const OpRunInfo &op_run_info, std::vector<tensor::TensorPtr> *inputs,
std::vector<bool> *tensor_mask);
// trans BaseRef list to py::tuple // trans BaseRef list to py::tuple
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册