提交 a05c38bb 编写于 作者: W Wei Luning

make python Parameter inherit from Tensor

上级 2b565627
...@@ -21,12 +21,12 @@ from .parser import (Parser, create_obj_instance, generate_scope, ...@@ -21,12 +21,12 @@ from .parser import (Parser, create_obj_instance, generate_scope,
get_class_member_namespace_symbol, create_slice_obj, get_class_member_namespace_symbol, create_slice_obj,
get_dataclass_attributes, get_dataclass_methods, get_obj_id, get_dataclass_attributes, get_dataclass_methods, get_obj_id,
get_module_namespace, get_obj_type, get_object_key, get_module_namespace, get_obj_type, get_object_key,
get_default_input, get_parse_method_of_class, get_scope_name, get_parse_method_of_class, get_scope_name,
is_class_member, parse_cb, resolve_symbol) is_class_member, parse_cb, resolve_symbol)
from .serialize import * from .serialize import *
__all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol', __all__ = ['parse_cb', 'get_parse_method_of_class', 'get_bprop_method_of_class', 'resolve_symbol',
'get_object_key', 'get_default_input', 'get_class_instance_type', 'is_class_member', 'get_object_key', 'get_class_instance_type', 'is_class_member',
'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace', 'get_obj_type', 'get_obj_id', 'create_obj_instance', 'get_module_namespace',
'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes', 'get_class_member_namespace_symbol', 'get_obj_id', 'Parser', 'get_dataclass_attributes',
'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name', 'get_dataclass_methods', 'dump_obj', 'load_obj', 'get_dataclass_methods', 'get_scope_name',
......
...@@ -206,16 +206,6 @@ def get_object_key(obj): ...@@ -206,16 +206,6 @@ def get_object_key(obj):
return obj_id, obj_key return obj_id, obj_key
def get_default_input(obj):
if hasattr(obj, '__parameter__'):
return obj.default_input
if isinstance(obj, tuple):
convert = lambda x: x.default_input if hasattr(x, '__parameter__') else x
args = tuple(convert(x) for x in obj)
return args
return obj
def is_class_member(node): def is_class_member(node):
"""Check the attr is class member variable.""" """Check the attr is class member variable."""
type_ = node.__class__.__name__ type_ = node.__class__.__name__
......
...@@ -76,7 +76,7 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { ...@@ -76,7 +76,7 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
if (AnfAlgo::IsParameterWeight(pk_node)) { if (AnfAlgo::IsParameterWeight(pk_node)) {
const auto &param_value = pk_node->default_param(); const auto &param_value = pk_node->default_param();
MS_EXCEPTION_IF_NULL(param_value); MS_EXCEPTION_IF_NULL(param_value);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value->value()); auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_value);
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0), if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(), LongToSize(tensor->data().nbytes()), tensor->data_type(),
......
...@@ -42,12 +42,12 @@ ...@@ -42,12 +42,12 @@
namespace mindspore { namespace mindspore {
namespace session { namespace session {
static std::shared_ptr<std::map<ParamValuePtr, ParameterPtr>> python_paras; static std::shared_ptr<std::map<ValuePtr, ParameterPtr>> python_paras;
void ClearPythonParasMap() { python_paras = nullptr; } void ClearPythonParasMap() { python_paras = nullptr; }
namespace { namespace {
const int kSummaryGetItem = 2; const int kSummaryGetItem = 2;
ParamValuePtr GetParamDefaultValue(const AnfNodePtr &node) { ValuePtr GetParamDefaultValue(const AnfNodePtr &node) {
if (node == nullptr) { if (node == nullptr) {
return nullptr; return nullptr;
} }
...@@ -209,8 +209,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph, ...@@ -209,8 +209,7 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
auto param = graph->NewParameter(); auto param = graph->NewParameter();
MS_EXCEPTION_IF_NULL(param); MS_EXCEPTION_IF_NULL(param);
if (tensor_mask == kParameterWeightTensorMask) { if (tensor_mask == kParameterWeightTensorMask) {
auto param_value_new = std::make_shared<ParamValue>(); param->set_default_param(input_tensor);
param->set_default_param(param_value_new);
} }
// set the kernel info of parameter // set the kernel info of parameter
auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(); auto kernel_build_info_builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
...@@ -390,7 +389,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ...@@ -390,7 +389,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf
ParameterPtr new_parameter = nullptr; ParameterPtr new_parameter = nullptr;
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter
if (python_paras == nullptr) { if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>(); python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
} }
auto iter = python_paras->find(param_value); auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) { if (iter != python_paras->end()) {
...@@ -667,7 +666,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph ...@@ -667,7 +666,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph
auto param_value = GetParamDefaultValue(anf); auto param_value = GetParamDefaultValue(anf);
ParameterPtr new_parameter = nullptr; ParameterPtr new_parameter = nullptr;
if (python_paras == nullptr) { if (python_paras == nullptr) {
python_paras = std::make_shared<std::map<ParamValuePtr, ParameterPtr>>(); python_paras = std::make_shared<std::map<ValuePtr, ParameterPtr>>();
} }
auto iter = python_paras->find(param_value); auto iter = python_paras->find(param_value);
if (iter != python_paras->end()) { if (iter != python_paras->end()) {
......
...@@ -1670,7 +1670,7 @@ class IrParser { ...@@ -1670,7 +1670,7 @@ class IrParser {
// load parameter default value from serialized file // load parameter default value from serialized file
py::object default_obj = LoadObject(lexer_.GetTokenText()); py::object default_obj = LoadObject(lexer_.GetTokenText());
auto param_value_new = py::cast<ParamValuePtr>(default_obj); auto param_value_new = py::cast<tensor::TensorPtr>(default_obj);
param->set_default_param(param_value_new); param->set_default_param(param_value_new);
tok = lexer_.GetNextToken(); tok = lexer_.GetNextToken();
......
...@@ -318,8 +318,9 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) { ...@@ -318,8 +318,9 @@ void BaseDigraph::FuncGraphParameters(const FuncGraphPtr &key) {
buffer_ << parameter->ToString(); buffer_ << parameter->ToString();
auto param = parameter->cast<ParameterPtr>(); auto param = parameter->cast<ParameterPtr>();
if (param->has_default()) { if (param->has_default()) {
auto tensor = param->default_param()->value(); auto tensor_v = param->default_param();
if (tensor) { if (tensor_v && tensor_v->isa<tensor::Tensor>()) {
auto tensor = tensor_v->cast<tensor::TensorPtr>();
auto &shape = tensor->shape(); auto &shape = tensor->shape();
std::ostringstream shape_str; std::ostringstream shape_str;
std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(shape_str, ",")); std::copy(shape.begin(), shape.end(), std::ostream_iterator<int>(shape_str, ","));
......
...@@ -38,7 +38,12 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) { ...@@ -38,7 +38,12 @@ bool ParameterRequireGrad(const AnfNodePtr &node_ptr) {
if (!para_ptr->has_default()) { if (!para_ptr->has_default()) {
return false; return false;
} }
return para_ptr->default_param()->requires_grad(); auto obj = py::cast(para_ptr->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) {
return false;
}
return param_value->requires_grad();
} }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
...@@ -41,6 +41,7 @@ ...@@ -41,6 +41,7 @@
#include "frontend/parallel/context.h" #include "frontend/parallel/context.h"
#include "frontend/parallel/ops_info/tmp_identity_info.h" #include "frontend/parallel/ops_info/tmp_identity_info.h"
#include "frontend/parallel/ops_info/reshape_info.h" #include "frontend/parallel/ops_info/reshape_info.h"
#include "frontend/parallel/graph_util/node_info.h"
#include "frontend/parallel/step_parallel.h" #include "frontend/parallel/step_parallel.h"
#include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h" #include "frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
#include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/python_adapter.h"
...@@ -122,12 +123,7 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) { ...@@ -122,12 +123,7 @@ std::vector<bool> ExtractInputParameterByNode(const CNodePtr &node) {
if (input->isa<Parameter>()) { if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>(); auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) { is_parameter.push_back(ParameterRequireGrad(input_parameter));
bool requires_grad = input_parameter->default_param()->requires_grad();
is_parameter.push_back(requires_grad);
} else {
is_parameter.push_back(false);
}
} else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) { } else if (input->isa<CNode>() || IsValueNode<tensor::Tensor>(input) || IsValueNode<RefKey>(input)) {
is_parameter.push_back(false); is_parameter.push_back(false);
} }
...@@ -798,12 +794,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) { ...@@ -798,12 +794,7 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
std::vector<bool> is_parameter; std::vector<bool> is_parameter;
auto casted_target_parameter = target_parameter->cast<ParameterPtr>(); auto casted_target_parameter = target_parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(casted_target_parameter); MS_EXCEPTION_IF_NULL(casted_target_parameter);
if (casted_target_parameter->has_default()) { is_parameter.push_back(ParameterRequireGrad(casted_target_parameter));
bool requires_grad = casted_target_parameter->default_param()->requires_grad();
is_parameter.push_back(requires_grad);
} else {
is_parameter.push_back(false);
}
if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) { if (tmp_identity_ptr->set_is_parameter(is_parameter) != SUCCESS) {
MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed"; MS_LOG(EXCEPTION) << "Setting parameter for TmpIdentityInfo failed";
} }
......
...@@ -1295,11 +1295,8 @@ void CoverSliceShape(const FuncGraphPtr &root) { ...@@ -1295,11 +1295,8 @@ void CoverSliceShape(const FuncGraphPtr &root) {
g_RefMap.clear(); g_RefMap.clear();
} }
bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_node) { bool ParameterIsCloned(const AnfNodePtr &parameter_node) {
MS_EXCEPTION_IF_NULL(root);
MS_EXCEPTION_IF_NULL(parameter_node); MS_EXCEPTION_IF_NULL(parameter_node);
FuncGraphManagerPtr manager = root->manager();
MS_EXCEPTION_IF_NULL(manager);
auto cloned_parameter = parameter_node->cast<ParameterPtr>(); auto cloned_parameter = parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter); MS_EXCEPTION_IF_NULL(cloned_parameter);
...@@ -1307,8 +1304,12 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_nod ...@@ -1307,8 +1304,12 @@ bool ParameterIsCloned(const FuncGraphPtr &root, const AnfNodePtr &parameter_nod
if (!cloned_parameter->has_default()) { if (!cloned_parameter->has_default()) {
return false; return false;
} }
auto obj = py::cast(cloned_parameter->default_param());
bool cloned = cloned_parameter->default_param()->cloned(); auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) {
return false;
}
bool cloned = param_value->cloned();
if (!cloned) { if (!cloned) {
return false; return false;
} }
...@@ -1324,12 +1325,16 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { ...@@ -1324,12 +1325,16 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>(); auto cloned_parameter = cloned_parameter_node->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(cloned_parameter); MS_EXCEPTION_IF_NULL(cloned_parameter);
if (!ParameterIsCloned(root, cloned_parameter_node)) { if (!ParameterIsCloned(cloned_parameter_node)) {
continue;
}
auto obj = py::cast(cloned_parameter->default_param());
auto param_value = py::cast<ParamValuePtr>(obj.attr("_value"));
if (param_value == nullptr) {
continue; continue;
} }
// get the cloned index // get the cloned index
int32_t cloned_index = cloned_parameter->default_param()->cloned_index(); int32_t cloned_index = param_value->cloned_index();
// find the be cloned parameter // find the be cloned parameter
bool found_be_cloned_parameter = false; bool found_be_cloned_parameter = false;
...@@ -1344,12 +1349,18 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { ...@@ -1344,12 +1349,18 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
} }
const auto &param_value_cloned = be_cloned_parameter->default_param(); const auto &param_value_cloned = be_cloned_parameter->default_param();
if (!param_value_cloned->be_cloned()) {
auto obj_in = py::cast(param_value_cloned);
auto param_value_in = py::cast<ParamValuePtr>(obj_in.attr("_value"));
if (param_value_in == nullptr) {
continue;
}
if (!param_value_in->be_cloned()) {
continue; continue;
} }
// get the be cloned index // get the be cloned index
auto &be_cloned_index = param_value_cloned->be_cloned_index(); auto &be_cloned_index = param_value_in->be_cloned_index();
if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) { if (std::find(be_cloned_index.begin(), be_cloned_index.end(), cloned_index) != be_cloned_index.end()) {
found_be_cloned_parameter = true; found_be_cloned_parameter = true;
cloned_from_parameter = be_cloned_parameter; cloned_from_parameter = be_cloned_parameter;
...@@ -2103,10 +2114,7 @@ std::string NodeParameterName(const CNodePtr &node) { ...@@ -2103,10 +2114,7 @@ std::string NodeParameterName(const CNodePtr &node) {
if (input->isa<Parameter>()) { if (input->isa<Parameter>()) {
auto input_parameter = input->cast<ParameterPtr>(); auto input_parameter = input->cast<ParameterPtr>();
if (input_parameter->has_default()) { if (input_parameter->has_default()) {
const auto &param_value = input_parameter->default_param(); input_parameter->name();
if (param_value->requires_grad()) {
return param_value->name();
}
} }
} }
} }
......
...@@ -233,8 +233,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) { ...@@ -233,8 +233,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res) {
for (const auto &param : func_graph->parameters()) { for (const auto &param : func_graph->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param); auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) { if (param_node->has_default()) {
const auto &param_value = param_node->default_param(); ValuePtr value = param_node->default_param();
ValuePtr value = param_value->value();
constexpr bool broaden = true; constexpr bool broaden = true;
AbstractBasePtr ptr = abstract::FromValue(value, broaden); AbstractBasePtr ptr = abstract::FromValue(value, broaden);
......
...@@ -68,6 +68,8 @@ PYBIND11_MODULE(_c_expression, m) { ...@@ -68,6 +68,8 @@ PYBIND11_MODULE(_c_expression, m) {
py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.") py::arg("type") = py::str("onnx_ir"), "Get graph proto string by specifying ir type.")
.def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""), .def("compile", &ExecutorPy::Compile, py::arg("obj"), py::arg("args"), py::arg("phase") = py::str(""),
py::arg("use_vm") = py::bool_(false), "Compile obj by executor.") py::arg("use_vm") = py::bool_(false), "Compile obj by executor.")
.def("updata_param_node_default_input", &ExecutorPy::UpdataParamNodeDefaultInput, py::arg("phase"),
py::arg("params"), "Fetch the inputs of Conv or Matmul for quant export.")
.def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"), .def("get_parameter_layout", &ExecutorPy::GetParameterLayout, py::arg("phase") = py::str("train"),
"Get Parameter Tensor Layout Dictionary.") "Get Parameter Tensor Layout Dictionary.")
.def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"), .def("get_strategy", &ExecutorPy::GetCNodeStrategy, py::arg("phase") = py::str("train"),
......
...@@ -205,41 +205,6 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_ ...@@ -205,41 +205,6 @@ bool ConvertMetaFuncGraph(const py::object &obj, ValuePtr *const data, bool use_
return true; return true;
} }
bool ConvertDataType(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting type object";
auto typeptr = obj.cast<TypePtr>();
if (typeptr == nullptr) {
MS_LOG(ERROR) << "Resolve TypePtr error, get ptr is null";
return false;
}
*data = typeptr;
return true;
}
bool ConvertMetaTensor(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting MetaTensor object.";
auto m_tensor = obj.cast<MetaTensorPtr>();
if (m_tensor == nullptr) {
MS_LOG(ERROR) << "Resolve MetaTensor error, get ptr is null.";
return false;
}
*data = m_tensor;
return true;
}
bool ConvertTensor(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting tensor object";
auto m_tensor = obj.cast<TensorPtr>();
if (m_tensor == nullptr) {
MS_LOG(ERROR) << "Resolve Tensor error, get ptr is null";
return false;
}
*data = m_tensor;
return true;
}
bool ConvertSlice(const py::object &obj, ValuePtr *const data) { bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
MS_LOG(DEBUG) << "Converting slice object"; MS_LOG(DEBUG) << "Converting slice object";
...@@ -364,11 +329,11 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature ...@@ -364,11 +329,11 @@ bool ConvertData(const py::object &obj, ValuePtr *const data, bool use_signature
} else if (py::isinstance<MetaFuncGraph>(obj)) { } else if (py::isinstance<MetaFuncGraph>(obj)) {
ret = ConvertMetaFuncGraph(obj, &converted, use_signature); ret = ConvertMetaFuncGraph(obj, &converted, use_signature);
} else if (py::isinstance<Type>(obj)) { } else if (py::isinstance<Type>(obj)) {
ret = ConvertDataType(obj, &converted); converted = obj.cast<TypePtr>();
} else if (py::isinstance<Tensor>(obj)) { } else if (py::isinstance<Tensor>(obj)) {
ret = ConvertTensor(obj, &converted); converted = obj.cast<TensorPtr>();
} else if (py::isinstance<MetaTensor>(obj)) { } else if (py::isinstance<MetaTensor>(obj)) {
ret = ConvertMetaTensor(obj, &converted); converted = obj.cast<MetaTensorPtr>();
} else if (py::isinstance<EnvInstance>(obj)) { } else if (py::isinstance<EnvInstance>(obj)) {
std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>(); std::shared_ptr<EnvInstance> env = obj.cast<std::shared_ptr<EnvInstance>>();
converted = env; converted = env;
......
...@@ -85,7 +85,6 @@ const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super"; ...@@ -85,7 +85,6 @@ const char PYTHON_PARSE_ANALYZE_SUPER[] = "analyze_super";
const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj"; const char PYTHON_PARSE_CLASS_SLICE[] = "create_slice_obj";
const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj"; const char PYTHON_PARSE_CLASS_ELLIPSIS[] = "create_ellipsis_obj";
const char PYTHON_MOD_GET_DEFAULT_INPUT[] = "get_default_input";
// define the common name // define the common name
const char NAMED_PRIMITIVE_LEN[] = "len"; const char NAMED_PRIMITIVE_LEN[] = "len";
......
...@@ -103,10 +103,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object ...@@ -103,10 +103,9 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr &func_graph, const py::object
} }
if (para_node == nullptr) { if (para_node == nullptr) {
auto node = top_graph->AddWeightParameter(param_name); auto node = top_graph->AddWeightParameter(param_name);
auto param_value = py::cast<ParamValuePtr>(python_adapter::GetPyObjAttr(obj, "_value")); auto value = py::cast<tensor::MetaTensorPtr>(obj);
node->set_default_param(param_value); node->set_default_param(value);
// set_abstract for parameter // set_abstract for parameter
ValuePtr value = param_value->value();
constexpr bool broaden = true; constexpr bool broaden = true;
node->set_abstract(abstract::FromValue(value, broaden)); node->set_abstract(abstract::FromValue(value, broaden));
para_node = node; para_node = node;
......
...@@ -719,7 +719,11 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef ...@@ -719,7 +719,11 @@ void ProcessVmArgInner(const py::tuple &args, const ResourcePtr &res, VectorRef
if (!param_ptr->has_default()) { if (!param_ptr->has_default()) {
MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param"; MS_LOG(EXCEPTION) << "Parameter[" << i << "] has no default param";
} }
arg_list->push_back(param_ptr->default_param()->value()); if (!param_ptr->default_param()->isa<Tensor>()) {
MS_LOG(EXCEPTION) << "Parameter[" << param_ptr->ToString()
<< "] is not initialized, need to call `.init_data()`";
}
arg_list->push_back(param_ptr->default_param());
} }
} }
} }
...@@ -782,6 +786,24 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::stri ...@@ -782,6 +786,24 @@ FuncGraphPtr ExecutorPy::BuildGraph(const py::dict &init_params, const std::stri
#endif #endif
} }
void ExecutorPy::UpdataParamNodeDefaultInput(const std::string &phase,
const std::unordered_map<std::string, tensor::TensorPtr> &params_value) {
FuncGraphPtr func_graph = info_[phase]->resource->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
MS_LOG(DEBUG) << "UpdataParamNodeDefaultInput for func graph(" << func_graph->ToString() << ") phase(" << phase
<< ")!";
auto &params = func_graph->parameters();
for (const auto &param : params) {
MS_EXCEPTION_IF_NULL(param);
auto param_cast = param->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(param_cast);
auto iter = params_value.find(param_cast->name());
if (iter != params_value.end()) {
param_cast->set_default_param(iter->second);
}
}
}
void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) { void ExecutorPy::RunInitGraph(const py::dict &init_params, const std::string &phase) {
#if ENABLE_GE #if ENABLE_GE
RunGEInitGraph(init_params, phase); RunGEInitGraph(init_params, phase);
......
...@@ -88,6 +88,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> { ...@@ -88,6 +88,8 @@ class ExecutorPy : public std::enable_shared_from_this<ExecutorPy> {
FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase, FuncGraphPtr BuildGraph(const py::dict &init_params, const std::string &phase,
const py::object &broadcast_params = {}); const py::object &broadcast_params = {});
void UpdataParamNodeDefaultInput(const std::string &phase,
const std::unordered_map<std::string, tensor::TensorPtr> &params);
void RunInitGraph(const py::dict &init_params, const std::string &phase); void RunInitGraph(const py::dict &init_params, const std::string &phase);
py::dict GetParameterLayout(const std::string &phase); py::dict GetParameterLayout(const std::string &phase);
py::dict GetCNodeStrategy(const std::string &phase); py::dict GetCNodeStrategy(const std::string &phase);
......
...@@ -146,12 +146,6 @@ static std::string GetOpId(const OpExecInfoPtr &op_exec_info) { ...@@ -146,12 +146,6 @@ static std::string GetOpId(const OpExecInfoPtr &op_exec_info) {
return id; return id;
} }
py::object GetTupleObj(const py::object &obj) {
py::module mod = parse::python_adapter::GetPyModule(parse::PYTHON_MOD_PARSE_MODULE);
py::object obj_tuple = parse::python_adapter::CallPyModFn(mod, parse::PYTHON_MOD_GET_DEFAULT_INPUT, obj);
return obj_tuple;
}
std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) { std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) {
std::map<SignatureEnumDType, std::vector<size_t>> type_indexes; std::map<SignatureEnumDType, std::vector<size_t>> type_indexes;
for (size_t i = 0; i < dtypes.size(); ++i) { for (size_t i = 0; i < dtypes.size(); ++i) {
...@@ -242,7 +236,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu ...@@ -242,7 +236,7 @@ py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tu
py::tuple input_mask(args.size()); py::tuple input_mask(args.size());
for (size_t i = 0; i < args.size(); ++i) { for (size_t i = 0; i < args.size(); ++i) {
input_mask[i] = py::hasattr(args[i], "__parameter__"); input_mask[i] = py::hasattr(args[i], "__parameter__");
py_args[i] = GetTupleObj(args[i]); py_args[i] = args[i];
} }
auto signature = prim->signatures(); auto signature = prim->signatures();
std::vector<SignatureEnumDType> dtypes; std::vector<SignatureEnumDType> dtypes;
...@@ -366,9 +360,6 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat ...@@ -366,9 +360,6 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
py::tuple result(op_inputs.size()); py::tuple result(op_inputs.size());
for (size_t i = 0; i < op_inputs.size(); i++) { for (size_t i = 0; i < op_inputs.size(); i++) {
py::object input = op_inputs[i]; py::object input = op_inputs[i];
if (py::hasattr(input, "__parameter__")) {
input = py::getattr(input, "data");
}
auto tensor = py::cast<tensor::TensorPtr>(input); auto tensor = py::cast<tensor::TensorPtr>(input);
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr()); auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape(), tensor->data_ptr());
new_tensor->set_device_address(tensor->device_address()); new_tensor->set_device_address(tensor->device_address());
...@@ -878,8 +869,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o ...@@ -878,8 +869,7 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, const py::object &o
if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) { if (graph_info_map_[df_builder_].param_map.count(obj_id) == 0) {
auto free_param = df_builder_->add_parameter(); auto free_param = df_builder_->add_parameter();
free_param->set_name(param_name); free_param->set_name(param_name);
auto free_param_new = py::cast<ParamValuePtr>(obj.attr("_value")); free_param->set_default_param(py::cast<tensor::TensorPtr>(obj));
free_param->set_default_param(free_param_new);
free_param->debug_info()->set_name(param_name); free_param->debug_info()->set_name(param_name);
MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id; MS_LOG(DEBUG) << "Top graph set free parameter " << obj_id;
graph_info_map_[df_builder_].param_map[obj_id] = free_param; graph_info_map_[df_builder_].param_map[obj_id] = free_param;
...@@ -1074,8 +1064,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args ...@@ -1074,8 +1064,7 @@ abstract::AbstractBasePtrList PynativeExecutor::GetArgsSpec(const py::args &args
for (const auto &param : df_builder_->parameters()) { for (const auto &param : df_builder_->parameters()) {
auto param_node = std::static_pointer_cast<Parameter>(param); auto param_node = std::static_pointer_cast<Parameter>(param);
if (param_node->has_default()) { if (param_node->has_default()) {
const auto &param_value = param_node->default_param(); ValuePtr value = param_node->default_param();
ValuePtr value = param_value->value();
AbstractBasePtr ptr = abstract::FromValue(value, true); AbstractBasePtr ptr = abstract::FromValue(value, true);
if (ptr == nullptr) { if (ptr == nullptr) {
MS_LOG(EXCEPTION) << "Args convert error"; MS_LOG(EXCEPTION) << "Args convert error";
......
...@@ -187,7 +187,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap ...@@ -187,7 +187,7 @@ void IrExportBuilder::BuildParameters(const FuncGraphPtr &func_graph, onnx::Grap
onnx::TensorProto *initializer_proto = graph_proto->add_initializer(); onnx::TensorProto *initializer_proto = graph_proto->add_initializer();
initializer_proto->set_name(param_name); initializer_proto->set_name(param_name);
SetParamToTensorProto(param, initializer_proto); SetParamToTensorProto(param, initializer_proto);
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param()->value()); auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param->default_param());
if (tensor) { if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
} }
......
...@@ -449,7 +449,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP ...@@ -449,7 +449,7 @@ void OnnxExporter::ExportParameters(const FuncGraphPtr &func_graph, onnx::GraphP
initializer_proto->set_name(param_ptr->ToString()); initializer_proto->set_name(param_ptr->ToString());
SetTensorProtoInfo(param_ptr, initializer_proto); SetTensorProtoInfo(param_ptr, initializer_proto);
// set value for initializer // set value for initializer
auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param()->value()); auto tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_ptr->default_param());
if (tensor) { if (tensor) {
initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes()); initializer_proto->set_raw_data(tensor->data_c(), tensor->data().nbytes());
} }
......
...@@ -52,7 +52,7 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string &param_name, ...@@ -52,7 +52,7 @@ bool GetParameterShape(const FuncGraphPtr &graph, const std::string &param_name,
if (param_node->name() == param_name) { if (param_node->name() == param_name) {
TensorPtr tensor; TensorPtr tensor;
if (param_node->has_default()) { if (param_node->has_default()) {
tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param()->value()); tensor = std::dynamic_pointer_cast<tensor::Tensor>(param_node->default_param());
} }
if (tensor == nullptr) { if (tensor == nullptr) {
shape->push_back(ONE_SHAPE); shape->push_back(ONE_SHAPE);
......
...@@ -448,7 +448,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple ...@@ -448,7 +448,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &output, const py::tuple
if (!param->has_default()) { if (!param->has_default()) {
MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")"; MS_LOG(EXCEPTION) << "Can not determine value of Parameter " << index << " (" << param->name() << ")";
} }
auto tensor = param->default_param()->value(); auto tensor = param->default_param();
*ret_val = py::cast(tensor); *ret_val = py::cast(tensor);
} }
return true; return true;
......
...@@ -124,10 +124,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons ...@@ -124,10 +124,7 @@ bool MSANFModelParser::BuildParameterForFuncGraph(const ParameterPtr &node, cons
MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret; MS_LOG(EXCEPTION) << "memcpy_s error, errorno" << ret;
} }
auto param_value = std::make_shared<ParamValue>(); node->set_default_param(tensor_info);
MS_EXCEPTION_IF_NULL(param_value);
param_value->set_value(tensor_info);
node->set_default_param(param_value);
} }
anfnode_build_map_[value_proto.name()] = node; anfnode_build_map_[value_proto.name()] = node;
return true; return true;
......
...@@ -24,22 +24,19 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) { ...@@ -24,22 +24,19 @@ REGISTER_PYBIND_DEFINE(ParamValue, ([](const py::module *m) {
(void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue") (void)py::class_<ParamValue, ParamValuePtr>(*m, "ParamValue")
.def(py::init()) .def(py::init())
.def("clone", &ParamValue::Clone) .def("clone", &ParamValue::Clone)
.def_property("data", &ParamValue::value, &ParamValue::set_value)
.def_property("name", &ParamValue::name, &ParamValue::set_name) .def_property("name", &ParamValue::name, &ParamValue::set_name)
.def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad) .def_property("requires_grad", &ParamValue::requires_grad, &ParamValue::set_requires_grad)
.def_property("layerwise_parallel", &ParamValue::layerwise_parallel, .def_property("layerwise_parallel", &ParamValue::layerwise_parallel,
&ParamValue::set_layerwise_parallel) &ParamValue::set_layerwise_parallel)
.def(py::pickle( .def(py::pickle(
[](const ParamValue &p) { // __getstate__ [](const ParamValue &p) { // __getstate__
return py::make_tuple(py::cast(p.value()), p.name(), p.requires_grad(), return py::make_tuple(p.name(), p.requires_grad(), p.layerwise_parallel());
p.layerwise_parallel());
}, },
[](const py::tuple &t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 6) { if (t.size() != 6) {
std::runtime_error("Invalid state for ParamValue!"); std::runtime_error("Invalid state for ParamValue!");
} }
ParamValuePtr p = std::make_shared<ParamValue>(); ParamValuePtr p = std::make_shared<ParamValue>();
p->set_value(t[0].cast<tensor::TensorPtr>());
p->set_name(t[1].cast<std::string>()); p->set_name(t[1].cast<std::string>());
p->set_requires_grad(t[2].cast<bool>()); p->set_requires_grad(t[2].cast<bool>());
p->set_layerwise_parallel(t[3].cast<bool>()); p->set_layerwise_parallel(t[3].cast<bool>());
......
...@@ -372,7 +372,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) { ...@@ -372,7 +372,7 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
.def(py::pickle( .def(py::pickle(
[](const Tensor &t) { // __getstate__ [](const Tensor &t) { // __getstate__
/* Return a tuple that fully encodes the state of the object */ /* Return a tuple that fully encodes the state of the object */
return py::make_tuple(TensorPy::AsNumpy(t)); return py::make_tuple(TensorPy::SyncAsNumpy(t));
}, },
[](const py::tuple &t) { // __setstate__ [](const py::tuple &t) { // __setstate__
if (t.size() != 1) { if (t.size() != 1) {
......
...@@ -255,7 +255,6 @@ def ms_function(fn=None, obj=None, input_signature=None): ...@@ -255,7 +255,6 @@ def ms_function(fn=None, obj=None, input_signature=None):
process_obj = obj process_obj = obj
if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__): if args and not isinstance(args[0], MsTensor) and hasattr(args[0], func.__name__):
process_obj = args[0] process_obj = args[0]
args = (x.default_input if hasattr(x, 'default_input') else x for x in args)
return _MindSporeFunction(func, input_signature, process_obj)(*args) return _MindSporeFunction(func, input_signature, process_obj)(*args)
return staging_specialize return staging_specialize
...@@ -354,28 +353,8 @@ class _Executor: ...@@ -354,28 +353,8 @@ class _Executor:
raise RuntimeError("Failure to init and dataset subgraph!") raise RuntimeError("Failure to init and dataset subgraph!")
return True return True
def _build_data_graph(self, obj, params, phase): def _build_data_graph(self, obj, phase):
if params is None: self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
self._executor.build_data_graph(obj.parameters_dict(), phase, obj.parameters_broadcast_dict())
elif isinstance(params, OrderedDict):
self._executor.build_data_graph(params, phase)
else:
raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params)))
def _params_init_data(self, obj, params, auto_parallel_mode=False):
"""Init parameters' data."""
if params is not None:
for key, param in params.items():
if not auto_parallel_mode:
param.init_data()
elif key not in obj.parameter_layout_dict:
logger.debug("Layout dict does not contain the key %s.", key)
param.init_data(set_sliced=True)
else:
layout = obj.parameter_layout_dict[key]
param.init_data(layout, set_sliced=True)
obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
def _set_dataset_mode(self, args_list): def _set_dataset_mode(self, args_list):
"""set dataset mode.""" """set dataset mode."""
...@@ -386,7 +365,7 @@ class _Executor: ...@@ -386,7 +365,7 @@ class _Executor:
else: else:
_set_dataset_mode_config('normal') _set_dataset_mode_config('normal')
def compile(self, obj, *args, phase='predict', params=None, do_convert=True, auto_parallel_mode=False): def compile(self, obj, *args, phase='predict', do_convert=True, auto_parallel_mode=False):
""" """
Compiles graph. Compiles graph.
...@@ -394,7 +373,6 @@ class _Executor: ...@@ -394,7 +373,6 @@ class _Executor:
obj (Function/Cell): The function or cell instance need compile. obj (Function/Cell): The function or cell instance need compile.
args (tuple): Function or cell input arguments. args (tuple): Function or cell input arguments.
phase (str): The name of compile phase. Default: 'predict'. phase (str): The name of compile phase. Default: 'predict'.
params (OrderedDict): The parameters dictionary used for init data graph. Default: None.
do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph. do_convert (bool): When set to True, convert ME graph to GE graph after compiling graph.
auto_parallel_mode: When set to True, use auto parallel mode to compile graph. auto_parallel_mode: When set to True, use auto parallel mode to compile graph.
...@@ -435,10 +413,8 @@ class _Executor: ...@@ -435,10 +413,8 @@ class _Executor:
if auto_parallel_mode: if auto_parallel_mode:
obj.parameter_layout_dict = self._executor.get_parameter_layout(phase) obj.parameter_layout_dict = self._executor.get_parameter_layout(phase)
self._params_init_data(obj, params, auto_parallel_mode) replace = obj.init_parameters_data(auto_parallel_mode=auto_parallel_mode)
if not enable_debug_runtime or enable_ge: self._updata_param_node_default_input(phase, replace)
if auto_parallel_mode:
obj.load_parameter_slice(params)
# set parallel inputs in sink mode # set parallel inputs in sink mode
if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag): if auto_parallel_mode and (args and isinstance(args[0], Tensor) and args[0].virtual_flag):
...@@ -446,16 +422,20 @@ class _Executor: ...@@ -446,16 +422,20 @@ class _Executor:
# the following GE init process is not needed when use vm or ms backend # the following GE init process is not needed when use vm or ms backend
if enable_ge: if enable_ge:
self._build_data_graph(obj, params, phase) self._build_data_graph(obj, phase)
if "export" not in phase: if "export" not in phase:
init_phase = "init_subgraph" + "." + str(obj.create_time) init_phase = "init_subgraph" + "." + str(obj.create_time)
_exec_init_graph(obj, init_phase) _exec_init_graph(obj, init_phase)
elif not enable_ge and "export" in phase: elif not enable_ge and "export" in phase:
self._build_data_graph(obj, params, phase) self._build_data_graph(obj, phase)
return phase, True return phase, True
def _updata_param_node_default_input(self, phase, replace):
new_param = {x.name: replace[x] for x in replace if id(x) != id(replace[x])}
return self._executor.updata_param_node_default_input(phase, new_param)
def _get_strategy(self, obj): def _get_strategy(self, obj):
real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time) real_phase = self.phase_prefix + obj.phase + '.' + str(obj.create_time)
return self._executor.get_strategy(real_phase) return self._executor.get_strategy(real_phase)
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""Parameter for cell.""" """Parameter for cell."""
import numbers
from copy import copy from copy import copy
from mindspore import context from mindspore import context
from .._c_expression import ParamValue from .._c_expression import ParamValue
...@@ -37,10 +36,17 @@ def _check_type(x): ...@@ -37,10 +36,17 @@ def _check_type(x):
return True return True
class Parameter: class Parameter(MetaTensor):
""" """
Parameter types of cell models. Parameter types of cell models.
After initialized `Parameter` is a subtype of `Tensor`.
In graph mode, if init `Parameter` by a `Initializer`, the type of Parameter will be a `MetaTensor`
not a `Tensor`. `MetaTensor` only save the shape type info of a tensor with no memory usage. The shape
can be change while compile for auto-parallel. Call `init_data` will return a Tensor Parameter with
initialized data.
Note: Note:
Each parameter of Cell is represented by Parameter class. Each parameter of Cell is represented by Parameter class.
...@@ -52,23 +58,85 @@ class Parameter: ...@@ -52,23 +58,85 @@ class Parameter:
layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode, layerwise_parallel (bool): A kind of model parallel mode. When layerwise_parallel is true in paralle mode,
broadcast and gradients communication would not be applied on parameters. Default: False. broadcast and gradients communication would not be applied on parameters. Default: False.
""" """
__base_type__ = {}
def __new__(cls, default_input, name, *args, **kwargs):
input_class, *class_init_args = Parameter._get_parameter_new_args(default_input)
new_type = Parameter._get_base_class(input_class)
obj = input_class.__new__(new_type)
input_class.__init__(obj, *class_init_args)
# it's better to make the Initializer a kind of metatensor.
obj.init_mode = None
if isinstance(default_input, Initializer):
obj.init_mode = default_input
return obj
def __reduce_ex__(self, _):
data = self
if self.init_mode is not None:
data = self.init_mode
else:
# cast to break deep infinit loop while deepcopy
data = Tensor(self)
return (
Parameter, (data, self.name, self.requires_grad, self.layerwise_parallel))
def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False): def __init__(self, default_input, name, requires_grad=True, layerwise_parallel=False):
self._value = ParamValue() self._value = ParamValue()
self.set_parameter_data(default_input)
self.name = name self.name = name
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.layerwise_parallel = layerwise_parallel self.layerwise_parallel = layerwise_parallel
# this flag for tensor copy data.
self.init_flag = False
# this flag is for ge variable copy data.
self._is_init = False self._is_init = False
self._inited_param = None
self._sliced = False self._sliced = False
self.is_param_ps = False self.is_param_ps = False
self._cast_type = None self._cast_type = None
self.init_in_server = False self.init_in_server = False
if context.get_context("mode") == context.PYNATIVE_MODE:
self.init_data() @staticmethod
def _get_base_class(input_class):
input_class_name = f'Parameter{input_class.__name__}'
if input_class_name in Parameter.__base_type__:
new_type = Parameter.__base_type__[input_class_name]
else:
new_type = type(input_class_name, (Parameter, input_class), {})
Parameter.__base_type__[input_class_name] = new_type
return new_type
@staticmethod
def _get_parameter_new_args(data):
"""Set `default_input` of current `Parameter`."""
if isinstance(data, bool):
raise ValueError('Parameter data can not be `bool`')
if isinstance(data, Initializer):
if context.get_context("mode") == context.PYNATIVE_MODE:
# always init data while in pynative mode.
data = data.to_tensor()
else:
return (MetaTensor, data.dtype, data.shape)
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
return (Tensor, data.asnumpy(),)
if isinstance(data, int):
return (Tensor, data, mstype.int32)
if isinstance(data, float):
return (Tensor, data, mstype.float32)
return (Tensor, data)
def __str__(self):
value_str = MetaTensor.__repr__(self)
if isinstance(self, Tensor):
value_str = Tensor.__repr__(self)
return f'Parameter (name={self._value.name}, value={value_str})'
def __repr__(self): def __repr__(self):
format_str = 'Parameter (name={name})' value_str = MetaTensor.__repr__(self)
return format_str.format(name=self._value.name) if isinstance(self, Tensor):
value_str = Tensor.__repr__(self)
return f'Parameter (name={self._value.name}, value={value_str})'
def __parameter__(self): def __parameter__(self):
"""For parse check.""" """For parse check."""
...@@ -77,6 +145,13 @@ class Parameter: ...@@ -77,6 +145,13 @@ class Parameter:
self.is_param_ps = True self.is_param_ps = True
self.init_in_server = init_in_server self.init_in_server = init_in_server
@property
def inited_param(self):
"""Get the new parameter after call the init_data."""
return self._inited_param
@property @property
def name(self): def name(self):
"""Get the name of the parameter.""" """Get the name of the parameter."""
...@@ -157,15 +232,11 @@ class Parameter: ...@@ -157,15 +232,11 @@ class Parameter:
x._value.name = prefix + '.' + self._value.name x._value.name = prefix + '.' + self._value.name
x.is_init = False x.is_init = False
if init != 'same': if init != 'same':
shape = self.default_input.shape shape = self.shape
dtype = self.default_input.dtype dtype = self.dtype
if isinstance(init, (str, Initializer, numbers.Number)): x.default_input = initializer(init, shape=shape, dtype=dtype)
x.init_mode = initializer(init, shape=shape, dtype=dtype) if context.get_context("mode") == context.PYNATIVE_MODE:
x.default_input = MetaTensor(dtype, shape) x.init_data()
if context.get_context("mode") == context.PYNATIVE_MODE:
x.init_data()
else:
x.default_input = initializer(init, shape=shape, dtype=dtype)
return x return x
@property @property
...@@ -195,50 +266,65 @@ class Parameter: ...@@ -195,50 +266,65 @@ class Parameter:
@property @property
def default_input(self): def default_input(self):
return self._data return self
@default_input.setter @default_input.setter
def default_input(self, data): def default_input(self, data):
self._data = data self.set_parameter_data(data)
self._value.data = data
def __add__(self, other):
return self.default_input + other
def __sub__(self, other):
return self.default_input - other
def __mul__(self, other): def _update_tensor_data(self, data):
return self.default_input * other "Update the parameter by a Tensor."
if isinstance(self, Tensor):
# for Tensor same shape:
return self.assign_value(data)
# create a new tensor
return Parameter(data, self.name, self.requires_grad)
def __truediv__(self, other): def set_parameter_data(self, data, slice_shape=False):
return self.default_input / other """
Set `default_input` of current `Parameter`.
def __setitem__(self, index, value): Args:
default_input = self.default_input data (Union[Tensor, Initializer]): new data.
default_input[index] = value slice_shape (bool): If slice the Parameter. Default: False.
return self
def set_parameter_data(self, data): Retruns:
"""Set `default_input` of current `Parameter`.""" Parameter, the parameter after set data.
if isinstance(data, bool): """
raise ValueError('Parameter data can not be `bool`') if not isinstance(data, (MetaTensor, Initializer)):
if isinstance(data, Tensor): raise ValueError(f"Parameter data must be `Initializer` or a kind of `MetaTensor` "
# make a copy of Tensor to init the parameter f"(like `Tensor` or `MetaTensor`). But with type {type(data)}.")
data = Tensor(data.asnumpy()) # both not init.
data.init_flag = False is_incoming_tensor = isinstance(data, Tensor)
elif isinstance(data, Initializer): is_current_tensor = isinstance(self, Tensor)
self.init_mode = data
data = MetaTensor(self.init_mode.dtype, self.init_mode.shape) if is_incoming_tensor and not is_current_tensor:
elif isinstance(data, int): raise TypeError("Parameter is a `MetaTensor` and not initializered, `data` for `set_parameter_data`"
data = Tensor(data, dtype=mstype.int32) "should be a Initializer. If you want to update it by Tensor, call method"
elif isinstance(data, float): "`init_parameters_data` of `Cell` to init and replace all the Parameter of"
data = Tensor(data, dtype=mstype.float32) "network, then call this method.")
if tuple(self.shape) != tuple(data.shape):
# If Slice create Parameter shape can be change.
if slice_shape:
self._update_tensor_data(data)
self.sliced = True
else:
raise ValueError(f"Can not change the shape of Parameter which has been initialized."
f" Current shape is {self.shape}, and incoming is {data.shape}.")
if self.dtype != data.dtype:
raise ValueError(f"Can not change the Parameter dtype. Current dtype is {self.set_dtype}"
f", and incoming is {data.dtype}. Use .set_dtype(xxx) to change the dtype.")
if isinstance(data, Initializer):
# The parameter has been initializered, directly update by the data
if is_current_tensor:
self._update_tensor_data(data.to_tensor())
else:
self.init_mode = data
elif is_incoming_tensor or is_current_tensor:
self._update_tensor_data(data)
else: else:
data = Tensor(data) raise ValueError(f"Not support to update the Parameter by {data}")
data.init_flag = False return self
self.default_input = data
def init_data(self, layout=None, set_sliced=False): def init_data(self, layout=None, set_sliced=False):
""" """
...@@ -252,31 +338,37 @@ class Parameter: ...@@ -252,31 +338,37 @@ class Parameter:
- slice_shape (list[int]): Shape of slice. - slice_shape (list[int]): Shape of slice.
set_sliced (bool): True if should set parameter sliced after init the data of initializer. set_sliced (bool): True if should set parameter sliced after init the data of initializer.
Default: False. Default: False.
Returns:
Parameter, Parameter after init data.
""" """
if isinstance(self.default_input, Tensor): if self.init_mode is None:
# skip if data already initialized. return self
return if self.inited_param is not None:
return self.inited_param
if layout is not None: if layout is not None:
if not isinstance(layout, list): if not isinstance(layout, list):
raise TypeError("The layout should be list! layout is {}." raise TypeError("The layout should be list! layout is {}.".format(layout))
.format(layout))
if len(layout) < 3: if len(layout) < 3:
raise ValueError("The length of layout must be larger than 3! layout is {}." raise ValueError("The length of layout must be larger than 3! layout is {}.".format(layout))
.format(layout))
slice_index = int(_get_slice_index(layout[0], layout[1])) slice_index = int(_get_slice_index(layout[0], layout[1]))
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)):
self.default_input = self.init_mode.to_tensor(0, [1]) data = self.init_mode.to_tensor(0, [1])
else: else:
self.default_input = self.init_mode.to_tensor(slice_index, layout[2]) data = self.init_mode.to_tensor(slice_index, layout[2])
else: else:
if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)): if (self.init_in_server and self.is_param_ps and isinstance(self.init_mode, Initializer)):
self.default_input = self.init_mode.to_tensor(0, [1]) data = self.init_mode.to_tensor(0, [1])
else: else:
self.default_input = self.init_mode.to_tensor() data = self.init_mode.to_tensor()
self.init_mode = None obj = self._update_tensor_data(data)
if id(obj) != id(self):
self._inited_param = obj
obj.init_mode = None
if set_sliced: if set_sliced:
self.sliced = True obj.sliced = True
return obj
class ParameterTuple(tuple): class ParameterTuple(tuple):
......
...@@ -75,7 +75,7 @@ class Tensor(Tensor_): ...@@ -75,7 +75,7 @@ class Tensor(Tensor_):
self._virtual_flag = False self._virtual_flag = False
def __repr__(self): def __repr__(self):
return str(self.__str__()) return str(Tensor_.__str__(self))
def __add__(self, other): def __add__(self, other):
out = tensor_operator_registry.get('__add__')(self, other) out = tensor_operator_registry.get('__add__')(self, other)
......
...@@ -283,11 +283,11 @@ class Parameter : public ANode { ...@@ -283,11 +283,11 @@ class Parameter : public ANode {
std::string fullname_with_scope() override { return name(); }; std::string fullname_with_scope() override { return name(); };
bool has_default() const { return has_default_; } bool has_default() const { return has_default_; }
void set_default_param(ParamValuePtr param) { void set_default_param(ValuePtr param) {
default_param_ = param; default_param_ = param;
has_default_ = true; has_default_ = true;
} }
ParamValuePtr default_param() const { return default_param_; } ValuePtr default_param() const { return default_param_; }
bool operator==(const AnfNode &other) const override { bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) { if (!other.isa<Parameter>()) {
...@@ -303,7 +303,7 @@ class Parameter : public ANode { ...@@ -303,7 +303,7 @@ class Parameter : public ANode {
private: private:
std::string name_; std::string name_;
bool has_default_; bool has_default_;
ParamValuePtr default_param_; ValuePtr default_param_;
}; };
using ParameterPtr = std::shared_ptr<Parameter>; using ParameterPtr = std::shared_ptr<Parameter>;
......
...@@ -33,9 +33,6 @@ class ParamValue { ...@@ -33,9 +33,6 @@ class ParamValue {
virtual ~ParamValue() = default; virtual ~ParamValue() = default;
tensor::MetaTensorPtr value() const { return value_; }
void set_value(const tensor::MetaTensorPtr &value) { value_ = value; }
const std::string &name() const { return name_; } const std::string &name() const { return name_; }
void set_name(const std::string &name) { name_ = name; } void set_name(const std::string &name) { name_ = name; }
...@@ -72,7 +69,6 @@ class ParamValue { ...@@ -72,7 +69,6 @@ class ParamValue {
} }
private: private:
tensor::MetaTensorPtr value_;
std::string name_{"Parameter"}; std::string name_{"Parameter"};
bool requires_grad_{true}; bool requires_grad_{true};
bool layerwise_parallel_{false}; bool layerwise_parallel_{false};
......
...@@ -36,7 +36,7 @@ struct AnfQuantParam { ...@@ -36,7 +36,7 @@ struct AnfQuantParam {
int32_t numBits; int32_t numBits;
AnfQuantParam() : scale(1.0), zeroPoint(0), min(0.0), max(0.0), narrowRange(false), numBits(8), inited(false) {} AnfQuantParam() : scale(1.0), zeroPoint(0), min(0.0), max(0.0), narrowRange(false), numBits(8), inited(false) {}
}; };
class ParamValueLite : public ParamValue { class ParamValueLite : public Value {
public: public:
ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {} ParamValueLite() : tensor_addr_(nullptr), tensor_size_(0) {}
virtual ~ParamValueLite() = default; virtual ~ParamValueLite() = default;
...@@ -65,6 +65,10 @@ class ParamValueLite : public ParamValue { ...@@ -65,6 +65,10 @@ class ParamValueLite : public ParamValue {
quant_params_.emplace_back(std::move(quant_param)); quant_params_.emplace_back(std::move(quant_param));
} }
bool operator==(const Value &other) const override {
this == &other;
}
private: private:
void *tensor_addr_; void *tensor_addr_;
size_t tensor_size_; size_t tensor_size_;
......
...@@ -229,7 +229,6 @@ class Cell: ...@@ -229,7 +229,6 @@ class Cell:
for item in inputs: for item in inputs:
if isinstance(item, numpy.ndarray): if isinstance(item, numpy.ndarray):
raise TypeError("cell inputs should not be numpy array.") raise TypeError("cell inputs should not be numpy array.")
self.init_parameters_data()
orign_grad = [] orign_grad = []
if self.requires_grad is True: if self.requires_grad is True:
_pynative_exec.set_grad_flag(True) _pynative_exec.set_grad_flag(True)
...@@ -350,19 +349,8 @@ class Cell: ...@@ -350,19 +349,8 @@ class Cell:
params (dict): The parameters dictionary used for init data graph. params (dict): The parameters dictionary used for init data graph.
""" """
if params is None: if params is None:
for key in self.parameters_dict(): params = self.parameters_dict()
tensor = self.parameters_dict()[key].data if isinstance(params, OrderedDict):
if key not in self.parameter_layout_dict:
logger.info("layout dict does not contain the key %s", key)
continue
if self.parameters_dict()[key].sliced:
logger.debug("Param %s is already sliced.", key)
continue
layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout)
self.parameters_dict()[key].set_parameter_data(new_tensor)
self.parameters_dict()[key].sliced = True
elif isinstance(params, OrderedDict):
for key in params: for key in params:
tensor = params[key].data tensor = params[key].data
if key not in self.parameter_layout_dict: if key not in self.parameter_layout_dict:
...@@ -373,8 +361,7 @@ class Cell: ...@@ -373,8 +361,7 @@ class Cell:
continue continue
layout = self.parameter_layout_dict[key] layout = self.parameter_layout_dict[key]
new_tensor = _load_tensor_by_layout(tensor, layout) new_tensor = _load_tensor_by_layout(tensor, layout)
params[key].set_parameter_data(new_tensor) params[key].set_parameter_data(new_tensor, True)
params[key].sliced = True
else: else:
raise TypeError('Parameters need OrderedDict type, but got {}'. raise TypeError('Parameters need OrderedDict type, but got {}'.
format(type(params))) format(type(params)))
...@@ -545,17 +532,46 @@ class Cell: ...@@ -545,17 +532,46 @@ class Cell:
""" """
raise NotImplementedError raise NotImplementedError
def init_parameters_data(self, recurse=True, auto_parallel_mode=False): def init_parameters_data(self, auto_parallel_mode=False):
"""Init parameters' data.""" """
for param in self.get_parameters(expand=recurse): Init all parameters' data and replace the original saved parameters in cell.
if not auto_parallel_mode:
param.init_data() Args:
elif param.name not in self.parameter_layout_dict: auto_parallel_mode (bool): If running in auto_parallel_mode.
logger.debug("Layout dict does not contain the key %s.", param.name)
param.init_data(set_sliced=True) Returns:
else: Dict[Parameter, Parameter], returns a dict of original parameter and replaced parameter.
layout = self.parameter_layout_dict[param.name] """
param.init_data(layout, set_sliced=True) replace = dict()
def _updata(param):
if param in replace:
return replace[param]
layout = None
set_sliced = False
if auto_parallel_mode:
set_sliced = True
if param.name not in self.parameter_layout_dict:
logger.debug("Layout dict does not contain the key %s.", param.name)
else:
layout = self.parameter_layout_dict[param.name]
new_p = param.init_data(layout, set_sliced=set_sliced)
replace[param] = new_p
return new_p
# replace all original usage.
cells = self.cells_and_names()
for _, cell in cells:
params = cell._params.items()
for param_name, param in params:
cell._params[param_name] = _updata(param)
cell_dict = cell.__dict__
for key in cell_dict:
if isinstance(cell_dict[key], ParameterTuple):
param_tuple = cell_dict[key]
new_param_tuple = []
for param in param_tuple:
new_param_tuple.append(_updata(param))
cell.__dict__[key] = ParameterTuple(new_param_tuple)
return replace
def parameters_dict(self, recurse=True): def parameters_dict(self, recurse=True):
""" """
...@@ -682,9 +698,10 @@ class Cell: ...@@ -682,9 +698,10 @@ class Cell:
for cell_name, cell in cells: for cell_name, cell in cells:
params = cell._params.items() params = cell._params.items()
for par_name, par in params: for par_name, par in params:
if par and par not in params_set: if par.inited_param is not None:
par = par.inited_param
if par is not None and par not in params_set:
params_set.add(par) params_set.add(par)
par_new_name = par_name par_new_name = par_name
if cell_name: if cell_name:
par_new_name = cell_name + '.' + par_new_name par_new_name = cell_name + '.' + par_new_name
......
...@@ -90,7 +90,7 @@ class Optimizer(Cell): ...@@ -90,7 +90,7 @@ class Optimizer(Cell):
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0): def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0):
super(Optimizer, self).__init__(auto_prefix=False) super(Optimizer, self).__init__(auto_prefix=False)
if parameters and not isinstance(parameters, list): if parameters is not None and not isinstance(parameters, list):
parameters = list(parameters) parameters = list(parameters)
if not parameters: if not parameters:
......
...@@ -295,7 +295,6 @@ def load_param_into_net(net, parameter_dict): ...@@ -295,7 +295,6 @@ def load_param_into_net(net, parameter_dict):
logger.error("Failed to combine the net and the parameters.") logger.error("Failed to combine the net and the parameters.")
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param))) msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
raise TypeError(msg) raise TypeError(msg)
param.init_data()
_update_param(param, new_param) _update_param(param, new_param)
else: else:
param_not_load.append(param.name) param_not_load.append(param.name)
...@@ -362,15 +361,13 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, a ...@@ -362,15 +361,13 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, a
integrated_save (bool): Whether to integrated save in automatic model parallel scene. integrated_save (bool): Whether to integrated save in automatic model parallel scene.
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False. async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
""" """
train_network.init_parameters_data()
param_dict = {} param_dict = {}
for _, param in train_network.parameters_and_names(): for _, param in train_network.parameters_and_names():
param_dict[param.name] = param param_dict[param.name] = param
param_list = [] param_list = []
for (key, value) in param_dict.items(): for (key, value) in param_dict.items():
each_param = {"name": key} each_param = {"name": key}
value.init_data()
if isinstance(value.data, Tensor): if isinstance(value.data, Tensor):
param_data = value.data param_data = value.data
else: else:
......
...@@ -263,6 +263,7 @@ class MobileNetV2(nn.Cell): ...@@ -263,6 +263,7 @@ class MobileNetV2(nn.Cell):
Examples: Examples:
>>> _initialize_weights() >>> _initialize_weights()
""" """
self.init_parameters_data()
for _, m in self.cells_and_names(): for _, m in self.cells_and_names():
if isinstance(m, (nn.Conv2d, DepthwiseConv)): if isinstance(m, (nn.Conv2d, DepthwiseConv)):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
......
...@@ -196,6 +196,7 @@ class mobilenetV2(nn.Cell): ...@@ -196,6 +196,7 @@ class mobilenetV2(nn.Cell):
self.head = nn.SequentialCell(head) self.head = nn.SequentialCell(head)
# init weights # init weights
self.init_parameters_data()
self._initialize_weights() self._initialize_weights()
def construct(self, x): def construct(self, x):
...@@ -215,6 +216,7 @@ class mobilenetV2(nn.Cell): ...@@ -215,6 +216,7 @@ class mobilenetV2(nn.Cell):
Examples: Examples:
>>> _initialize_weights() >>> _initialize_weights()
""" """
self.init_parameters_data()
for _, m in self.cells_and_names(): for _, m in self.cells_and_names():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
......
...@@ -200,6 +200,7 @@ class ResUnit(nn.Cell): ...@@ -200,6 +200,7 @@ class ResUnit(nn.Cell):
self.add = P.TensorAdd() if self.use_short_cut_conv else None self.add = P.TensorAdd() if self.use_short_cut_conv else None
def construct(self, x): def construct(self, x):
"""construct"""
if self.first_conv: if self.first_conv:
out = self.expand(x) out = self.expand(x)
else: else:
...@@ -289,6 +290,7 @@ class MobileNetV3(nn.Cell): ...@@ -289,6 +290,7 @@ class MobileNetV3(nn.Cell):
kernel_size=1, has_bias=True, pad_mode='pad') kernel_size=1, has_bias=True, pad_mode='pad')
self.squeeze = P.Squeeze(axis=(2, 3)) self.squeeze = P.Squeeze(axis=(2, 3))
self.init_parameters_data()
self._initialize_weights() self._initialize_weights()
def construct(self, x): def construct(self, x):
......
...@@ -171,9 +171,9 @@ def test_bert_tdt(): ...@@ -171,9 +171,9 @@ def test_bert_tdt():
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
callback = ModelCallback() callback = ModelCallback()
netwithloss.init_parameters_data()
params = netwithloss.trainable_params() params = netwithloss.trainable_params()
for param in params: for param in params:
param.init_data()
value = param.default_input value = param.default_input
name = param.name name = param.name
if isinstance(value, Tensor): if isinstance(value, Tensor):
......
...@@ -207,9 +207,9 @@ def test_bert_percision(): ...@@ -207,9 +207,9 @@ def test_bert_percision():
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
callback = ModelCallback() callback = ModelCallback()
netwithloss.init_parameters_data()
params = netwithloss.trainable_params() params = netwithloss.trainable_params()
for param in params: for param in params:
param.init_data()
value = param.default_input value = param.default_input
name = param.name name = param.name
if isinstance(value, Tensor): if isinstance(value, Tensor):
...@@ -279,9 +279,9 @@ def test_bert_performance(): ...@@ -279,9 +279,9 @@ def test_bert_performance():
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
callback = ModelCallback() callback = ModelCallback()
netwithloss.init_parameters_data()
params = netwithloss.trainable_params() params = netwithloss.trainable_params()
for param in params: for param in params:
param.init_data()
value = param.default_input value = param.default_input
name = param.name name = param.name
if isinstance(value, Tensor): if isinstance(value, Tensor):
......
...@@ -63,6 +63,7 @@ class LossCallBack(Callback): ...@@ -63,6 +63,7 @@ class LossCallBack(Callback):
str(cb_params.net_outputs))) str(cb_params.net_outputs)))
def model_fine_tune(train_net, fix_weight_layer): def model_fine_tune(train_net, fix_weight_layer):
train_net.init_parameters_data()
for para in train_net.trainable_params(): for para in train_net.trainable_params():
para.set_parameter_data(Tensor(np.ones(para.data.shape).astype(np.float32) * 0.02)) para.set_parameter_data(Tensor(np.ones(para.data.shape).astype(np.float32) * 0.02))
if fix_weight_layer in para.name: if fix_weight_layer in para.name:
......
...@@ -174,9 +174,14 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl): ...@@ -174,9 +174,14 @@ def train_process(q, device_id, epoch_size, device_num, enable_hccl):
steps_per_epoch=step_size, lr_decay_mode=config.lr_decay_mode)) steps_per_epoch=step_size, lr_decay_mode=config.lr_decay_mode))
# optimizer # optimizer
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, decayed_params = []
net.trainable_params())) no_decayed_params = []
no_decayed_params = [param for param in net.trainable_params() if param not in decayed_params] for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
decayed_params.append(param)
else:
no_decayed_params.append(param)
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params, 'weight_decay': 0.0}, {'params': no_decayed_params, 'weight_decay': 0.0},
{'order_params': net.trainable_params()}] {'order_params': net.trainable_params()}]
......
...@@ -107,7 +107,6 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) { ...@@ -107,7 +107,6 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond2) {
for (auto p : kg->parameters()) { for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>(); auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr); EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
...@@ -157,7 +156,6 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) { ...@@ -157,7 +156,6 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond4) {
for (auto p : kg->parameters()) { for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>(); auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr); EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
...@@ -185,7 +183,6 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond5) { ...@@ -185,7 +183,6 @@ TEST_F(TestHWInsertMemcpyForHccl, test_cond5) {
for (auto p : kg->parameters()) { for (auto p : kg->parameters()) {
auto param = p->cast<ParameterPtr>(); auto param = p->cast<ParameterPtr>();
EXPECT_NE(param, nullptr); EXPECT_NE(param, nullptr);
param->set_default_param(std::make_shared<ParamValue>());
} }
auto optimizer = std::make_shared<opt::GraphOptimizer>(); auto optimizer = std::make_shared<opt::GraphOptimizer>();
......
...@@ -766,7 +766,7 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) { ...@@ -766,7 +766,7 @@ TEST_F(AnfRuntimeAlgorithmTest, IsParameterWeight) {
auto kernel_graph = std::make_shared<KernelGraph>(); auto kernel_graph = std::make_shared<KernelGraph>();
auto parameter_node = kernel_graph->add_parameter(); auto parameter_node = kernel_graph->add_parameter();
MS_EXCEPTION_IF_NULL(parameter_node); MS_EXCEPTION_IF_NULL(parameter_node);
auto param_value_new = std::make_shared<ParamValue>(); auto param_value_new = std::make_shared<tensor::Tensor>(int64_t(0), kInt32);
parameter_node->set_default_param(param_value_new); parameter_node->set_default_param(param_value_new);
EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node)); EXPECT_TRUE(AnfAlgo::IsParameterWeight(parameter_node));
EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error); EXPECT_THROW(AnfAlgo::IsParameterWeight(nullptr), std::runtime_error);
......
...@@ -82,7 +82,7 @@ TEST_F(KernelGraphTest, NewParameter) { ...@@ -82,7 +82,7 @@ TEST_F(KernelGraphTest, NewParameter) {
// test weight parameter node as input // test weight parameter node as input
auto weight_parameter_node = anf_graph->add_parameter(); auto weight_parameter_node = anf_graph->add_parameter();
MS_EXCEPTION_IF_NULL(weight_parameter_node); MS_EXCEPTION_IF_NULL(weight_parameter_node);
auto param_value_new = std::make_shared<ParamValue>(); auto param_value_new = std::make_shared<tensor::Tensor>(kNumberTypeFloat32, shape);
weight_parameter_node->set_default_param(param_value_new); weight_parameter_node->set_default_param(param_value_new);
weight_parameter_node->set_abstract(x_abstract); weight_parameter_node->set_abstract(x_abstract);
auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node); auto new_weight_parameter_node = kernel_graph->NewParameter(weight_parameter_node);
......
...@@ -225,7 +225,7 @@ def test_div(): ...@@ -225,7 +225,7 @@ def test_div():
@non_graph_engine @non_graph_engine
def test_parameter(): def test_parameter():
x = Parameter(initializer(1, [1], ms.float32), name="beta1_power") x = Parameter(initializer(1, [1], ms.float32), name="beta1_power")
x.init_data() x = x.init_data()
z = x / 2 z = x / 2
print(z) print(z)
......
...@@ -139,14 +139,31 @@ def test_parameter_lazy_init(): ...@@ -139,14 +139,31 @@ def test_parameter_lazy_init():
# Call init_data() without set default_input. # Call init_data() without set default_input.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1') para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test1')
assert not isinstance(para.default_input, Tensor) assert not isinstance(para.default_input, Tensor)
para.init_data() para = para.init_data()
assert isinstance(para.default_input, Tensor) assert isinstance(para.default_input, Tensor)
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3))) assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
# Call init_data() after default_input is set. # Call init_data() after default_input is set.
para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2') para = Parameter(initializer('ones', [1, 2, 3], mstype.float32), 'test2')
assert not isinstance(para.default_input, Tensor) assert not isinstance(para.default_input, Tensor)
para.default_input = Tensor(np.zeros((1, 2, 3))) # expect type error when not init
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) with pytest.raises(TypeError):
para.init_data() # expect no effect. para.default_input = Tensor(np.zeros((1, 2, 3)))
# init then assign
para = para.init_data()
# check the type
with pytest.raises(ValueError):
para.default_input = Tensor(np.zeros((1, 2, 3)))
# check the shape
with pytest.raises(ValueError):
para.default_input = Tensor(np.zeros((1, 2)))
# expect change ok
para.default_input = Tensor(np.zeros((1, 2, 3)).astype(np.float32))
assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3))) assert np.array_equal(para.default_input.asnumpy(), np.zeros((1, 2, 3)))
para.default_input = initializer('ones', [1, 2, 3], mstype.float32)
assert isinstance(para.default_input, Tensor)
# same object and has inited
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
# expect no effect.
para.init_data()
assert np.array_equal(para.default_input.asnumpy(), np.ones((1, 2, 3)))
...@@ -69,8 +69,7 @@ def test_qat_lenet(): ...@@ -69,8 +69,7 @@ def test_qat_lenet():
net = qat.convert_quant_network( net = qat.convert_quant_network(
net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) net, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# should load the checkpoint. mock here # should load the checkpoint. mock here
for param in net.get_parameters(): net.init_parameters_data()
param.init_data()
qat.export(net, img, file_name="quant.pb") qat.export(net, img, file_name="quant.pb")
...@@ -80,8 +79,7 @@ def test_qat_mobile_per_channel_tf(): ...@@ -80,8 +79,7 @@ def test_qat_mobile_per_channel_tf():
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False])
# should load the checkpoint. mock here # should load the checkpoint. mock here
for param in network.get_parameters(): network.init_parameters_data()
param.init_data()
qat.export(network, img, file_name="quant.pb") qat.export(network, img, file_name="quant.pb")
@pytest.mark.skip(reason="no `te.lang.cce` in ut env") @pytest.mark.skip(reason="no `te.lang.cce` in ut env")
...@@ -90,6 +88,5 @@ def test_qat_mobile_per_channel_ff(): ...@@ -90,6 +88,5 @@ def test_qat_mobile_per_channel_ff():
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False]) network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False])
# should load the checkpoint. mock here # should load the checkpoint. mock here
for param in network.get_parameters(): network.init_parameters_data()
param.init_data()
qat.export(network, img, file_name="quant.pb") qat.export(network, img, file_name="quant.pb")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册