未验证 提交 6b5749eb 编写于 作者: W wanghuancoder 提交者: GitHub

[Eager] save load testcase (#39571)

* eager, test=develop

* fix bug, test=develop

* eager, test=develop

* merge legacy to fluid

* eager, test=develop

* eager, test=develop

* Refactor TensorAdd func by template and remove gradient_accumulation in eager

* Remove needless target name

* eager, test=develop

* eager, test=develop

* Use overload instead of template

* Remove legacy code

* Remove legacy code

* selectedrows, test=develop

* Remove DataType test

* eager, test=develop

* eager, test=develop

* support gan, test=develop

* Using Tensor directly instead of using EagerTensor

* support gradient_accumulation

* make test_imperative_lod_tensor_to_selected_rows longer

* make test_imperative_lod_tensor_to_selected_rows longer

* refine code

* ptb, test=develop

* Rename all EagerTensor to Tensor

* Rename some EagerTensor to Tensor

* rename EagerTensor to EagerVariable

* eager, test=develop

* eager, test=develop

* eager, test=develop

* eager, test=develop

* add more test

* eager, test=develop

* Support copiable selected rows and merge develop

* save load, eager, test=develop

* save load, eager, test=develop

* refine, test=develop

* refine, test=develop

* refine, test=develop

* revert static_runner, test=develop

* EagerTensor to Tensor, test=develop

* refine, test=develop

* refine, test=develop

* clear grad, test=develop

* merge, develop

* merge, develop

* merge, test=develop

* merge, test=develop
Co-authored-by: NJiabinYang <360788950@qq.com>
Co-authored-by: NWeilong Wu <veyron_wu@163.com>
上级 2136bd42
......@@ -64,12 +64,6 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name,
framework::proto::VarType::Type var_type =
paddle::framework::proto::VarType::LOD_TENSOR) {
auto ddims = phi::make_ddim(dims);
PADDLE_ENFORCE_GE(
phi::product(ddims), 0,
paddle::platform::errors::InvalidArgument(
"Create Eager Tensor with dims contain minus num is ilegal"
"Please check your code and make sure you new a "
"eager tensor with fixed shape instead of using -1."));
self->tensor.set_name(name);
auto autograd_meta = egr::EagerUtils::autograd_meta(&(self->tensor));
autograd_meta->SetPersistable(persistable);
......@@ -83,13 +77,10 @@ void EmptyTensorInitializer(TensorObject* self, const std::string& name,
phi::make_intrusive<paddle::experimental::SharedStorage>(place),
phi::DenseTensorMeta(paddle::framework::TransToPtenDataType(dtype),
ddims));
if (phi::product(ddims) > 0) {
dense_tensor->mutable_data(place);
}
self->tensor.set_impl(dense_tensor);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"We only support LoDTensor to be constructed by this initializer, "
"please check your var type first and make sure you are going to "
"construct LoDTensor."));
}
if (!autograd_meta->GetMutableGradNode()) {
......
......@@ -17,7 +17,7 @@ from __future__ import print_function
import os
import collections
import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer, EagerParamBase
import pickle
from . import learning_rate_scheduler
import warnings
......@@ -94,7 +94,7 @@ def save_dygraph(state_dict, model_path):
param_num = 0
for k, v in state_dict.items():
if isinstance(v, ParamBase):
if isinstance(v, (ParamBase, EagerParamBase)):
param_num += 1
if param_num == 0:
......@@ -103,7 +103,7 @@ def save_dygraph(state_dict, model_path):
model_dict = {}
name_table = {}
for k, v in state_dict.items():
if isinstance(v, (Variable, core.VarBase)):
if isinstance(v, (Variable, core.VarBase, core.eager.Tensor)):
model_dict[k] = v.numpy()
name_table[k] = v.name
else:
......
......@@ -535,6 +535,14 @@ def _load_persistable_vars_by_program(model_path,
orig_each_name = program_holder._suffix_varname_dict[each_var.name()]
if _is_parameter(each_var, program_holder.infer_program):
# create output varbase
if framework._in_eager_mode():
new_var = framework.EagerParamBase(
shape=each_var.shape(),
dtype=each_var.dtype(),
name=each_var.name(),
type=each_var.type(),
persistable=True)
else:
new_var = framework.ParamBase(
shape=each_var.shape(),
dtype=each_var.dtype(),
......@@ -620,8 +628,19 @@ def _load_persistable_vars(model_path, var_info_path, program_holder,
# create output varbase
if extra_var_info[name].get('trainable', None) is not None:
# use default shape and dtype
if framework._in_eager_mode():
new_var = framework.EagerParamBase(
shape=[
1
], # only to pass check, this shape is not meaningful
dtype=core.VarDesc.VarType.FP32,
name=new_name,
persistable=True)
else:
new_var = framework.ParamBase(
shape=[1], # only to pass check, this shape is not meaningful
shape=[
1
], # only to pass check, this shape is not meaningful
dtype=core.VarDesc.VarType.FP32,
name=new_name,
persistable=True)
......@@ -747,12 +766,20 @@ def _run_dygraph(instance, input, program_holder):
# 1. prepare inputs, outputs, attrs
input_vars = []
for i, value in enumerate(input):
if not isinstance(value, (np.ndarray, core.VarBase)):
if not isinstance(value, (np.ndarray, core.VarBase, core.eager.Tensor)):
raise TypeError(
"The type of input in TranslatedLayer must be numpy array or Variable(VarBase), but received %s."
% type(value))
# NOTE: In order to unify the API, firstly convert the input to VarBase
if isinstance(value, np.ndarray):
if framework._in_eager_mode():
var = core.eager.Tensor(
value=value,
name=program_holder.input_descs[i].name(),
persistable=False,
place=framework._current_expected_place(),
zero_copy=True)
else:
var = core.VarBase(
value=value,
name=program_holder.input_descs[i].name(),
......@@ -784,12 +811,28 @@ def _run_dygraph(instance, input, program_holder):
output_vars = []
for var_desc in program_holder.output_descs:
if framework._in_eager_mode():
var = core.eager.Tensor(
dtype=var_desc.dtype(),
dims=var_desc.shape(),
name=var_desc.name(),
type=var_desc.type(),
persistable=False)
else:
var = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
output_vars.append(var)
# hold forward variables
if framework._in_eager_mode():
tmp_scope_vec = core.eager.Tensor(
dtype=core.VarDesc.VarType.FP32,
dims=[],
name="program_out_scope",
type=core.VarDesc.VarType.STEP_SCOPES,
persistable=True)
else:
tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
"program_out_scope",
core.VarDesc.VarType.STEP_SCOPES, True)
......@@ -797,11 +840,27 @@ def _run_dygraph(instance, input, program_holder):
double_grad_vars = []
for var_desc in program_holder.double_grad_descs:
if framework._in_eager_mode():
var = core.eager.Tensor(
dtype=var_desc.dtype(),
dims=var_desc.shape(),
name=var_desc.name(),
type=var_desc.type(),
persistable=False)
else:
var = core.VarBase(var_desc.dtype(),
var_desc.shape(),
var_desc.name(), var_desc.type(), False)
double_grad_vars.append(var)
if len(double_grad_vars) == 0:
if framework._in_eager_mode():
double_grad_vars = [
core.eager.Tensor(
value=[1],
name='Fake_var',
place=framework._current_expected_place())
]
else:
double_grad_vars = [
core.VarBase(
value=[1],
......@@ -1215,11 +1274,12 @@ class TranslatedLayer(layers.Layer):
# the TranslatedLayer object holded var names count started from 0
with unique_name.guard():
for name, var in persistable_vars.items():
if isinstance(var, framework.ParamBase):
if isinstance(var,
(framework.ParamBase, framework.EagerParamBase)):
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name
self.add_parameter(dy_name, var)
elif isinstance(var, core.VarBase):
elif isinstance(var, (core.VarBase, core.eager.Tensor)):
dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
self._persistable_var_name_dict[name] = dy_name
self.register_buffer(dy_name, var)
......
......@@ -27,6 +27,7 @@ from test_imperative_base import new_program_scope
import numpy as np
import six
import paddle
from paddle.fluid.framework import _test_eager_guard
class SimpleLSTMRNN(fluid.Layer):
......@@ -208,7 +209,7 @@ class PtbModel(fluid.Layer):
class TestDygraphPtbRnn(unittest.TestCase):
def setUp(self):
def func_setUp(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -277,7 +278,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.opti_dict = adam.state_dict()
self.base_opti = {}
for k, v in self.opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.base_opti[v.name] = v.numpy()
self.assertTrue(np.sum(np.abs(v.numpy())) != 0)
else:
......@@ -294,7 +295,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
fluid.save_dygraph(self.state_dict, "./test_dy")
def testLoadAndSetVarBase(self):
def func_testLoadAndSetVarBase(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -363,7 +364,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
# set to zero
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_t = v.numpy()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......@@ -374,11 +375,12 @@ class TestDygraphPtbRnn(unittest.TestCase):
adam._learning_rate.step_num = 0
para_state_dict, opti_state_dict = fluid.load_dygraph("./test_dy")
print(opti_state_dict.keys())
adam.set_state_dict(opti_state_dict)
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
......@@ -403,7 +405,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
def testSetVariable(self):
def func_testSetVariable(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -472,7 +474,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
# set to zero
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_t = v.numpy()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......@@ -485,7 +487,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
adam.set_state_dict(self.opti_dict)
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
......@@ -510,7 +512,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
def testSetNumpy(self):
def func_testSetNumpy(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -580,7 +582,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_opti_dict = {}
# set to zero
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_t = v.numpy()
np_opti_dict[v.name] = np_t
var = v.value().get_tensor()
......@@ -596,7 +598,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
......@@ -623,7 +625,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
def testSetVariableBeforeTrain(self):
def func_testSetVariableBeforeTrain(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -700,7 +702,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
base_t = self.model_base[k]
self.assertTrue(np.array_equal(new_t, base_t))
def testLoadAndSetVarBaseBeforeTrain(self):
def func_testLoadAndSetVarBaseBeforeTrain(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -791,7 +793,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
base_t = self.model_base[k]
self.assertTrue(np.array_equal(new_t, base_t))
def testSetNumpyBeforeTrain(self):
def func_testSetNumpyBeforeTrain(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -840,7 +842,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_state_dict = {}
for k, v in self.opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_opti_dict[v.name] = v.numpy()
else:
np_opti_dict[k] = v
......@@ -894,7 +896,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
base_t = self.model_base[k]
self.assertTrue(np.array_equal(new_t, base_t))
def testOnlyLoadParams(self):
def func_testOnlyLoadParams(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
......@@ -911,7 +913,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
para_state_dict, opti_state_dict = fluid.load_dygraph(
os.path.join('saved_dy', 'emb_dy.pdopt'))
def test_load_compatible_with_keep_name_table(self):
def func_test_load_compatible_with_keep_name_table(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
......@@ -922,6 +924,27 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(para_state_dict != None)
self.assertTrue(opti_state_dict == None)
def test_main(self):
self.func_setUp()
self.func_testLoadAndSetVarBase()
self.func_testSetVariable()
self.func_testSetNumpy()
self.func_testSetVariableBeforeTrain()
self.func_testLoadAndSetVarBaseBeforeTrain()
self.func_testSetNumpyBeforeTrain()
self.func_testOnlyLoadParams()
self.func_test_load_compatible_with_keep_name_table()
with _test_eager_guard():
self.func_setUp()
self.func_testLoadAndSetVarBase()
self.func_testSetVariable()
self.func_testSetNumpy()
self.func_testSetVariableBeforeTrain()
self.func_testLoadAndSetVarBaseBeforeTrain()
self.func_testSetNumpyBeforeTrain()
self.func_testOnlyLoadParams()
self.func_test_load_compatible_with_keep_name_table()
if __name__ == '__main__':
unittest.main()
......@@ -27,6 +27,7 @@ from test_imperative_base import new_program_scope
import numpy as np
import six
import paddle
from paddle.fluid.framework import _test_eager_guard
class SimpleLSTMRNN(fluid.Layer):
......@@ -208,7 +209,7 @@ class PtbModel(fluid.Layer):
class TestDygraphPtbRnn(unittest.TestCase):
def setUp(self):
def func_setUp(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -279,7 +280,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.opti_dict = adam.state_dict()
self.base_opti = {}
for k, v in self.opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.base_opti[v.name] = v.numpy()
self.assertTrue(np.sum(np.abs(v.numpy())) != 0)
else:
......@@ -296,7 +297,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
paddle.save(self.state_dict, "./test_dy_v2.pdparams")
def testLoadAndSetVarBase(self):
def func_testLoadAndSetVarBase(self):
self.setUp()
seed = 90
hidden_size = 10
......@@ -367,7 +368,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
# set to zero
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_t = v.numpy()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......@@ -380,7 +381,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
......@@ -405,7 +406,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
def testSetVariable(self):
def func_testSetVariable(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -475,7 +476,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
# set to zero
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_t = v.numpy()
var = v.value().get_tensor()
var.set(np.zeros_like(np_t), place)
......@@ -488,7 +489,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
adam.set_state_dict(self.opti_dict)
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
......@@ -513,7 +514,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
def testSetNumpy(self):
def func_testSetNumpy(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -584,7 +585,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_opti_dict = {}
# set to zero
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_t = v.numpy()
np_opti_dict[v.name] = np_t
var = v.value().get_tensor()
......@@ -600,7 +601,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
opti_dict = adam.state_dict()
for k, v in opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
self.assertTrue(
np.array_equal(v.numpy(), self.base_opti[v.name]))
else:
......@@ -627,7 +628,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
self.assertTrue(np.array_equal(new_t, base_t))
def testSetVariableBeforeTrain(self):
def func_testSetVariableBeforeTrain(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -706,7 +707,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
base_t = self.model_base[k]
self.assertTrue(np.array_equal(new_t, base_t))
def testLoadAndSetVarBaseBeforeTrain(self):
def func_testLoadAndSetVarBaseBeforeTrain(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -797,7 +798,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
base_t = self.model_base[k]
self.assertTrue(np.array_equal(new_t, base_t))
def testSetNumpyBeforeTrain(self):
def func_testSetNumpyBeforeTrain(self):
seed = 90
hidden_size = 10
vocab_size = 1000
......@@ -846,7 +847,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
np_state_dict = {}
for k, v in self.opti_dict.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
np_opti_dict[v.name] = v.numpy()
else:
np_opti_dict[k] = v
......@@ -902,7 +903,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
base_t = self.model_base[k]
self.assertTrue(np.array_equal(new_t, base_t))
def testOnlyLoadParams(self):
def func_testOnlyLoadParams(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
......@@ -911,7 +912,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
para_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy.pdparams'))
def test_no_state_in_input_dict(self):
def func_test_no_state_in_input_dict(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
......@@ -923,7 +924,7 @@ class TestDygraphPtbRnn(unittest.TestCase):
emb.set_state_dict(para_state_dict)
def test_state_shape_mismatch(self):
def func_test_state_shape_mismatch(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
......@@ -936,6 +937,29 @@ class TestDygraphPtbRnn(unittest.TestCase):
emb.set_state_dict(para_state_dict)
def test_main(self):
self.func_setUp()
self.func_testLoadAndSetVarBase()
self.func_testSetVariable()
self.func_testSetNumpy()
self.func_testSetVariableBeforeTrain()
self.func_testLoadAndSetVarBaseBeforeTrain()
self.func_testSetNumpyBeforeTrain()
self.func_testOnlyLoadParams()
self.func_test_no_state_in_input_dict()
self.func_test_state_shape_mismatch()
with _test_eager_guard():
self.func_setUp()
self.func_testLoadAndSetVarBase()
self.func_testSetVariable()
self.func_testSetNumpy()
self.func_testSetVariableBeforeTrain()
self.func_testLoadAndSetVarBaseBeforeTrain()
self.func_testSetNumpyBeforeTrain()
self.func_testOnlyLoadParams()
self.func_test_no_state_in_input_dict()
self.func_test_state_shape_mismatch()
if __name__ == '__main__':
unittest.main()
......@@ -30,7 +30,7 @@ from paddle.fluid.io import _unpack_saved_dict, _pack_loaded_dict, _pickle_loads
from paddle.fluid.io import _legacy_save as _legacy_static_save
from paddle.fluid.io import _open_file_buffer, _is_file_path, _is_memory_buffer
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, _current_expected_place, Program
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer, in_dygraph_mode, ParamBase, EagerParamBase, _current_expected_place, Program
from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
......@@ -42,7 +42,7 @@ def _build_saved_state_dict(state_dict):
save_dict = {}
name_table = {}
for key, value in state_dict.items():
if isinstance(value, (Variable, core.VarBase)):
if isinstance(value, (Variable, core.VarBase, core.eager.Tensor)):
if value.type == core.VarDesc.VarType.VOCAB:
save_dict[key] = value.value().get_map_tensor()
else:
......@@ -260,6 +260,8 @@ def _pickle_save(obj, f, protocol):
# This is not a good method, because the pickle module has been modified.
pickle.dispatch_table[core.VarBase] = reduce_varbase
pickle.dispatch_table[ParamBase] = reduce_varbase
pickle.dispatch_table[core.eager.Tensor] = reduce_varbase
pickle.dispatch_table[EagerParamBase] = reduce_varbase
pickle.dispatch_table[core.LoDTensor] = reduce_LoDTensor
pickle.dispatch_table.update(dispatch_table_layer)
......@@ -267,6 +269,8 @@ def _pickle_save(obj, f, protocol):
pickle.dispatch_table.pop(core.VarBase)
pickle.dispatch_table.pop(core.LoDTensor)
pickle.dispatch_table.pop(ParamBase)
pickle.dispatch_table.pop(core.eager.Tensor)
pickle.dispatch_table.pop(EagerParamBase)
for k in dispatch_table_layer:
pickle.dispatch_table.pop(k)
......@@ -286,6 +290,8 @@ def _pickle_save(obj, f, protocol):
pickler.dispatch_table[core.VarBase] = reduce_varbase
pickler.dispatch_table[core.LoDTensor] = reduce_LoDTensor
pickler.dispatch_table[ParamBase] = reduce_varbase
pickler.dispatch_table[core.eager.Tensor] = reduce_varbase
pickler.dispatch_table[EagerParamBase] = reduce_varbase
pickler.dispatch_table.update(dispatch_table_layer)
pickler.dump(obj)
......@@ -317,7 +323,8 @@ def _is_state_dict(obj):
def condition(obj):
return isinstance(obj, (fluid.Layer, Program, core.VarBase,
core.LoDTensor, core.SelectedRows))
core.eager.Tensor, core.LoDTensor,
core.SelectedRows))
# If the value of a dict is a core.VarBase/LoDTensor or a dict
# that does not contain a paddle type(Layer, Program, VarBase, LoDTensor, SelectedRows),
......@@ -327,7 +334,8 @@ def _is_state_dict(obj):
for k, v in value.items():
if _contain_x(v, condition):
return False
elif not isinstance(value, (core.VarBase, core.LoDTensor)):
elif not isinstance(value, (core.VarBase, core.eager.Tensor,
core.LoDTensor)):
return False
return True
......@@ -412,8 +420,9 @@ def _parse_every_object(obj, condition_func, convert_func):
elif type(obj) == set:
return set(_parse_every_object(list(obj), condition_func, convert_func))
else:
if isinstance(obj, collections.Iterable) and not isinstance(obj, (
str, np.ndarray, core.VarBase, core.LoDTensor)):
if isinstance(obj, collections.Iterable) and not isinstance(
obj,
(str, np.ndarray, core.VarBase, core.eager.Tensor, core.LoDTensor)):
raise NotImplementedError(
"The iteratable objects supported are tuple, list, dict, OrderedDict, string. But received {}.".
format(type(obj)))
......@@ -541,7 +550,7 @@ def _save_binary_var(obj, path):
_save_lod_tensor(obj, path)
elif isinstance(obj, core.SelectedRows):
_save_selected_rows(obj, path)
elif isinstance(obj, core.VarBase):
elif isinstance(obj, (core.VarBase, core.eager.Tensor)):
_save_lod_tensor(obj.value().get_tensor(), path)
else:
# Since the concept of 'Tensor' is only exposed to users, the error message can only contain tensor instead of 'LoDTensor' or 'SelectedRows'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册