提交 5f468b65 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!2798 Decouple ParamValue from python

Merge pull request !2798 from hewei/decouple_param_value
......@@ -26,7 +26,7 @@
#include "utils/graph_utils.h"
#include "utils/symbolic.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "ir/tensor_py.h"
#include "pipeline/parse/python_adapter.h"
#include "pipeline/parse/resolve.h"
......@@ -485,8 +485,8 @@ void AnfExporter::OutputParameters(std::ofstream &ofs, const std::vector<AnfNode
MS_LOG(EXCEPTION) << "Param could not cast to parameter";
}
if (param_ptr->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
ofs << " = @" << DumpObject(param_value->value(), "D");
auto param_value = param_ptr->default_param();
ofs << " = @" << DumpObject(py::cast(param_value), "D");
}
// output comment
......@@ -1667,7 +1667,7 @@ class IrParser {
// load parameter default value from serialized file
py::object default_obj = LoadObject(lexer_.GetTokenText());
auto param_value_new = std::make_shared<ParamValuePy>(default_obj);
auto param_value_new = py::cast<ParamValuePtr>(default_obj);
param->set_default_param(param_value_new);
tok = lexer_.GetNextToken();
......
......@@ -25,7 +25,7 @@
#include "pybind11/pybind11.h"
#include "ir/meta_func_graph.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "ir/primitive.h"
#include "utils/graph_utils.h"
#include "utils/utils.h"
......@@ -321,18 +321,9 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << parameter->ToString();
auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
auto py_p = param_value->value();
if (py::hasattr(py_p, "default_input")) {
py_p = py_p.attr("default_input");
std::vector<int> shape;
if (py::hasattr(py_p, PYTHON_TENSOR_FLAG)) {
auto m_tensor = py_p.cast<std::shared_ptr<tensor::Tensor>>();
shape = m_tensor->shape();
} else if (py::hasattr(py_p, PYTHON_META_TENSOR_FLAG)) {
auto m_tensor = py_p.cast<std::shared_ptr<tensor::MetaTensor>>();
shape = m_tensor->shape();
}
auto tensor = param->default_param()->value();
if (tensor) {
auto &shape = tensor->shape();
std::ostringstream shape_str;
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(shape_str, ","));
buffer_ << "[" << shape_str.str() << "]";
......
......@@ -79,11 +79,7 @@ using KernelInfoDevicePtr = std::shared_ptr<KernelInfoDevice>;
class AnfVisitor;
class ParamValue {
public:
ParamValue() = default;
virtual ~ParamValue() = default;
};
class ParamValue;
using ParamValuePtr = std::shared_ptr<ParamValue>;
// AnfNode is the basic class of the IR definition derived from Base.
......
......@@ -19,7 +19,7 @@
#include <algorithm>
#include "ir/manager.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "operator/ops.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
......@@ -71,9 +71,8 @@ void Cloner::CloneParameter(const AnfNodePtr &node, const FuncGraphPtr &target,
new_param->set_abstract(old_param->abstract());
new_param->set_name(old_param->name());
if (old_param->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
new_param->set_default_param(param_value_new);
// Default parameter can be shared since it is readonly.
new_param->set_default_param(old_param->default_param());
}
ScopePtr scope = (node->scope() != kDefaultScope) ? node->scope() : this->scope();
new_param->set_scope(scope);
......@@ -253,9 +252,8 @@ void Cloner::CloneParameter(const ParameterPtr &param, const AnfNodePtr &node) {
if (node->isa<Parameter>()) {
ParameterPtr old_param = dyn_cast<Parameter>(node);
if (old_param->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(old_param->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
param->set_default_param(param_value_new);
// Default parameter can be shared since it is readonly.
param->set_default_param(old_param->default_param());
}
param->set_name(old_param->name());
}
......
......@@ -19,7 +19,7 @@
#include <memory>
#include "ir/anf.h"
#include "ir/param_value.h"
namespace mindspore {
class ParamValueLite : public ParamValue {
......
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
#include <atomic>
#include <memory>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/tensor.h"
namespace mindspore {
class ParamValue {
public:
ParamValue() {}
ParamValue(const ParamValue &other) = default;
~ParamValue() = default;
tensor::MetaTensorPtr value() const { return value_; }
void set_value(const tensor::MetaTensorPtr &value) { value_ = value; }
const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; }
const std::string &sparse_grad() const { return sparse_grad_; }
void set_sparse_grad(const std::string &sparse_grad) { sparse_grad_ = sparse_grad; }
bool requires_grad() const { return requires_grad_; }
void set_requires_grad(bool requires_grad) { requires_grad_ = requires_grad; }
bool layerwise_parallel() const { return layerwise_parallel_; }
void set_layerwise_parallel(bool layerwise_parallel) { layerwise_parallel_ = layerwise_parallel; }
bool has_indexed_slices_grad() const { return has_indexed_slices_grad_; }
void set_has_indexed_slices_grad(bool b) { has_indexed_slices_grad_ = b; }
// Whether the parameter clone from other parameter.
bool cloned() const { return cloned_; }
// Whether the parameter is cloned.
bool be_cloned() const { return be_cloned_; }
// If the parameter is cloned, generate one index per clone.
const std::vector<int32_t> &be_cloned_index() const { return be_cloned_index_; }
// If the parameter clone from other parameter, it has a unique index.
int32_t cloned_index() const { return cloned_index_; }
// Make a cloned parameter and update clone info.
ParamValuePtr Clone() {
static std::atomic<int32_t> parameter_cloned_index{1};
int32_t index = parameter_cloned_index.fetch_add(1, std::memory_order_relaxed);
auto clone = std::make_shared<ParamValue>(*this);
clone->be_cloned_ = false;
clone->cloned_ = true;
clone->be_cloned_index_ = {};
clone->cloned_index_ = index;
this->be_cloned_ = true;
this->be_cloned_index_.push_back(index);
return clone;
}
private:
tensor::MetaTensorPtr value_;
std::string name_{"Parameter"};
std::string sparse_grad_;
bool requires_grad_{true};
bool layerwise_parallel_{false};
bool has_indexed_slices_grad_{false};
bool be_cloned_{false};
bool cloned_{false};
std::vector<int32_t> be_cloned_index_;
int32_t cloned_index_{0};
};
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ir/param_value.h"
#include "pybind11/pybind11.h"
#include "pybind_api/api_register.h"
namespace mindspore {
namespace py = pybind11;
REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue")
.def(py::init())
.def("clone", &ParamValue::Clone)
.def_property("data", &ParamValue::value, &ParamValue::set_value)
.def_property("name", &ParamValue::name, &ParamValue::set_name)
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
&ParamValue::set_layerwise_parallel)
.def_property("has_indexed_slices_grad", &ParamValue::has_indexed_slices_grad,
&ParamValue::set_has_indexed_slices_grad)
.def_property("sparse_grad", &ParamValue::sparse_grad, &ParamValue::set_sparse_grad)
.def(py::pickle(
[](const ParamValue &p) { // __getstate__
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(),
p.layerwise_parallel(), p.has_indexed_slices_grad(),
p.sparse_grad());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 6) {
std::runtime_error("Invalid state for ParamValue!");
}
ParamValuePtr p = std::make_shared<ParamValue>();
p->set_value(t[0].cast<tensor::TensorPtr>());
p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>());
p->set_has_indexed_slices_grad(t[4].cast<bool>());
p->set_sparse_grad(t[5].cast<std::string>());
return p;
}));
}));
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
#define MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
#include <memory>
#include "ir/anf.h"
#include "pybind11/pybind11.h"
namespace mindspore {
namespace py = pybind11;
class ParamValuePy : public ParamValue {
public:
ParamValuePy() : value_(py::none()) {}
explicit ParamValuePy(const py::object &value) : value_(value) {}
~ParamValuePy() override = default;
py::object value() { return value_; }
void set_value(const py::object &obj) { value_ = obj; }
private:
py::object value_;
};
using ParamValuePyPtr = std::shared_ptr<ParamValuePy>;
} // namespace mindspore
#endif // MINDSPORE_CCSRC_IR_PARAM_VALUE_PY_H_
......@@ -216,7 +216,7 @@ class Tensor : public MetaTensor {
std::string ToStringRepr() const;
bool is_init() { return init_flag_; }
bool is_init() const { return init_flag_; }
void set_init_flag(bool flag) { init_flag_ = flag; }
bool is_dirty() const { return dirty_; }
......
......@@ -213,9 +213,28 @@ static std::vector<int> GetShapeFromTuple(const py::tuple &tuple) {
}
REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
// Define python MetaTensor class.
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
return tensor;
}));
// Define python Tensor class.
// dtype should define before Tensor, because Tensor init depend dtype
(void)py::class_<Tensor, std::shared_ptr<Tensor>>(*m, "Tensor")
(void)py::class_<Tensor, MetaTensor, std::shared_ptr<Tensor>>(*m, "Tensor")
.def(py::init([](const Tensor &tensor) { return std::make_shared<Tensor>(tensor); }),
py::arg("input"))
.def(py::init([](const Tensor &tensor, const TypePtr &type_ptr) {
......@@ -252,6 +271,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
}),
py::arg("input"), py::arg("dtype") = nullptr)
.def_readonly(PYTHON_TENSOR_FLAG, &Tensor::parse_info_)
.def_property("init_flag", &Tensor::is_init, &Tensor::set_init_flag)
.def_property_readonly("dtype", &Tensor::Dtype, R"mydelimiter(
Get the tensor's data type.
......@@ -365,26 +385,6 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
/* Create a new C++ instance */
return TensorPy::MakeTensor(t[0].cast<py::array>());
}));
// Define python MetaTensor class.
(void)py::class_<MetaTensor, std::shared_ptr<MetaTensor>>(*m, "MetaTensor")
.def(py::init<TypePtr, const std::vector<int>>(), py::arg("dtype"), py::arg("shape"))
.def_readonly(PYTHON_META_TENSOR_FLAG, &MetaTensor::parse_info_)
.def_property_readonly("dtype", &MetaTensor::Dtype, "Get the MetaTensor's dtype.")
.def_property_readonly("shape", &MetaTensor::shape, "Get the MetaTensor's shape.")
.def(py::pickle(
[](const MetaTensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */
return py::make_tuple(static_cast<int>(t.data_type()), t.shape());
},
[](const py::tuple &t) { // __setstate__
if (t.size() != 2) {
throw std::runtime_error("Invalid state!");
}
/* Create a new C++ instance */
MetaTensor tensor(TypeId(t[0].cast<int>()), t[1].cast<std::vector<int>>());
return tensor;
}));
}));
} // namespace tensor
} // namespace mindspore
......@@ -23,8 +23,8 @@
#include <algorithm>
#include <functional>
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "ir/tensor.h"
#include "ir/param_value.h"
#include "debug/anf_ir_utils.h"
#include "operator/ops.h"
#include "proto/onnx.pb.h"
......@@ -187,13 +187,9 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_name);
SetParamToTensorProto(param, initializer_proto);
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
py::object obj = param_value->value();
py::object data = obj.attr("data");
if (py::isinstance<tensor::Tensor>(data)) {
auto method = data.attr("asnumpy");
py::array npy_data = method();
initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast<size_t>(npy_data.nbytes()));
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param()->value());
if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
}
}
}
......
......@@ -26,8 +26,8 @@
#include "debug/anf_ir_utils.h"
#include "proto/onnx.pb.h"
#include "operator/ops.h"
#include "ir/param_value_py.h"
#include "ir/tensor_py.h"
#include "ir/tensor.h"
#include "ir/param_value.h"
namespace mindspore {
enum OpMergeMode {
......@@ -449,13 +449,9 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP
initializer_proto->set_name(param_ptr->ToString());
SetTensorProtoInfo(param_ptr, initializer_proto);
// set value for initializer
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
py::object obj = param_value->value();
py::object data = obj.attr("data");
if (py::isinstance<tensor::Tensor>(data)) {
auto method = data.attr("asnumpy");
py::array npy_data = method();
initializer_proto->set_raw_data(npy_data.request(true).ptr, static_cast<size_t>(npy_data.nbytes()));
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param()->value());
if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
}
}
}
......
......@@ -19,7 +19,7 @@
#include <string>
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "pipeline/parse/python_adapter.h"
namespace mindspore {
......@@ -38,8 +38,7 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if (!para_ptr->has_default()) {
return false;
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(para_ptr->default_param());
return py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
return para_ptr->default_param()->requires_grad();
}
} // namespace parallel
} // namespace mindspore
......@@ -28,7 +28,7 @@
#include <vector>
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "ir/tensor.h"
#include "optimizer/opt.h"
#include "optimizer/optimizer.h"
......@@ -123,9 +123,8 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
is_parameter.push_back(require_grad);
bool requires_grad = input_parameter->default_param()->requires_grad();
is_parameter.push_back(requires_grad);
} else {
is_parameter.push_back(false);
}
......@@ -799,9 +798,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(casted_target_parameter);
if (casted_target_parameter->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(casted_target_parameter->default_param());
bool require_grad = py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "requires_grad"));
is_parameter.push_back(require_grad);
bool requires_grad = casted_target_parameter->default_param()->requires_grad();
is_parameter.push_back(requires_grad);
} else {
is_parameter.push_back(false);
}
......
......@@ -28,7 +28,7 @@
#include <utility>
#include "ir/tensor.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "operator/ops.h"
#include "optimizer/optimizer.h"
#include "parallel/auto_parallel/graph_costmodel.h"
......@@ -1298,9 +1298,7 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_nod
return false;
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param());
py::object clone_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO);
bool cloned = py::cast<bool>(parse::python_adapter::GetPyObjAttr(clone_info, CLONED));
bool cloned = cloned_parameter->default_param()->cloned();
if (!cloned) {
return false;
}
......@@ -1321,9 +1319,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
}
// get the cloned index
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(cloned_parameter->default_param());
py::object cloned_info = parse::python_adapter::GetPyObjAttr(param_value->value(), CLONE_INFO);
int32_t cloned_index = py::cast<int32_t>(parse::python_adapter::GetPyObjAttr(cloned_info, CLONED_INDEX));
int32_t cloned_index = cloned_parameter->default_param()->cloned_index();
// find the be cloned parameter
bool found_be_cloned_parameter = false;
......@@ -1337,21 +1333,17 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
continue;
}
auto param_value_cloned = std::dynamic_pointer_cast<ParamValuePy>(be_cloned_parameter->default_param());
py::object be_cloned_info = parse::python_adapter::GetPyObjAttr(param_value_cloned->value(), CLONE_INFO);
if (!py::cast<bool>(parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED))) {
const auto &param_value_cloned = be_cloned_parameter->default_param();
if (!param_value_cloned->be_cloned()) {
continue;
}
// get the be cloned index
py::list be_cloned_index = parse::python_adapter::GetPyObjAttr(be_cloned_info, BE_CLONED_INDEX);
for (auto &index : be_cloned_index) {
if (cloned_index == py::cast<int32_t>(index)) {
found_be_cloned_parameter = true;
cloned_from_parameter = be_cloned_parameter;
cloned_from_node = be_cloned_parameter_node;
break;
}
auto &be_cloned_index = param_value_cloned->be_cloned_index();
if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
found_be_cloned_parameter = true;
cloned_from_parameter = be_cloned_parameter;
cloned_from_node = be_cloned_parameter_node;
}
}
......@@ -2090,9 +2082,9 @@ std::string NodeParameterName(const CNodePtr &node) {
if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(input_parameter->default_param());
if (py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), REQUIRES_GRAD))) {
return py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), PARAM_NAME));
const auto &param_value = input_parameter->default_param();
if (param_value->requires_grad()) {
return param_value->name();
}
}
}
......
......@@ -24,7 +24,7 @@
#include <functional>
#include "ir/func_graph_cloner.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "parallel/costmodel_context.h"
#include "parallel/context.h"
#include "pipeline/pass.h"
......@@ -228,14 +228,12 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
auto sparse_grad =
py::cast<std::string>(parse::python_adapter::GetPyObjAttr(param_value->value(), "sparse_grad"));
ptr->set_sparse_grad(sparse_grad);
auto has_indexed_slices_grad =
py::cast<bool>(parse::python_adapter::GetPyObjAttr(param_value->value(), "has_indexed_slices_grad"));
ptr->set_has_indexed_slices_grad(has_indexed_slices_grad);
const auto &param_value = param_node->default_param();
ValuePtr value = param_value->value();
constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden);
ptr->set_sparse_grad(param_value->sparse_grad());
ptr->set_has_indexed_slices_grad(param_value->has_indexed_slices_grad());
parallel::ParallelParameterContextRestoreInNoTraining(func_graph, param_node, ptr);
args_spec.push_back(ptr);
......
......@@ -21,7 +21,7 @@
#include <vector>
#include <algorithm>
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "pipeline/parse/data_converter.h"
#include "pipeline/parse/parse.h"
#include "pipeline/parse/python_adapter.h"
......@@ -103,16 +103,12 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
}
if (para_node == nullptr) {
auto node = top_graph->AddWeightParameter(param_name);
auto param_value_new = std::make_shared<ParamValuePy>(obj);
node->set_default_param(param_value_new);
auto param_value = py::cast<ParamValuePtr>(python_adapter::GetPyObjAttr(obj, "_value"));
node->set_default_param(param_value);
// set_abstract for parameter
auto to_convert = py::cast<py::object>(python_adapter::GetPyObjAttr(obj, "default_input"));
ValuePtr converted = nullptr;
(void)ConvertData(to_convert, &converted);
bool broaden = true;
node->set_abstract(abstract::FromValue(converted, broaden));
ValuePtr value = param_value->value();
constexpr bool broaden = true;
node->set_abstract(abstract::FromValue(value, broaden));
para_node = node;
}
auto iter = func_graph->make_ref_params().find(para_node);
......
......@@ -24,7 +24,7 @@
#include <cstdlib>
#include <algorithm>
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "pipeline/pass.h"
#include "pipeline/parse/data_converter.h"
#include "optimizer/ad/dfunctor.h"
......@@ -695,10 +695,7 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if (!param_ptr->has_default()) {
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_ptr->default_param());
py::object obj = param_value->value();
py::object p_value = py::cast<py::object>(parse::python_adapter::GetPyObjAttr(obj, "default_input"));
(*arg_list).push_back(p_value);
arg_list->push_back(param_ptr->default_param()->value());
}
}
}
......
......@@ -24,7 +24,7 @@
#include "debug/trace.h"
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "utils/any.h"
#include "utils/utils.h"
#include "utils/context/ms_context.h"
......@@ -830,7 +830,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name);
auto free_param_new = std::make_shared<ParamValuePy>(obj);
auto free_param_new = py::cast<ParamValuePtr>(obj.attr("_value"));
free_param->set_default_param(free_param_new);
free_param->debug_info()->set_name(param_name);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
......@@ -1026,8 +1026,9 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
for (const auto &param : df_builder_->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
AbstractBasePtr ptr = abstract::FromValue(parse::data_converter::PyDataToValue(param_value->value()), true);
const auto &param_value = param_node->default_param();
ValuePtr value = param_value->value();
AbstractBasePtr ptr = abstract::FromValue(value, true);
if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Args convert error";
}
......
......@@ -16,9 +16,8 @@
#include "session/ascend_inference_session.h"
#include "operator/ops.h"
#include "ir/tensor.h"
#include "ir/tensor_py.h"
#include "ir/anf.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "device/kernel_runtime.h"
#include "session/anf_runtime_algorithm.h"
#include "common/utils.h"
......@@ -27,66 +26,8 @@
#include "utils/config_manager.h"
#include "utils/base_ref_extends.h"
using mindspore::tensor::TensorPy;
namespace mindspore {
namespace session {
namespace {
static TypeId GetDataType(const py::buffer_info &buf) {
if (buf.format.size() == 1) {
switch (buf.format.front()) {
case 'e':
case 'f':
case 'd':
switch (buf.itemsize) {
case 2:
return TypeId::kNumberTypeFloat16;
case 4:
return TypeId::kNumberTypeFloat32;
case 8:
return TypeId::kNumberTypeFloat64;
}
break;
case 'b':
case 'h':
case 'i':
case 'l':
case 'q':
switch (buf.itemsize) {
case 1:
return TypeId::kNumberTypeInt8;
case 2:
return TypeId::kNumberTypeInt16;
case 4:
return TypeId::kNumberTypeInt32;
case 8:
return TypeId::kNumberTypeInt64;
}
break;
case 'B':
case 'H':
case 'I':
case 'L':
case 'Q':
switch (buf.itemsize) {
case 1:
return TypeId::kNumberTypeUInt8;
case 2:
return TypeId::kNumberTypeUInt16;
case 4:
return TypeId::kNumberTypeUInt32;
case 8:
return TypeId::kNumberTypeUInt64;
}
break;
case '?':
return TypeId::kNumberTypeBool;
}
}
MS_LOG(WARNING) << "Unsupported DataType format " << buf.format << " item size " << buf.itemsize;
return TypeId::kTypeUnknown;
}
} // namespace
void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const {
MS_EXCEPTION_IF_NULL(kernel_graph);
......@@ -131,15 +72,13 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (AnfAlgo::IsParameterWeight(pk_node)) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
const auto &param_value = pk_node->default_param();
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value();
MS_EXCEPTION_IF_NULL(py_param);
py::array py_array = py_param.cast<py::array>();
py::buffer_info buf = py_array.request();
auto buf_type = GetDataType(buf);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value->value());
MS_EXCEPTION_IF_NULL(tensor);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) {
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
......
......@@ -19,7 +19,7 @@
#include <unordered_set>
#include <set>
#include "operator/ops.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "session/anf_runtime_algorithm.h"
#include "device/kernel_info.h"
#include "kernel/kernel_build_info.h"
......@@ -380,9 +380,7 @@ ParameterPtr KernelGraph::NewParameter(const ParameterPtr &parameter) {
new_parameter->set_abstract(parameter->abstract());
new_parameter->set_name(parameter->name());
if (AnfAlgo::IsParameterWeight(parameter)) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
auto param_value_new = std::make_shared<ParamValuePy>(param_value->value());
new_parameter->set_default_param(param_value_new);
new_parameter->set_default_param(parameter->default_param());
kernel_info->SetFeatureMapFlag(false);
} else {
kernel_info->SetFeatureMapFlag(true);
......
......@@ -20,7 +20,7 @@
#include <unordered_set>
#include "pipeline/parse/data_converter.h"
#include "ir/manager.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "kernel/common_utils.h"
#include "operator/ops.h"
#include "common/trans.h"
......@@ -38,12 +38,12 @@
namespace mindspore {
namespace session {
static std::shared_ptr<std::map<PyObject *, ParameterPtr>> python_paras_;
static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras_;
void ClearPythonParasMap() { python_paras_ = nullptr; }
namespace {
const int kSummaryGetItem = 2;
PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
if (node == nullptr) {
return nullptr;
}
......@@ -51,10 +51,7 @@ PyObject *GetParamDefaultInputTensor(const AnfNodePtr &node) {
if (parameter == nullptr || !parameter->has_default()) {
return nullptr;
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(parameter->default_param());
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value();
return py_param.ptr();
return parameter->default_param();
}
BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph,
......@@ -215,8 +212,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) {
py::object obj;
auto param_value_new = std::make_shared<ParamValuePy>(obj);
auto param_value_new = std::make_shared<ParamValue>();
param->set_default_param(param_value_new);
}
// set the kernel info of parameter
......@@ -384,7 +380,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
}
MS_EXCEPTION_IF_NULL(graph);
auto m_tensor = GetParamDefaultInputTensor(anf);
auto param_value = GetParamDefaultValue(anf);
auto valid_inputs = graph->MutableValidInputs();
MS_EXCEPTION_IF_NULL(valid_inputs);
auto graph_inputs = graph->MutableInputs();
......@@ -392,16 +388,16 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
python_paras_ = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
}
auto iter = python_paras_->find(m_tensor);
auto iter = python_paras_->find(param_value);
if (iter != python_paras_->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (m_tensor != nullptr) {
(*python_paras_)[m_tensor] = new_parameter;
if (param_value != nullptr) {
(*python_paras_)[param_value] = new_parameter;
}
TraceManager::EndTrace();
}
......@@ -618,19 +614,19 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
MS_LOG(EXCEPTION) << "Anf[" << anf->DebugString() << "] is not a parameter";
}
auto m_tensor = GetParamDefaultInputTensor(anf);
auto param_value = GetParamDefaultValue(anf);
ParameterPtr new_parameter = nullptr;
if (python_paras_ == nullptr) {
python_paras_ = std::make_shared<std::map<PyObject *, ParameterPtr>>();
python_paras_ = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>();
}
auto iter = python_paras_->find(m_tensor);
auto iter = python_paras_->find(param_value);
if (iter != python_paras_->end()) {
new_parameter = iter->second;
} else {
TraceManager::DebugTrace(std::make_shared<TraceCopy>(anf->debug_info()));
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>());
if (m_tensor != nullptr) {
(*python_paras_)[m_tensor] = new_parameter;
if (param_value != nullptr) {
(*python_paras_)[param_value] = new_parameter;
}
TraceManager::EndTrace();
}
......
......@@ -16,7 +16,7 @@
#include "utils/callbacks_ge.h"
#include "pybind11/pybind11.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "transform/df_graph_manager.h"
#include "transform/util.h"
#include "pipeline/parse/data_converter.h"
......@@ -50,13 +50,10 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string &param_name,
return false;
}
if (param_node->name() == param_name) {
py::object parameter;
TensorPtr tensor;
if (param_node->has_default()) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param_node->default_param());
parameter = param_value->value();
tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()->value());
}
ValuePtr value = parse::data_converter::PyDataToValue(parameter);
TensorPtr tensor = std::dynamic_pointer_cast<tensor::Tensor>(value);
if (tensor == nullptr) {
shape->push_back(ONE_SHAPE);
} else {
......
......@@ -30,7 +30,7 @@
#include "pipeline/parse/parse_base.h"
#include "ir/value.h"
#include "ir/tensor.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "utils/base_ref_extends.h"
namespace mindspore {
......@@ -449,8 +449,8 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
if (!param->has_default()) {
MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
}
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(param->default_param());
*ret_val = param_value->value().attr("data");
auto tensor = param->default_param()->value();
*ret_val = py::cast(tensor);
}
return true;
}
......
......@@ -22,14 +22,12 @@
#include <vector>
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "ir/tensor.h"
#include "ir/tensor_py.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "operator/ops.h"
#include "pipeline/static_analysis/abstract_value.h"
#include "proto/onnx.pb.h"
#include "utils/log_adapter.h"
using mindspore::tensor::TensorPy;
using std::string;
namespace mindspore {
......@@ -123,11 +121,10 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons
MS_EXCEPTION_IF_NULL(tensor_data_buf);
memcpy_s(tensor_data_buf, tensor_info->data().nbytes(), initial_data.data(), initial_data.size());
py::array array_data = TensorPy::AsNumpy(*tensor_info);
ParamValuePyPtr para_value_ptr = std::make_shared<ParamValuePy>();
MS_EXCEPTION_IF_NULL(para_value_ptr);
para_value_ptr->set_value(array_data);
node->set_default_param(para_value_ptr);
auto param_value = std::make_shared<ParamValue>();
MS_EXCEPTION_IF_NULL(param_value);
param_value->set_value(tensor_info);
node->set_default_param(param_value);
}
anfnode_build_map_[value_proto.name()] = node;
return true;
......
......@@ -17,11 +17,11 @@
import numbers
from copy import copy
from mindspore import context
from .._c_expression import ParamValue
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
from ..parallel._utils import _set_clone_info, _CloneInfo
from ..parallel._tensor import _get_slice_index
__all__ = ['Parameter', 'ParameterTuple']
......@@ -56,6 +56,7 @@ class Parameter:
"""
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False,
sparse_grad="", has_indexed_slices_grad=False):
self._value = ParamValue()
self.set_parameter_data(default_input)
self.name = name
self.requires_grad = requires_grad
......@@ -64,13 +65,12 @@ class Parameter:
self.has_indexed_slices_grad = has_indexed_slices_grad
self._is_init = False
self._sliced = False
self.clone_info = _CloneInfo()
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data()
def __repr__(self):
format_str = 'Parameter (name={name})'
return format_str.format(name=self._name)
return format_str.format(name=self._value.name)
def __parameter__(self):
"""For parse check."""
......@@ -78,7 +78,7 @@ class Parameter:
@property
def name(self):
"""Get the name of the parameter."""
return self._name
return self._value.name
@name.setter
def name(self, name_):
......@@ -100,7 +100,7 @@ class Parameter:
format(name_, PARAMETER_NAME_PREFIX_MAX_LEN))
else:
raise ValueError("The type of the name should be `str` or `None`.")
self._name = name_
self._value.name = name_
@property
def sliced(self):
......@@ -140,7 +140,9 @@ class Parameter:
"""
_check_str_by_regular(prefix)
x = copy(self)
x.name = prefix + '.' + x.name
# pylint: disable=protected-access
x._value = self._value.clone()
x._value.name = prefix + '.' + self._value.name
x.is_init = False
if init != 'same':
shape = self.default_input.shape
......@@ -152,58 +154,64 @@ class Parameter:
x.init_data()
else:
x.default_input = initializer(init, shape=shape, dtype=dtype)
x.clone_info = copy(self.clone_info)
_set_clone_info(self.clone_info, x.clone_info)
return x
@property
def layerwise_parallel(self):
return self._layerwise_parallel
return self._value.layerwise_parallel
@layerwise_parallel.setter
def layerwise_parallel(self, value=True):
if not isinstance(value, bool):
raise TypeError("`layerwise_parallel` parameter must be bool type")
self._layerwise_parallel = value
self._value.layerwise_parallel = value
@property
def requires_grad(self):
"""Return whether the parameter requires gradient."""
return self._requires_grad
return self._value.requires_grad
@requires_grad.setter
def requires_grad(self, value=True):
if not isinstance(value, bool):
raise TypeError("`requires_grad` parameter must be bool type")
self._requires_grad = value
self._value.requires_grad = value
@property
def sparse_grad(self):
"""Return whether the parameter's gradient is sparse."""
return self._sparse_grad
return self._value.sparse_grad
@sparse_grad.setter
def sparse_grad(self, value=""):
if not isinstance(value, str):
raise TypeError("`sparse_grad` parameter must be str type")
self._sparse_grad = value
self._value.sparse_grad = value
@property
def has_indexed_slices_grad(self):
"""Return whether the parameter's gradient is indexed_slices."""
return self._has_indexed_slices_grad
return self._value.has_indexed_slices_grad
@has_indexed_slices_grad.setter
def has_indexed_slices_grad(self, value=False):
if not isinstance(value, bool):
raise TypeError("`has_indexed_slices_grad` parameter must be bool type")
self._has_indexed_slices_grad = value
self._value.has_indexed_slices_grad = value
@property
def data(self):
return self.default_input
@property
def default_input(self):
return self._data
@default_input.setter
def default_input(self, data):
self._data = data
self._value.data = data
def __add__(self, other):
return self.default_input + other
......@@ -223,11 +231,12 @@ class Parameter:
def set_parameter_data(self, data):
"""Set `default_input` of current `Parameter`."""
self.init_mode = None
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
data = Tensor(data.asnumpy().copy())
data = Tensor(data.asnumpy())
data.init_flag = False
elif isinstance(data, Initializer):
self.init_mode = data
......@@ -242,7 +251,6 @@ class Parameter:
self.default_input = data
def init_data(self, layout=None, set_sliced=False):
"""
Init data of the parameter.
......@@ -256,7 +264,7 @@ class Parameter:
set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False.
"""
if not isinstance(self.default_input, MetaTensor):
if self.init_mode is None:
return
if layout is not None:
if not isinstance(layout, list):
......
......@@ -73,7 +73,6 @@ class Tensor(Tensor_):
else:
Tensor_.__init__(self, input_data, dtype)
self._virtual_flag = False
self._init_flag = False
def __repr__(self):
return str(self.__str__())
......@@ -205,19 +204,6 @@ class Tensor(Tensor_):
raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value
@property
def init_flag(self):
"""whether the tensor is init."""
return self._init_flag
@init_flag.setter
def init_flag(self, value):
"""Set the tensor is init_flag."""
if not isinstance(value, bool):
raise TypeError("init_flag must be bool.")
self.set_init_flag(value)
self._init_flag = value
class IndexedSlices:
def __init__(self, indices, values, dense_shape):
......
......@@ -16,7 +16,7 @@
"""Utitly functions to help distribution class."""
import numpy as np
from mindspore.ops import _utils as utils
from ....common.tensor import Tensor, MetaTensor
from ....common.tensor import Tensor
from ....common.parameter import Parameter
from ....common import dtype as mstype
......@@ -152,7 +152,7 @@ def check_greater_equal_zero(value, name):
"""
if isinstance(value, Parameter):
if isinstance(value.default_input, MetaTensor):
if not isinstance(value.default_input, Tensor):
return
value = value.default_input
comp = np.less(value.asnumpy(), np.zeros(value.shape))
......@@ -188,7 +188,7 @@ def check_prob(p):
ValueError: if p is not a proper probability.
"""
if isinstance(p, Parameter):
if isinstance(p.default_input, MetaTensor):
if not isinstance(p.default_input, Tensor):
return
p = p.default_input
comp = np.less(p.asnumpy(), np.zeros(p.shape))
......
......@@ -122,47 +122,6 @@ def _parameter_broadcast_check(parallel_mode, parameter_broadcast):
"do not support parameter broadcast, parallel_mode: {0}, parameter_broadcast:{1}"
.format(parallel_mode, parameter_broadcast))
PARAMETER_CLONED_INDEX = 0
class _CloneInfo():
"""
The clone info of parameter.
Attributes:
be_cloned (bool): Whether the parameter is cloned.
cloned (bool): Whether the parameter clone from other parameter.
be_cloned_index (tuple): If the parameter is cloned, generate one index per clone.
cloned_index (int): If the parameter clone from other parameter, it has a unique index.
"""
def __init__(self):
self.be_cloned = False
self.cloned = False
self.be_cloned_index = []
self.cloned_index = None
def _set_clone_info(clone_from, clone_to):
"""
Set the clone info.
Args:
clone_from (_CloneInfo): The clone info of be_cloned parameter.
clone_to (_CloneInfo): The clone info of cloned parameter.
"""
global PARAMETER_CLONED_INDEX
clone_to.be_cloned = False
clone_to.cloned = True
clone_to.be_cloned_index = []
clone_to.cloned_index = PARAMETER_CLONED_INDEX
clone_from.be_cloned = True
clone_from.be_cloned_index.append(PARAMETER_CLONED_INDEX)
PARAMETER_CLONED_INDEX = PARAMETER_CLONED_INDEX + 1
def _get_python_op(op_name, op_path, instance_name, arglist):
"""Get python operator."""
module = __import__(op_path, fromlist=["None"])
......
......@@ -15,7 +15,7 @@
*/
#include "common/common_test.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "operator/ops.h"
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
......@@ -764,10 +764,9 @@ TEST_F(AnfRuntimeAlgorithmTest, IsRealCNodeKernel) {
TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) {
auto kernel_graph = std::make_shared<KernelGraph>();
py::object obj;
auto parameter_node = kernel_graph->add_parameter();
MS_EXCEPTION_IF_NULL(parameter_node);
auto param_value_new = std::make_shared<ParamValuePy>(obj);
auto param_value_new = std::make_shared<ParamValue>();
parameter_node->set_default_param(param_value_new);
EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node));
EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error);
......
......@@ -15,7 +15,7 @@
*/
#include "common/common_test.h"
#include "ir/param_value_py.h"
#include "ir/param_value.h"
#include "operator/ops.h"
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
......@@ -82,8 +82,7 @@ TEST_F(KernelGraphTest, NewParameter) {
// test weight parameter node as input
auto weight_parameter_node = anf_graph->add_parameter();
MS_EXCEPTION_IF_NULL(weight_parameter_node);
py::object obj;
auto param_value_new = std::make_shared<ParamValuePy>(obj);
auto param_value_new = std::make_shared<ParamValue>();
weight_parameter_node->set_default_param(param_value_new);
weight_parameter_node->set_abstract(x_abstract);
auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册