未验证 提交 ff0886a9 编写于 作者: H hong 提交者: GitHub

save load problem fix and new feature add (#20823)

* fix persistable;

* fix save load bugs; test=develop

* fix bug; test=develop

* add example for new io api; test=develop

* addd example; test=develop
上级 2058bab1
......@@ -126,6 +126,8 @@ class Executor {
Scope* scope, Dataset* dataset);
void RunFromDataset(std::shared_ptr<TrainerBase> trainer);
const platform::Place GetPlace() const { return place_; }
private:
const platform::Place place_;
};
......
......@@ -237,6 +237,56 @@ static std::vector<std::string> inline GetNameList(
return vec_res;
}
static void inline CreateVariableIfNotExit(
const py::handle &py_handle, const framework::Scope &scope,
const framework::Executor *exe = nullptr) {
std::vector<std::string> vec_res;
PyObject *py_obj = py_handle.ptr(); // get underlying PyObject
// Python None is not nullptr in C++!
if (!py_obj || py_obj == Py_None) {
PADDLE_THROW("Save parameter list is None");
}
if (PyList_Check(py_obj)) {
size_t len = PyList_GET_SIZE(py_obj);
vec_res.reserve(len);
const char *kNameField = "name";
const char *kVarDescField = "desc";
for (size_t i = 0; i < len; ++i) {
PyObject *py_name =
PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kNameField);
PADDLE_ENFORCE_NOT_NULL(py_name);
auto para_name = PyObjectCast<std::string>(py_name);
Py_DECREF(py_name);
auto var = scope.FindVar(para_name);
if (var == nullptr) {
PADDLE_ENFORCE_NE(exe, nullptr,
"Parameter not Initialized, "
"Please set argument [executor] not None "
"or run startup program first");
PyObject *py_var_desc =
PyObject_GetAttrString(PyList_GET_ITEM(py_obj, i), kVarDescField);
PADDLE_ENFORCE_NOT_NULL(py_var_desc);
auto var_desc = PyObjectCast<framework::VarDesc>(py_var_desc);
Py_DECREF(py_var_desc);
var = const_cast<framework::Scope *>(&scope)->Var(para_name);
auto *tensor_temp = var->GetMutable<framework::LoDTensor>();
tensor_temp->Resize(framework::make_ddim(var_desc.GetShape()));
tensor_temp->mutable_data(exe->GetPlace(), var_desc.GetDataType());
}
}
} else {
PADDLE_THROW("Set parameter should be a list");
}
return;
}
#ifdef PADDLE_WITH_AVX
PYBIND11_MODULE(core_avx, m) {
#else
......@@ -285,11 +335,18 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("_load_static_dict",
[](const std::string &str_file_name, const py::handle &vec_var_list,
const Scope &scope) {
const Scope &scope, const Executor *executor) {
std::vector<std::string> vec_name_list = GetNameList(vec_var_list);
CreateVariableIfNotExit(vec_var_list, scope, executor);
LoadStaticNameListFromDisk(str_file_name, vec_name_list, scope);
});
m.def("_create_loaded_parameter",
[](const py::handle &vec_var_list, const Scope &scope,
const Executor *executor) {
CreateVariableIfNotExit(vec_var_list, scope, executor);
});
m.def("_save_dygraph_dict", [](const std::string &str_file_name,
const PyNameVarBaseMap &state_dict) {
auto vec_var_base_list = GetVarBaseList(state_dict);
......
......@@ -86,7 +86,7 @@ from paddle.fluid.layers.math_op_patch import monkey_patch_variable
from . import install_check
from .dygraph.nn import *
from .dygraph.layers import *
from .io import save, load
from .io import save, load, load_program_state, set_program_state
from .dygraph.checkpoint import save_dygraph, load_dygraph
Tensor = LoDTensor
......
......@@ -181,6 +181,7 @@ if avx_supported():
from .core_avx import _load_static_dict
from .core_avx import _save_dygraph_dict
from .core_avx import _load_dygraph_dict
from .core_avx import _create_loaded_parameter
except Exception as e:
if has_avx_core:
raise e
......@@ -214,6 +215,7 @@ if load_noavx:
from .core_noavx import _load_static_dict
from .core_noavx import _save_dygraph_dict
from .core_noavx import _load_dygraph_dict
from .core_noavx import _create_loaded_parameter
except Exception as e:
if has_noavx_core:
sys.stderr.write(
......
......@@ -41,9 +41,19 @@ from .. import compat as cpt
batch = paddle.batch
__all__ = [
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
'load_persistables', 'save_inference_model', 'load_inference_model',
'batch', 'save', 'load'
'save_vars',
'save_params',
'save_persistables',
'load_vars',
'load_params',
'load_persistables',
'save_inference_model',
'load_inference_model',
'batch',
'save',
'load',
'load_program_state',
'set_program_state',
] + reader.__all__ + paddle.reader.__all__
_logger = get_logger(
......@@ -97,7 +107,10 @@ def is_persistable(var):
def is_belong_to_optimizer(var):
return var.belong_to_optimizer
if not isinstance(var, Parameter):
return is_persistable(var)
return False
def _clone_var_in_block_(block, var):
......@@ -1531,16 +1544,16 @@ def save(program, model_path):
f.write(program.desc.serialize_to_string())
def load(program, model_path):
def load(program, model_path, executor=None):
"""
This function filter out parameters and optimizer information from program, and then get corresponding value from file.
An exception will throw if shape or dtype of the parameters is not match between program and loaded file.
NOTICE: This function MUST called after run start_up_program
An exception will throw if shape or dtype of the parameters is not match.
Args:
program: The program to be load
model_path: The file prefix store the program
program(Program): The program will be loaded
model_path(str): The file prefix store the program
executor(Executor, optional): The executor used for initialize the parameter
When startup program is not run.
Returns:
None
......@@ -1557,6 +1570,8 @@ def load(program, model_path):
"""
assert executor is None or isinstance(executor, Executor)
parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format(parameter_file_name)
......@@ -1576,6 +1591,11 @@ def load(program, model_path):
t.set(ndarray, place)
parameter_list = list(filter(is_parameter, program.list_vars()))
if executor:
paddle.fluid.core._create_loaded_parameter(parameter_list,
global_scope(),
executor._default_executor)
with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f)
for v in parameter_list:
......@@ -1590,7 +1610,11 @@ def load(program, model_path):
if len(optimizer_var_list) > 0:
opt_file_name = model_path + ".pdopt"
assert os.path.exists(opt_file_name), \
"Optimizer file [{}] not exits".format(opt_file_name)
"Optimizer file [{}] not exits".format( opt_file_name)
if executor:
paddle.fluid.core._create_loaded_parameter(
optimizer_var_list, global_scope(), executor._default_executor)
with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f)
......@@ -1599,3 +1623,126 @@ def load(program, model_path):
"Can not find [{}] in model file [{}]".format(
v.name, opt_file_name)
set_var(v, load_dict[v.name])
def load_program_state(model_path):
"""
Load program state from local file
Args:
model_path(str): The file prefix store the program
Returns:
state_dict(dict): the dict store Parameter and optimizer information
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data( name="x", shape=[10, 10], dtype='float32')
y = fluid.layers.fc( x, 10)
z = fluid.layers.fc( y, 10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run( fluid.default_startup_program() )
prog = fluid.default_main_program()
fluid.save( prog, "./temp")
program_state = fluid.load_program_state( "./temp")
fluid.set_program_state( prog, program_state)
"""
parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format( parameter_file_name)
with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f)
opt_file_name = model_path + ".pdopt"
if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f)
para_dict.update(opti_dict)
return para_dict
def set_program_state(program, state_dict):
"""
Set program parameter from state_dict
An exception will throw if shape or dtype of the parameters is not match.
NOTICE: This function MUST called after run start_up_program
Args:
program(Program): The program to be set
state_dict(dict): the dict store Parameter and optimizer information
Returns:
None
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.data( name="x", shape=[10, 10], dtype='float32')
y = fluid.layers.fc( x, 10)
z = fluid.layers.fc( y, 10)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run( fluid.default_startup_program() )
prog = fluid.default_main_program()
fluid.save( prog, "./temp")
program_state = fluid.load_program_state( "./temp")
"""
parameter_list = list(filter(is_persistable, program.list_vars()))
used_para_list = {}
for para in parameter_list:
var_temp = paddle.fluid.global_scope().find_var(para.name)
assert var_temp != None, \
"Variable [ {} ] Not found, Please make sure run startup program".format( para.name )
if para.name in state_dict:
# set value from state dict
orig_para_np = np.array(var_temp.get_tensor())
new_para_np = state_dict[para.name]
assert orig_para_np.shape == new_para_np.shape, \
"Shape not matching: the Program requires a parameter with a shape of ({}), " \
"while the loaded parameter (namely [ {} ]) has a shape of ({})." \
.format(orig_para_np.shape, para.name, new_para_np.shape)
assert orig_para_np.dtype == new_para_np.dtype, \
"Dtype not matching: the Program requires a parameter with a dtype of ({}), " \
"while the loaded parameter (namely [ {} ]) has a dtype of ({})." \
.format(orig_para_np.dtype, para.name, new_para_np.dtype)
ten = var_temp.get_tensor()
ten_place = ten._place()
assert ten_place.is_gpu_place() or ten_place.is_cpu_place(), \
"Place not support, only support CPUPlace and GPUPlace, now is {}".format( str(ten_place))
py_place = paddle.fluid.CPUPlace()
if ten_place.is_cuda_pinned_place():
place = paddle.fluid.CUDAPinnedPlace()
elif ten_place.is_gpu_place():
p = paddle.fluid.core.Place()
p.set_place(ten_place)
py_place = paddle.fluid.CUDAPlace(p.gpu_device_id())
ten.set(new_para_np, py_place)
used_para_list[para.name] = 1
unused_para_list = []
for k, v in state_dict.items():
if k not in used_para_list:
unused_para_list.append(k)
if len(unused_para_list) > 0:
warnings.warn(
"This list is not set, Because of Paramerter not found in program. There are: {}".
format(" ".join(unused_para_list)))
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.dygraph.nn import Embedding
......@@ -22,8 +23,10 @@ import paddle.fluid.framework as framework
from paddle.fluid.optimizer import Adam
from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope
from paddle.fluid.executor import global_scope
import numpy as np
import six
import pickle
class SimpleLSTMRNN(fluid.Layer):
......@@ -210,7 +213,7 @@ class PtbModel(fluid.Layer):
return loss, last_hidden, last_cell
class TestDygraphPtbRnn(unittest.TestCase):
class TestSaveLoadBase(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -281,8 +284,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
main_program = framework.default_main_program()
base_map = {}
for var in main_program.list_vars():
if isinstance(var,
framework.Parameter) or var.belong_to_optimizer:
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been update
......@@ -293,8 +295,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
# set var to zero
for var in main_program.list_vars():
if isinstance(var,
framework.Parameter) or var.belong_to_optimizer:
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
......@@ -303,18 +304,17 @@ class TestDygraphPtbRnn(unittest.TestCase):
# make sure all the paramerter or optimzier var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
fluid.load(main_program, "./test_1")
fluid.load(main_program, "./test_1", exe)
for var in main_program.list_vars():
if isinstance(var,
framework.Parameter) or var.belong_to_optimizer:
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestDygraphPtbRnnPartial(unittest.TestCase):
class TestSaveLoadPartial(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
......@@ -393,8 +393,7 @@ class TestDygraphPtbRnnPartial(unittest.TestCase):
main_program = framework.default_main_program()
base_map = {}
for var in main_program.list_vars():
if isinstance(var,
framework.Parameter) or var.belong_to_optimizer:
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been update
......@@ -405,8 +404,7 @@ class TestDygraphPtbRnnPartial(unittest.TestCase):
# set var to zero
for var in main_program.list_vars():
if isinstance(var,
framework.Parameter) or var.belong_to_optimizer:
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
......@@ -415,11 +413,10 @@ class TestDygraphPtbRnnPartial(unittest.TestCase):
# make sure all the paramerter or optimzier var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
fluid.load(test_program, "./test_1")
fluid.load(test_program, "./test_1", None)
for var in test_program.list_vars():
if isinstance(var,
framework.Parameter) or var.belong_to_optimizer:
if isinstance(var, framework.Parameter) or var.persistable:
print(var.name)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
......@@ -427,5 +424,301 @@ class TestDygraphPtbRnnPartial(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
class TestSaveLoadSetStateDict(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
vocab_size = 1000
num_layers = 1
num_steps = 3
init_scale = 0.1
batch_size = 4
batch_num = 200
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
ptb_model = PtbModel(
"ptb_model",
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32')
init_cell = fluid.layers.data(
name="init_cell", shape=[1], dtype='float32')
static_loss, static_last_hidden, static_last_cell = ptb_model(
x, y, init_hidden, init_cell)
sgd.minimize(static_loss)
static_param_updated = dict()
static_param_init = dict()
out = exe.run(framework.default_startup_program())
static_loss_value = None
static_last_cell_value = None
static_last_hidden_value = None
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
fetch_list = [static_loss, static_last_hidden, static_last_cell]
out = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": y_data,
"init_hidden": init_hidden_data,
"init_cell": init_cell_data
},
fetch_list=fetch_list)
static_loss_value = out[0]
static_last_hidden_value = out[1]
static_last_cell_value = out[2]
# get value before save
main_program = framework.default_main_program()
base_map = {}
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
fluid.save(main_program, "./test_1")
# set var to zero
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
fluid.load(main_program, "./test_1", exe)
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestProgramStatePartial(unittest.TestCase):
def test_ptb_rnn_cpu_float32(self):
seed = 90
hidden_size = 10
vocab_size = 1000
num_layers = 1
num_steps = 3
init_scale = 0.1
batch_size = 4
batch_num = 200
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
ptb_model = PtbModel(
"ptb_model",
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
sgd = Adam(learning_rate=1e-3)
x = fluid.layers.data(
name="x", shape=[-1, num_steps, 1], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32')
init_cell = fluid.layers.data(
name="init_cell", shape=[1], dtype='float32')
static_loss, static_last_hidden, static_last_cell = ptb_model(
x, y, init_hidden, init_cell)
test_program = fluid.default_main_program().clone(for_test=True)
add_1 = fluid.layers.fc(static_last_hidden,
size=hidden_size,
num_flatten_dims=2,
bias_attr=False)
sgd.minimize(static_loss)
static_param_updated = dict()
static_param_init = dict()
out = exe.run(framework.default_startup_program())
static_loss_value = None
static_last_cell_value = None
static_last_hidden_value = None
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
fetch_list = [static_loss, static_last_hidden, static_last_cell]
out = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": y_data,
"init_hidden": init_hidden_data,
"init_cell": init_cell_data
},
fetch_list=fetch_list)
static_loss_value = out[0]
static_last_hidden_value = out[1]
static_last_cell_value = out[2]
# get value before save
main_program = framework.default_main_program()
base_map = {}
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
fluid.save(main_program, "./test_1")
# set var to zero
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
#fluid.load(test_program, "./test_1", None )
program_state = fluid.load_program_state("./test_1")
fluid.set_program_state(test_program, program_state)
for var in test_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
print(var.name)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
class TestVariableInit(unittest.TestCase):
def test_variable_init(self):
x = fluid.data(name="x", shape=[10, 10], dtype='float32')
y = fluid.layers.fc(x, 10)
z = fluid.layers.fc(y, 10)
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.save(fluid.default_main_program(), "./test_path")
def set_var(var, ndarray):
t = var.get_tensor()
p = t._place()
if p.is_cpu_place():
place = paddle.fluid.CPUPlace()
elif p.is_cuda_pinned_place():
place = paddle.fluid.CUDAPinnedPlace()
else:
p = paddle.fluid.core.Place()
p.set_place(t._place())
place = paddle.fluid.CUDAPlace(p.gpu_device_id())
t.set(ndarray, place)
program = fluid.default_main_program()
new_scope = fluid.core.Scope()
place = fluid.CPUPlace() if not core.is_compiled_with_cuda(
) else fluid.CUDAPlace(0)
exe = fluid.Executor(place)
parameter_list = list(
filter(fluid.io.is_parameter, program.list_vars()))
fluid.core._create_loaded_parameter(parameter_list, new_scope,
exe._default_executor)
parameter_file_name = "./test_path.pdparams"
with open(parameter_file_name, 'rb') as f:
load_dict = pickle.load(f)
for v in parameter_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
v.name, parameter_file_name)
new_v = new_scope.find_var(v.name)
set_var(new_v, load_dict[v.name])
opt_list = list(
filter(fluid.io.is_belong_to_optimizer, program.list_vars()))
fluid.core._create_loaded_parameter(opt_list, new_scope,
exe._default_executor)
opt_file_name = "./test_path.pdopt"
with open(opt_file_name, 'rb') as f:
load_dict = pickle.load(f)
for v in opt_list:
assert v.name in load_dict, \
"Can not find [{}] in model file [{}]".format(
v.name, opt_file_name)
new_v = new_scope.find_var(v.name)
set_var(new_v, load_dict[v.name])
base_map = {}
for var in program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimzier var have been update
base_map[var.name] = t
for var in program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(new_scope.find_var(var.name).get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册