提交 c23731e5 编写于 作者: C changzherui

Incremental subgraph initialization

上级 311b7e71
......@@ -374,6 +374,10 @@ TypeId Tensor::set_data_type(const TypeId data_type) {
return data_type_;
}
bool Tensor::is_init() { return init_flag_; }
void Tensor::set_init_flag(bool flag) { init_flag_ = flag; }
bool Tensor::convert_data(const py::array &in, const TypeId in_data_type, py::array *const out,
const TypeId out_data_type) {
if (out == nullptr) {
......@@ -499,6 +503,24 @@ REGISTER_PYBIND_DEFINE(Tensor, ([](const py::module *m) {
>>> data.size()
6
)mydelimiter")
.def("is_init", &Tensor::is_init, R"mydelimiter(
Get tensor init_flag.
Returns:
bool, whether the tensor init.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.is_init()
False
)mydelimiter")
.def("set_init_flag", &Tensor::set_init_flag, R"mydelimiter(
Set tensor init_flag.
Examples:
>>> data = mindspore.Tensor(np.ones((2, 3)))
>>> data.set_init_flag(True)
)mydelimiter")
.def("dim", &Tensor::DataDim, R"mydelimiter(
Get tensor's data dimension.
......
......@@ -389,6 +389,8 @@ class Tensor : public MetaTensor {
std::string ToStringRepr() const;
py::array data_; // < Tensor's data value
const bool parse_info_ = true;
bool is_init();
void set_init_flag(bool flag);
private:
// brief init tensor
......@@ -398,7 +400,7 @@ class Tensor : public MetaTensor {
// return true if succeed, false if failed.
void init(const py::array &input, const TypeId &data_type);
void init(const py::array &input, const TypePtr &type_ptr);
bool init_flag_{false};
// brief init tensor attribute
//
// param data_type [TypeId] Data type of the tensor.
......
......@@ -646,7 +646,6 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
if (adpt == nullptr) continue;
auto param_op = adpt->generate(name + "_data");
MS_LOG(INFO) << "Add parameter " << name << " as input, index " << index << ".";
(void)std::static_pointer_cast<Data>(param_op)->set_attr_index(index++);
if (!training_) {
auto adpt_const = FindAdapter(kNameConst, training_);
......@@ -675,6 +674,8 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
// we need three variable ops for each graph with same name
// build init subgraph
if (it.second->is_init() == 0) {
(void)std::static_pointer_cast<Data>(param_op)->set_attr_index(index++);
auto init_var = std::make_shared<Variable>(name);
auto assign_op = std::make_shared<Assign>("assign_" + name);
(void)init_var->update_output_desc_y(*desc);
......@@ -683,6 +684,7 @@ void DfGraphConvertor::InitParamWithData(const TensorOrderMap &tensors) {
init_ops_.push_back(param_op);
init_ops_.push_back(assign_op);
init_ops_.push_back(init_var);
}
auto variable = std::make_shared<Variable>(name);
(void)variable->update_output_desc_y(*desc);
......
......@@ -82,14 +82,15 @@ def _wrap_func(fn):
def _exec_init_graph(obj, init_phase):
"""Execute the parameter initializer graph."""
inst_executor = Executor_.get_instance()
exec_init_graph = False
for param in obj.get_parameters():
param_dict = OrderedDict()
for name, param in obj.parameters_dict().items():
if not param.is_init:
param_dict[name] = param
param.is_init = True
exec_init_graph = True
param.data.init_flag = True
if exec_init_graph:
inst_executor.run_init_graph(obj.parameters_dict(), init_phase)
if param_dict:
inst_executor.run_init_graph(param_dict, init_phase)
class _MindSporeFunction:
......
......@@ -188,11 +188,14 @@ class Parameter:
if isinstance(data, Tensor):
# make a copy of Tensor to init the parameter
data = Tensor(data.asnumpy().copy())
data.init_flag = False
elif isinstance(data, Initializer):
self.init_mode = data
data = MetaTensor(self.init_mode.dtype, self.init_mode.shape)
else:
data = Tensor(data)
data.init_flag = False
self.default_input = data
......
......@@ -65,6 +65,7 @@ class Tensor(Tensor_):
else:
super(Tensor, self).__init__(input_data, dtype)
self._virtual_flag = False
self._init_flag = False
def __repr__(self):
return str(self.__str__())
......@@ -153,3 +154,16 @@ class Tensor(Tensor_):
if not isinstance(value, bool):
raise TypeError("virtual_flag must be bool.")
self._virtual_flag = value
@property
def init_flag(self):
"""whether the tensor is init."""
return self._init_flag
@init_flag.setter
def init_flag(self, value):
"""Set the tensor is init_flag."""
if not isinstance(value, bool):
raise TypeError("init_flag must be bool.")
self.set_init_flag(value)
self._init_flag = value
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册