未验证 提交 faa8de5d 编写于 作者: D daquexian 提交者: GitHub

consistent init/save/load (#5896)

* single device
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* support consistent load and save
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* add consistent init test
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* reformat
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* refine
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* refine
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* refine CpuBroadcast(py::bytes* in, int64_t root)
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* refine CpuBroadcast(py::bytes* in, int64_t root)
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* refine CpuBroadcast(py::bytes* in, int64_t root)
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* auto format by CI

* fix test script path
Signed-off-by: Ndaquexian <daquexian566@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
上级 73a9e421
......@@ -563,7 +563,7 @@ jobs:
${{ env.extra_docker_args }} ${{ env.pip_cache_docker_args }} \
-e ONEFLOW_TEST_DIR=$PWD/python/oneflow/test/tensor \
${{ env.image_tag }} \
bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test.sh"
bash -c "python3 -m pip config set global.index-url ${{ env.pip_index_mirror }} && bash ci/test/try_install.sh && bash ci/test/generic_test_multi_client.sh"
- name: Graph API test
if: contains(fromJson('["cuda_new_interface", "cpu_new_interface"]'), matrix.test_suite)
run: |
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <pybind11/pybind11.h>
#include <pybind11/pytypes.h>
#include "oneflow/api/python/of_api_registry.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/core/ccl/ccl.h"
#include "oneflow/core/rpc/include/global_process_ctx.h"
#include "oneflow/core/job/rank_group.h"
namespace py = pybind11;
namespace oneflow {
namespace {
Maybe<py::bytes> CpuBroadcast(py::bytes* in, int64_t root) {
const auto& rank_group = JUST(RankGroup::DefaultRankGroup());
const auto& parallel_desc = JUST(RankGroup::GetDefaultParallelDesc(DeviceType::kCPU, rank_group));
Py_ssize_t length;
char* buffer;
if (GlobalProcessCtx::Rank() == root) {
CHECK_NOTNULL_OR_RETURN(in);
PyBytes_AsStringAndSize(in->ptr(), &buffer, &length);
}
JUST(ccl::Broadcast<DeviceType::kCPU>(&length, &length, sizeof(length), DataType::kChar, root,
parallel_desc, nullptr));
if (GlobalProcessCtx::Rank() == root) {
JUST(ccl::Broadcast<DeviceType::kCPU>(buffer, buffer, length, DataType::kChar, root,
parallel_desc, nullptr));
return *in;
} else {
// https://github.com/pybind/pybind11/issues/1236#issuecomment-527730864
PyBytesObject* bytesObject =
static_cast<PyBytesObject*>(PyObject_Malloc(offsetof(PyBytesObject, ob_sval) + length + 1));
PyObject_INIT_VAR(bytesObject, &PyBytes_Type, length);
bytesObject->ob_shash = -1;
bytesObject->ob_sval[length] = '\0';
buffer = bytesObject->ob_sval;
JUST(ccl::Broadcast<DeviceType::kCPU>(nullptr, buffer, length, DataType::kChar, root,
parallel_desc, nullptr));
return py::reinterpret_steal<py::bytes>(reinterpret_cast<PyObject*>(bytesObject));
}
}
} // namespace
ONEFLOW_API_PYBIND11_MODULE("", m) {
m.def("cpu_broadcast", [](py::bytes in, int64_t root) -> py::bytes {
return CpuBroadcast(&in, root).GetOrThrow();
});
m.def("cpu_broadcast", [](py::none in, int64_t root) -> py::bytes {
return CpuBroadcast(nullptr, root).GetOrThrow();
});
}
} // namespace oneflow
......@@ -294,7 +294,7 @@ from oneflow.nn.modules.unsqueeze import unsqueeze_op as unsqueeze
from oneflow.nn.modules.where import where_op as where
from oneflow.nn.modules.scatter import *
from oneflow.ops.builtin_ops import BuiltinOp as builtin_op
from oneflow.ops.initializer_util import constant_initializer, empty_initializer
from oneflow.ops.initializer_util import constant_initializer
from oneflow.ops.initializer_util import glorot_normal_initializer
from oneflow.ops.initializer_util import (
glorot_normal_initializer as xavier_normal_initializer,
......
......@@ -20,14 +20,15 @@ import numpy as np
from google.protobuf import text_format
import oneflow
import oneflow as flow
import oneflow._oneflow_internal
import oneflow.core.framework.variable_meta_info_pb2 as variable_meta_info_pb
import oneflow.framework.dtype as dtype_util
import pickle
SNAPSHOT_DONE_FILENAME = "snapshot_done"
META_INFO_FILENAME = "meta"
DATA_FILENAME = "out"
FAKE_JOB_NAME = "system_checkpoint"
OP_PREFIX = "system_checkpoint"
class FileBackendVariableBlob:
......@@ -100,97 +101,77 @@ def _ElemCnt(shape):
return np.prod(shape).astype(int).item()
def _LoadSingleVariable(path: str) -> Optional[FileBackendVariableBlob]:
if os.path.isfile(os.path.join(path, DATA_FILENAME)):
return FileBackendVariableBlob(path)
return None
def _LoadSingleVariable(
path: Optional[str], consistent_src_rank: Optional[int] = None
) -> "flow.Tensor":
if consistent_src_rank is not None:
rank = flow.framework.distribute.get_rank()
if rank == consistent_src_rank:
assert isinstance(path, str)
file_backed_blob = FileBackendVariableBlob(path)
loaded = flow.tensor(
file_backed_blob.numpy(), dtype=file_backed_blob.dtype
).to("cuda")
else:
loaded = flow.tensor([]).to("cuda")
loaded = loaded.to_consistent(
flow.placement("cuda", {0: [0]}), flow.sbp.broadcast
)
return loaded
def _GetCheckpoint(
path: str,
) -> Union[Dict[str, FileBackendVariableBlob], FileBackendVariableBlob]:
assert os.path.isdir(path), "Directory {} doesn't exist!".format(path)
single_var = _LoadSingleVariable(path)
if single_var is not None:
return single_var
var_dict = {}
for f in os.listdir(path):
var_dir = os.path.join(path, f)
var = _LoadSingleVariable(var_dir)
if var is not None:
var_dict[f] = var
return var_dict
assert isinstance(path, str)
return flow.tensor(FileBackendVariableBlob(path).numpy())
def GetCheckpoint(
path: str,
) -> Union[Dict[str, FileBackendVariableBlob], FileBackendVariableBlob]:
"""
Load variable(s) from file system.
"""
return _GetCheckpoint(path)
def _broadcast_py_object(obj, src: int = 0):
rank = flow.framework.distribute.get_rank()
if src == rank:
obj_bytes = pickle.dumps(obj)
return pickle.loads(flow._oneflow_internal.cpu_broadcast(obj_bytes, src))
else:
return pickle.loads(flow._oneflow_internal.cpu_broadcast(None, src))
def Load(
path: str,
) -> Union[Dict[str, FileBackendVariableBlob], FileBackendVariableBlob]:
return _GetCheckpoint(path)
def _ReadSlice(
container: ValueContainer,
) -> Iterable[Tuple[Sequence[int], Sequence[int], np.ndarray]]:
"""
Return a generator which iterates over the input blob or array and yields
(start_nd_idx, stop_nd_idx, slice_np_array)
"""
if isinstance(container, oneflow.Tensor):
def ReadFromTensor(tensor, start_nd_idx, stop_nd_idx):
start_nd_idx = list(map(int, start_nd_idx))
stop_nd_idx = list(map(int, stop_nd_idx))
return tensor[
tuple(
[
slice(start_nd_idx[i], stop_nd_idx[i])
for i in range(len(start_nd_idx))
]
)
].numpy()
yield from _ForEachSlice(container, ReadFromTensor)
elif isinstance(container, FileBackendVariableBlob):
np_dtype = np.dtype(
dtype_util.convert_oneflow_dtype_to_numpy_dtype(container.dtype)
)
with open(container.file_path, "rb") as f:
def ReadFromFile(_, start_nd_idx, stop_nd_idx):
length = _ElemCnt(np.array(stop_nd_idx) - np.array(start_nd_idx))
slice = f.read(length * np_dtype.itemsize)
return np.frombuffer(slice, dtype=np_dtype).reshape(
np.array(stop_nd_idx) - np.array(start_nd_idx)
)
yield from _ForEachSlice(container, ReadFromFile)
elif isinstance(container, np.ndarray):
def ReadFromNpArray(array, start_nd_idx, stop_nd_idx):
slice_objs = []
for (start, stop) in zip(start_nd_idx, stop_nd_idx):
slice_objs.append(slice(start, stop))
return array[tuple(slice_objs)]
yield from _ForEachSlice(container, ReadFromNpArray)
path: str, consistent_src_rank: Optional[int] = None,
) -> Dict[str, "flow.Tensor"]:
assert os.path.isdir(path), "Directory {} doesn't exist!".format(path)
rank = flow.framework.distribute.get_rank()
var_dict = {}
if consistent_src_rank is None or rank == consistent_src_rank:
all_files = os.listdir(path)
assert SNAPSHOT_DONE_FILENAME in all_files
all_files.remove(SNAPSHOT_DONE_FILENAME)
if consistent_src_rank is not None:
_broadcast_py_object(all_files, consistent_src_rank)
else:
raise RuntimeError("Unknown type: {}".format(type(container).__name__))
all_files = _broadcast_py_object(None, consistent_src_rank)
for f in all_files:
var_dir = os.path.join(path, f)
var_dict[f] = _LoadSingleVariable(var_dir, consistent_src_rank)
return var_dict
def _SaveVarDict(
path: str, var_dict: Optional[Dict[str, FileBackendVariableBlob]] = None,
def save(
var_dict: Dict[str, "flow.Tensor"],
path: str,
consistent_dst_rank: Optional[int] = None,
) -> None:
if var_dict is None:
var_dict = GetAllVariables()
consistent_mode = consistent_dst_rank is not None
for (name, var) in var_dict.items():
if consistent_mode:
assert (
var.is_consistent
), f"consistent tensor is needed, but {name} is a local tensor"
var_dict[name] = var.to_consistent(sbp=flow.sbp.broadcast).to_local()
else:
assert (
not var.is_consistent
), f"local tensor is needed, but {name} is a consistent tensor"
rank = flow.framework.distribute.get_rank()
if consistent_mode and rank != consistent_dst_rank:
return
def IsFileOrNonEmptyDir(path):
if os.path.isfile(path):
......@@ -205,6 +186,7 @@ def _SaveVarDict(
path
)
os.makedirs(path, exist_ok=True)
for (name, var) in var_dict.items():
meta_info = variable_meta_info_pb.VariableMetaInfo()
meta_info.shape.dim[:] = var.shape
......@@ -215,74 +197,14 @@ def _SaveVarDict(
param_path = os.path.join(var_dir, DATA_FILENAME)
os.makedirs(os.path.dirname(param_path))
with open(param_path, "wb") as f:
for (_, _, slice) in _ReadSlice(var):
f.write(slice.tobytes())
f.write(var.numpy().tobytes())
with open(os.path.join(var_dir, META_INFO_FILENAME), "w") as f:
f.write(text_format.MessageToString(meta_info))
with open(os.path.join(path, "snapshot_done"), "w"):
with open(os.path.join(path, SNAPSHOT_DONE_FILENAME), "w"):
pass
def SaveVarDict(
path: str, var_dict: Optional[Dict[str, FileBackendVariableBlob]] = None,
) -> None:
"""
Save `var_dict` to `path`
"""
return _SaveVarDict(path, var_dict)
def save(obj, save_dir):
return _SaveVarDict(save_dir, obj)
def _ForEachSlice(
container: ValueContainer,
f: Union[
Callable[[Sequence[int], Sequence[int]], Any],
Callable[[FileBackendVariableBlob, Sequence[int], Sequence[int]], Any],
Callable[[np.ndarray, Sequence[int], Sequence[int]], Any],
],
):
"""
Slice container into slices whose size < SLICE_BYTES. For every slice,
yield start_nd_idx, stop_nd_idx and f(slice)
"""
assert isinstance(
container, (FileBackendVariableBlob, np.ndarray, oneflow.Tensor)
), "Unknown type: {}".format(type(container).__name__)
assert container.shape is not None
SLICE_BYTES = 32 * 1024 * 1024
if isinstance(container, np.ndarray):
np_dtype = container.dtype
else:
np_dtype = np.dtype(
dtype_util.convert_oneflow_dtype_to_numpy_dtype(container.dtype)
)
SLICE_LEN = SLICE_BYTES // np_dtype.itemsize
start_idx = 0
size = _ElemCnt(container.shape)
cnt = 1
for axis in reversed(range(len(container.shape))):
cnt *= container.shape[axis]
if cnt > SLICE_LEN:
break
unit_size = _ElemCnt(tuple(container.shape)[axis + 1 :])
max_unit_num = SLICE_LEN // unit_size
while start_idx < size:
remainder = container.shape[axis]
while remainder > 0:
unit_num = max_unit_num if remainder >= max_unit_num else remainder
length = unit_num * unit_size
remainder -= unit_num
stop_idx = start_idx + length
start_nd_idx = np.unravel_index(start_idx, container.shape)
stop_nd_idx = np.unravel_index(stop_idx - 1, container.shape)
stop_nd_idx = tuple([x + 1 for x in stop_nd_idx])
yield (start_nd_idx, stop_nd_idx, f(container, start_nd_idx, stop_nd_idx))
start_idx = stop_idx
def generate_values_by_initializer(initializer, shape, dtype):
np_dtype = np.dtype(dtype_util.convert_oneflow_dtype_to_numpy_dtype(dtype))
length = _ElemCnt(shape)
......
......@@ -28,9 +28,9 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np
import oneflow as flow
import oneflow._oneflow_internal
import oneflow.framework.dtype as dtype_util
from oneflow.framework.check_point_v2 import GetCheckpoint, SaveVarDict
from oneflow.framework.function_util import FunctionConfig as ExecutionConfig
from oneflow.framework.tensor import Tensor
from oneflow.nn.module import Module
......@@ -573,12 +573,14 @@ class CheckpointModel(SubModel):
def _load_checkpoint(self, dirpath: str):
"""Load model states from a checkpoint.
"""
LoadVariables(GetCheckpoint(path=dirpath))
stat_dict = flow.load(path=dirpath)
self._model.load_state_dict(stat_dict)
def _save_checkpoint(self, dirpath: str):
"""Save model states as a checkpoint.
"""
SaveVarDict(path=dirpath)
stat_dict = self._model.state_dict()
flow.save(stat_dict, dirpath)
class TrainModelOOPStyle(SubModel):
......@@ -703,14 +705,14 @@ class CheckpointModelOOPStyle(SubModel):
def _load_checkpoint(self, dirpath: str):
"""Load model states from a checkpoint.
"""
stat_dict = GetCheckpoint(path=dirpath)
stat_dict = flow.load(path=dirpath)
self._model.load_state_dict(stat_dict)
def _save_checkpoint(self, dirpath: str):
"""Save model states as a checkpoint.
"""
stat_dict = self._model.state_dict()
SaveVarDict(path=dirpath, var_dict=stat_dict)
flow.save(stat_dict, dirpath)
def _infer_job_signature(data_module, batch, optimizer_idx, job):
......
......@@ -101,6 +101,18 @@ def _setitem(self, key, value):
value = flow.F.consistent_constant(
[1], value, self.dtype, placement=self.placement, sbp=flow.sbp.broadcast
)
else:
if value.is_consistent:
value = value.to_consistent(sbp=flow.sbp.broadcast)
# TODO: remove these lines after asymmetric boxing is ready
local_tensor = value.to_local()
if local_tensor.nelement() == 0:
local_tensor = flow.zeros(*value.shape)
value = local_tensor.to_consistent(
self.placement, sbp=flow.sbp.broadcast
)
else:
value = value.to_consistent(self.placement, sbp=flow.sbp.broadcast)
else:
if isinstance(value, (int, float)):
value = flow.F.constant([1], value, self.dtype, device=self.device)
......@@ -297,70 +309,41 @@ def _copy_from_numpy_to_eager_local_tensor(eager_local_tensor, np_arr):
copy_from_numpy(np_arr)
def _init_eager_local_tensor_by_initializer_conf(
eager_local_tensor, initializer_conf, random_seed=None
):
def _init_by_initializer_conf(tensor, initializer_conf, random_seed=None):
if random_seed is None:
random_seed = flow.default_generator().seed()
shape = tuple(eager_local_tensor.shape)
shape = tuple(tensor.shape)
initializer = initializer_util.GetInitializer(initializer_conf, random_seed, shape)
# initializer is None if and only if the initializer_conf is empty_initializer
if initializer is None:
return
_copy_from_numpy_to_eager_local_tensor(
eager_local_tensor,
check_point_v2.generate_values_by_initializer(
initializer, shape, eager_local_tensor.dtype
),
)
def _init_by_initializer_conf(tensor, initializer_conf):
np_arr = check_point_v2.generate_values_by_initializer(
initializer, shape, tensor.dtype
)
if tensor.is_consistent:
raise NotImplementedError(" consistent initializer unvailiable now")
else:
_init_eager_local_tensor_by_initializer_conf(tensor, initializer_conf)
return tensor
def _convert_to_placement_scope(placement_or_device):
if isinstance(placement_or_device, flow.placement):
placement = placement_or_device
return flow.scope.placement(
placement.device_tag,
list(placement.parallel_conf.device_name()),
placement.hierarchy,
src_tensor = flow.tensor(np_arr)
src_tensor = src_tensor.to_consistent(
placement=tensor.placement, sbp=flow.sbp.broadcast
)
tensor.copy_(src_tensor)
else:
device = placement_or_device
# TODO(jianhao): replace 0 with real machine id
machine_id = 0
# TODO(jianhao): support cuda in of
if device.type == "cuda":
device_tag = "gpu"
else:
device_tag = device.type
return flow.scope.placement(
device_tag, "{}:{}".format(machine_id, device.index), None
_copy_from_numpy_to_eager_local_tensor(
tensor, np_arr,
)
def _placement_scope(self):
if self.is_consistent:
return _convert_to_placement_scope(self.placement)
else:
return _convert_to_placement_scope(self.device)
return tensor
def _copy(self, other: Union[Tensor, np.ndarray]):
if isinstance(other, (Tensor, check_point_v2.FileBackendVariableBlob)):
src_np = other.numpy()
if self.is_consistent:
assert isinstance(other, Tensor)
assert other.is_consistent
self[:] = other
else:
assert isinstance(other, np.ndarray)
src_np = other
if isinstance(other, (Tensor)):
src_np = other.numpy()
else:
assert isinstance(other, np.ndarray)
src_np = other
_copy_from_numpy_to_eager_local_tensor(self, src_np)
_copy_from_numpy_to_eager_local_tensor(self, src_np)
def _get_device(self):
......@@ -420,7 +403,6 @@ def RegisterMethods():
Tensor.xavier_uniform_ = _xavier_uniform
Tensor.normal_ = _normal
Tensor.fill_ = _fill
Tensor._placement_scope = _placement_scope
Tensor.copy_ = _copy
Tensor.get_device = _get_device
Tensor._meta_repr = _meta_repr
......
......@@ -373,7 +373,8 @@ class Module(object):
)
continue
try:
param.copy_(input_param)
with flow.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", whose dimensions in the model are {} and whose dimensions in the checkpoint are {}, an exception occurred : {}.'.format(
......
......@@ -24,15 +24,6 @@ import oneflow.core.job.initializer_conf_pb2 as initializer_conf_util
import oneflow.core.operator.op_conf_pb2 as op_conf_util
def empty_initializer(
dtype: flow.dtype = flow.float,
) -> initializer_conf_util.InitializerConf:
initializer = initializer_conf_util.InitializerConf()
empty_conf = initializer_conf_util.EmptyInitializerConf()
initializer.empty_conf.CopyFrom(empty_conf)
return initializer
def constant_initializer(
value: float = 0, dtype: flow.dtype = flow.float
) -> initializer_conf_util.InitializerConf:
......
......@@ -14,6 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import collections.abc
import tempfile
import unittest
......@@ -30,8 +31,8 @@ def np_relu(np_arr):
return np.where(np_arr > 0, np_arr, 0)
@flow.unittest.skip_unless_1n1d()
class TestModule(flow.unittest.TestCase):
@flow.unittest.skip_unless_1n1d()
def test_nested_module(test_case):
class CustomModule(flow.nn.Module):
def __init__(self):
......@@ -47,6 +48,7 @@ class TestModule(flow.unittest.TestCase):
y = m(x)
test_case.assertTrue(np.array_equal(np_relu(x.numpy()), y.numpy()))
@flow.unittest.skip_unless_1n1d()
def test_relu(test_case):
relu = flow.nn.ReLU()
x = flow.Tensor(2, 3)
......@@ -54,6 +56,7 @@ class TestModule(flow.unittest.TestCase):
y = relu(x)
test_case.assertTrue(np.array_equal(np_relu(x.numpy()), y.numpy()))
@flow.unittest.skip_unless_1n1d()
def test_load_state_dict(test_case):
class CustomModule(flow.nn.Module):
def __init__(self):
......@@ -70,6 +73,7 @@ class TestModule(flow.unittest.TestCase):
y = m(x).numpy()
test_case.assertTrue(np.array_equal(y, ones))
@flow.unittest.skip_unless_1n1d()
def test_state_dict(test_case):
class CustomModule(flow.nn.Module):
def __init__(self, param1, param2):
......@@ -87,6 +91,7 @@ class TestModule(flow.unittest.TestCase):
{"param2.param1": tensor0, "param2.param2": tensor1, "param1": tensor1},
)
@flow.unittest.skip_unless_1n1d()
def test_parameter(test_case):
shape = (3, 4)
t = flow.Tensor(*shape)
......@@ -94,6 +99,7 @@ class TestModule(flow.unittest.TestCase):
test_case.assertEqual(type(p), flow.nn.Parameter)
test_case.assertEqual(p.shape, shape)
@flow.unittest.skip_unless_1n1d()
def test_module_forward(test_case):
class CustomModule(flow.nn.Module):
def __init__(self, w):
......@@ -108,6 +114,7 @@ class TestModule(flow.unittest.TestCase):
m = CustomModule(4)
test_case.assertEqual(m(3), 7)
@flow.unittest.skip_unless_1n1d()
def test_train_eval(test_case):
m = flow.nn.Module()
test_case.assertEqual(m.training, True)
......@@ -116,6 +123,7 @@ class TestModule(flow.unittest.TestCase):
m.eval()
test_case.assertEqual(m.training, False)
@flow.unittest.skip_unless_1n1d()
def test_module_setattr(test_case):
class CustomModule(flow.nn.Module):
def __init__(self, param1, param2):
......@@ -146,6 +154,7 @@ class TestModule(flow.unittest.TestCase):
test_case.assertTrue(np.allclose(child_params[0].numpy(), param0.numpy()))
test_case.assertTrue(np.allclose(child_params[1].numpy(), param1.numpy()))
@flow.unittest.skip_unless_1n1d()
def test_module_apply(test_case):
class CustomModule(flow.nn.Module):
def __init__(self):
......@@ -163,6 +172,7 @@ class TestModule(flow.unittest.TestCase):
net.apply(get_module_num)
test_case.assertEqual(module_num, 2)
@flow.unittest.skip_unless_1n1d()
def test_save_state_dict(test_case):
class CustomModule(flow.nn.Module):
def __init__(self):
......@@ -183,6 +193,55 @@ class TestModule(flow.unittest.TestCase):
res2 = m()
test_case.assertTrue(np.array_equal(res1.numpy(), res2.numpy()))
@flow.unittest.skip_unless_1n2d()
def test_save_and_load_consistent(test_case):
class CustomModule(flow.nn.Module):
def __init__(self):
super().__init__()
self.param = flow.nn.Parameter(flow.randn(3, 32, 3, 3))
def forward(self):
return self.param
m = CustomModule()
m = m.to_consistent(flow.placement("cuda", {0: range(2)}), flow.sbp.broadcast)
res1 = m()
state_dict = m.state_dict()
with tempfile.TemporaryDirectory() as f:
with test_case.assertRaises(Exception):
flow.save(state_dict, f)
consistent_src_dst_rank = 0
flow.save(state_dict, f, consistent_dst_rank=consistent_src_dst_rank)
rank = flow.framework.distribute.get_rank()
if rank != consistent_src_dst_rank:
test_case.assertEqual(len(os.listdir(f)), 0)
m = CustomModule()
m = m.to_consistent(
flow.placement("cuda", {0: range(2)}), flow.sbp.broadcast
)
with test_case.assertRaises(Exception):
loaded_state_dict = flow.load(f)
m.load_state_dict(loaded_state_dict)
loaded_state_dict = flow.load(
f, consistent_src_rank=consistent_src_dst_rank
)
test_case.assertEqual(len(loaded_state_dict), 1)
test_case.assertEqual(list(loaded_state_dict.keys())[0], "param")
m.load_state_dict(loaded_state_dict)
res2 = m()
test_case.assertTrue(
np.array_equal(
res1.to_consistent(sbp=flow.sbp.broadcast).to_local().numpy(),
res2.to_consistent(sbp=flow.sbp.broadcast).to_local().numpy(),
)
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册