未验证 提交 6a1b2253 编写于 作者: D daquexian 提交者: GitHub

New checkpoint (#3540)

* flow.load/save/get_all_variables without large tensor and multi machine support

* add lazy blob cache and disable blob_cache after writing

* update checkpoint to call the potential slice_assign and read_slice_from_blob method

* reformat

* new checkpoint supports eager

* split mut bn into mutable input bn and output bn

* work in eager mode. deprecate checkpoint.init()

* slice_assign implementation

* new slice op

* check step > 0, add more tests, refine the code

* revert the initializer changes

* remove print

* set y to 0 for partialsum

* check sbp, fix incorrect attr check

* add more tests

* rename slice2->logical_slice

* update tests

* extract common python code into a function

* get_size_in_slice -> GetSizeInSlice, rm unused test file

* minor update about step > 0

* minor update on tests

* add WITH_CUDA guard

* integrate with logical slice/slice_assign

* set scope according to variable op_conf

* initial support of stream init

* read_slice_from_blob/as_numpy return nd_idx and set the cpu:0 placement for created variable

* extract a 'for_every_slice' function

* initializer registration

* one meta file per variable

* remove mis-added file

* code clean

* create model io jobs only if legacy model io enabled, update legacy api

* add legacy model io test

* slice operation optimization

* add and update tests

* barrier for multi node eager

* make sync as a cluster instruction

* update test

* fix life cycle problem

* add python api vm_util.Sync()

* make initializer receive a random_seed

* Add vm_util.Sync(), remove debug code

* resolve TODO, remove __repr__ for now

* use compiled op_conf for getting random_seed

* UserOpAttrVal -> AttrValue, remove debug code

* test another dtype

* remove mis-added )

* fix dtype error when shape[axis+1:] is empty

* add initializers to check_point

* code clean, enable a temporary default checkpoint for test

* move legacy implementation to deprecated/

* update deprecated implementation

* fix bug in eager, add eager tests and some other minor updates

* remove name field in FileBackendBlob, update Load for single variable, and some other minor updates

* remove mis-added file

* move initializer implementation, some minor changes

* disable some bn tests missing checkpoint.init()

* fix dtype conversion bug

* relex the tolerance of layer_norm test

* reformat

* minor code clean

* use new pybind11 eager sync api

* add assignment between memory test

* disable optimizers test for now

* code clean

* reuse CreateEagerVariableBlob

* remove mis-added file

* unify two read slice function

* minor code clean

* add initializer_updated to check_point.py

* fix typo

* resolve merge conflict

* restore bn tests

* add type annotations, add some comments and minor code clean

* add some comments, remove 'need_root_path' parameter

* fixup

* get parallel_conf from job_set instead of op_attribute

* disable two tests involving legacy model io in eager mode

* add InitialzierImpl

* add InitializerImpl

* support load from numpy array, add test

* rename and format

* Add necessary docs and TODO, improve warning message

* ParallelConf4InterfaceOpName->ParallelConf4LazyInterfaceOpName

* address some comments

* rename api

* fix problems

* add test_initializer.py

* remove unused initializers

* remove quantinfo, move new checkpoint to check_point_v2.py

* fix crash on checkpoint.init()
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* restore optimizer test
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* Add GetOpAttributes api
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* restore ParallelConf4LazyOp as parallel desc symbol id in op attr doesn't align with that in job set
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* Add TestResumeTraining, shrink the large model size
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* restore 2n4c ci test
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* code clean
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* add snapshot_done
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* add test_mixed_model, update test_load_numpy
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* add flow.sync_default_session in api implementation
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* change the default value of ignore_mismatch from False to True to align with existing behavior
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* fix wrong initializer in test_mseloss.py and test_bce_loss.py
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* ForEachOpNode -> ForEachNode
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* fix test_partially_load_numpy
Signed-off-by: Ndaquexian <daquexian566@gmail.com>
Co-authored-by: Nwanghongsheng <2496533749@qq.com>
Co-authored-by: Ncheng cheng <472491134@qq.com>
上级 3e6f8895
......@@ -14,4 +14,4 @@ cd $test_tmp_dir
ONEFLOW_TEST_DEVICE_NUM=1 python3 -m unittest discover test/ops --failfast --verbose
ONEFLOW_TEST_DEVICE_NUM=2 python3 -m unittest discover test/ops --failfast --verbose
# ONEFLOW_TEST_DEVICE_NUM=4 python3 -m unittest discover test/ops --failfast --verbose
ONEFLOW_TEST_DEVICE_NUM=4 python3 -m unittest discover test/ops --failfast --verbose
syntax = "proto2";
package oneflow;
import "oneflow/core/common/shape.proto";
import "oneflow/core/common/data_type.proto";
message VariableMetaInfo {
required ShapeProto shape = 2;
required DataType data_type = 3;
}
......@@ -11,6 +11,7 @@ message IOConf {
optional bool save_downloaded_file_to_local_fs = 3 [default = false];
optional uint64 persistence_buf_byte = 4;
optional bool enable_model_io_v2 = 5 [default = false];
optional bool enable_legacy_model_io = 6 [default = false];
}
message ProfilerConf {
......
......@@ -907,10 +907,12 @@ Maybe<void> CompileAndMergePlanOnMaster(const PbRpf<Job>& conf_jobs, Plan* plan)
CHECK(!job_desc.Bool("__is_user_function__"));
jobs.emplace_back(new Job(*job));
};
if (Global<const IOConf>::Get()->enable_model_io_v2()) {
MakeModelIoV2Jobs(jobs, var_op_name2parallel_blob_conf, AppendJob);
} else {
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, AppendJob);
if (Global<const IOConf>::Get()->enable_legacy_model_io()) {
if (Global<const IOConf>::Get()->enable_model_io_v2()) {
MakeModelIoV2Jobs(jobs, var_op_name2parallel_blob_conf, AppendJob);
} else {
MakeModelIoJobs(jobs, var_op_name2parallel_blob_conf, AppendJob);
}
}
}
std::vector<std::shared_ptr<Job>> function_jobs;
......
......@@ -33,6 +33,10 @@ message OpAttribute {
optional ParallelSignature parallel_signature = 108;
}
message OpAttributeList {
repeated OpAttribute op_attribute = 1;
}
message OpNodeSignature {
optional SbpSignature sbp_signature = 1;
optional MirroredSignature mirrored_signature = 2;
......
......@@ -19,10 +19,10 @@ import oneflow.python.framework.c_api_util as c_api_util
import oneflow
def Infer(op_conf, ibn2blob_object, scope_symbol=None):
if scope_symbol is None:
scope_symbol = oneflow.current_scope()
op_conf.scope_symbol_id = scope_symbol.symbol_id
def Infer(op_conf, ibn2blob_object, scope_symbol_id=None):
if scope_symbol_id is None:
scope_symbol_id = oneflow.current_scope().symbol_id
op_conf.scope_symbol_id = scope_symbol_id
upstream_signature = MakeUpstreamSignature(ibn2blob_object)
return c_api_util.InferOpConf(op_conf, upstream_signature)
......
......@@ -389,10 +389,9 @@ class InstructionsBuilder(object):
assert len(op_attribute.output_bns) == 1
obn = op_attribute.output_bns[0]
blob_parallel_desc_sym_id = op_attribute.parallel_signature.bn_in_op2parallel_desc_symbol_id[
obn
]
blob_parallel_desc_sym = symbol_storage.GetSymbol4Id(blob_parallel_desc_sym_id)
parallel_conf = sess.ParallelConf4LazyInterfaceOpName(interface_op_name)
blob_parallel_desc_sym = self.GetParallelDescSymbol(parallel_conf)
op_arg_parallel_attr = op_arg_util.GetOpArgParallelAttribute(
blob_parallel_desc_sym, op_attribute, obn
)
......
......@@ -15,14 +15,59 @@ limitations under the License.
"""
import oneflow as flow
import oneflow.core.register.logical_blob_id_pb2 as logical_blob_id_util
import oneflow.python.eager.blob_cache as blob_cache_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.lib.core.async_util as async_util
import oneflow.python.framework.c_api_util as c_api_util
import oneflow.python.framework.input_blob_def as input_blob_def_util
import oneflow.python.framework.dtype as dtype_util
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.push_util as push_util
import oneflow.python.framework.session_context as session_ctx
from oneflow.python.oneflow_export import oneflow_export
import oneflow.python.eager.op_executor as op_executor
def _GetInterfaceBlobObject(builder, op_name):
if c_api_util.EagerExecutionEnabled():
return session_ctx.GetDefaultSession().var_name2var_blob[op_name].blob_object
blob_object = builder.MakeLazyRefBlobObject(op_name)
return blob_object
def GetEagerInterfaceBlob(op_name):
flow.sync_default_session()
sess = session_ctx.GetDefaultSession()
def CreateBlob():
job_name = sess.JobName4InterfaceOpName(op_name)
def Build(builder, Yield):
blob_object = _GetInterfaceBlobObject(builder, op_name)
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_name
op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
assert len(op_attribute.output_bns) == 1
lbi.blob_name = op_attribute.output_bns[0]
if blob_object.op_arg_parallel_attr.is_mirrored():
remote_blob = remote_blob_util.EagerMirroredBlob(
lbi, blob_object, job_name
)
else:
remote_blob = remote_blob_util.EagerConsistentBlob(
lbi, blob_object, job_name
)
Yield(remote_blob)
def AsyncGetInterfaceBlob(Yield):
vm_util.LogicalRun(lambda builder: Build(builder, Yield))
blob = async_util.Await(1, AsyncGetInterfaceBlob)[0]
return blob
return sess.FindOrCreateLazyBlob(op_name, CreateBlob)
@oneflow_export("experimental.get_interface_blob_value")
......@@ -34,7 +79,7 @@ def GetInterfaceBlobValue(op_name):
def AsyncGetInterfaceBlobValue(Yield):
def build(builder):
blob_object = builder.MakeLazyRefBlobObject(op_name)
blob_object = GetEagerInterfaceBlob(op_name).blob_object
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_name
op_attribute = sess.OpAttribute4InterfaceOpName(op_name)
......@@ -52,7 +97,6 @@ def GetInterfaceBlobValue(op_name):
value = remote_blob.numpy_list()
else:
value = remote_blob.numpy()
Yield(value)
vm_util.LogicalRun(build)
......@@ -60,35 +104,38 @@ def GetInterfaceBlobValue(op_name):
return async_util.Await(1, AsyncGetInterfaceBlobValue)[0]
def FeedValueToInterfaceBlobObject(blob_object, ndarray):
flow.sync_default_session()
def build(builder):
if blob_object.op_arg_blob_attr.is_tensor_list:
input_blob_def = input_blob_def_util.MirroredTensorListDef(
[x.shape for x in ndarray],
dtype=dtype_util.convert_numpy_dtype_to_oneflow_dtype(ndarray.dtype),
)
elif blob_object.op_arg_parallel_attr.is_mirrored():
input_blob_def = input_blob_def_util.MirroredTensorDef(
ndarray.shape,
dtype=dtype_util.convert_numpy_dtype_to_oneflow_dtype(ndarray.dtype),
)
else:
input_blob_def = input_blob_def_util.FixedTensorDef(
ndarray.shape,
dtype=dtype_util.convert_numpy_dtype_to_oneflow_dtype(ndarray.dtype),
)
push_util.FeedValueToEagerBlob(blob_object, input_blob_def, ndarray)
vm_util.LogicalRun(build)
@oneflow_export("experimental.set_interface_blob_value")
def FeedValueToInterfaceBlob(op_name, ndarray):
flow.sync_default_session()
def AsyncFeedValueToInterfaceBlob(Yield):
def build(builder):
blob_object = builder.MakeLazyRefBlobObject(op_name)
if blob_object.op_arg_blob_attr.is_tensor_list:
input_blob_def = input_blob_def_util.MirroredTensorListDef(
[x.shape for x in ndarray],
dtype=dtype_util.convert_numpy_dtype_to_oneflow_dtype(
ndarray.dtype
),
)
elif blob_object.op_arg_parallel_attr.is_mirrored():
input_blob_def = input_blob_def_util.MirroredTensorDef(
ndarray.shape,
dtype=dtype_util.convert_numpy_dtype_to_oneflow_dtype(
ndarray.dtype
),
)
else:
input_blob_def = input_blob_def_util.FixedTensorDef(
ndarray.shape,
dtype=dtype_util.convert_numpy_dtype_to_oneflow_dtype(
ndarray.dtype
),
)
push_util.FeedValueToEagerBlob(blob_object, input_blob_def, ndarray)
blob_object = GetEagerInterfaceBlob(op_name).blob_object
FeedValueToInterfaceBlobObject(blob_object, ndarray)
Yield()
vm_util.LogicalRun(build)
......
......@@ -628,6 +628,14 @@ def NewPhysicalSymbolId():
return object_id
def GetOpAttributes():
op_attributes, error_str = oneflow_internal.GetSerializedOpAttributes()
error = text_format.Parse(error_str, error_util.ErrorProto())
if error.HasField("error_type"):
raise JobBuildAndInferError(error)
return text_format.Parse(op_attributes, op_attribute_pb.OpAttributeList())
def GetJobSet():
job_set, error_str = oneflow_internal.GetSerializedJobSet()
error = text_format.Parse(error_str, error_util.ErrorProto())
......
......@@ -15,10 +15,14 @@ limitations under the License.
"""
import datetime
import os
import shutil
import numpy as np
import oneflow.python.framework.hob as hob
import oneflow.python.framework.job_instance as job_instance
import oneflow.python.framework.check_point_v2 as check_point_v2
import oneflow.python.framework.config_util as config_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.lib.core.enable_if as enable_if
import oneflow.python.eager.op_executor as op_executor
......@@ -34,7 +38,13 @@ class CheckPoint(object):
"""
def __init__(self) -> None:
pass
if not config_util.api_legacy_model_io_enabled():
print(
"\033[1mWARNING: 'flow.train.CheckPoint' is deprecated. Please use the new API:\033[0m\n"
"flow.train.CheckPoint().save(path) => \033[1m\033[92mflow.checkpoint.save(path)\033[0m\n"
"flow.train.CheckPoint().load(path) => \033[1m\033[92mflow.load_variables(flow.checkpoint.get(path))\033[0m\n"
"flow.train.CheckPoint().init() is not needed any more.\n"
)
@session_ctx.try_init_default_session
def save(self, path: str) -> None:
......@@ -43,6 +53,9 @@ class CheckPoint(object):
Args:
path: A `string` of path to save checkpoint.
"""
if not config_util.api_legacy_model_io_enabled():
check_point_v2.SaveVarDict(path)
return
assert type(path) is str
enable_if.unique([lazy_checkpoint_save, eager_checkpoint_save])(path)
......@@ -50,6 +63,8 @@ class CheckPoint(object):
def init(self) -> None:
r"""Initialize models by default initializer of op or Job.
"""
if not config_util.api_legacy_model_io_enabled():
return
enable_if.unique([lazy_checkpoint_init, eager_checkpoint_init])()
@session_ctx.try_init_default_session
......@@ -59,6 +74,9 @@ class CheckPoint(object):
Args:
path: A `string` of path to load checkpoint.
"""
if not config_util.api_legacy_model_io_enabled():
check_point_v2.LoadVariables(check_point_v2.GetCheckpoint(path))
return
assert type(path) is str
enable_if.unique([lazy_checkpoint_load, eager_checkpoint_load])(path)
......@@ -155,7 +173,6 @@ class SimpleCheckPointManager(object):
assert os.path.isdir(root_path)
self._root_path = root_path
self._prefix = prefix
self._checkpoint = CheckPoint()
def list_checkpoints(self) -> List[str]:
def is_snapshot(name):
......@@ -176,13 +193,14 @@ class SimpleCheckPointManager(object):
def initialize_or_restore(self) -> None:
name = self.latest_checkpoint()
if name:
self._checkpoint.load(self._GetSnapshotPath(name))
check_point_v2.LoadVariables(
check_point_v2.GetCheckpoint(self._GetSnapshotPath(name))
)
else:
self._checkpoint.init()
self.save()
def save(self) -> None:
self._checkpoint.save(self._GetSnapshotPath(self._NextSnapshotName()))
check_point_v2.SaveVarDict(self._GetSnapshotPath(self._NextSnapshotName()),)
def _NextSnapshotName(self) -> str:
return self._prefix + datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
......
此差异已折叠。
......@@ -340,6 +340,29 @@ def persistence_buf_byte(val):
sess.config_proto.io_conf.persistence_buf_byte = val
@oneflow_export("config.legacy_model_io_enabled")
def api_legacy_model_io_enabled():
sess = session_ctx.GetDefaultSession()
return sess.config_proto.io_conf.enable_legacy_model_io
@oneflow_export("config.enable_legacy_model_io")
def api_enable_legacy_model_io(val: bool = True):
r"""Whether or not use legacy model io.
Args:
val ([type]): True or False
"""
return enable_if.unique([enable_legacy_model_io, do_nothing])(val)
@enable_if.condition(hob.in_normal_mode & ~hob.session_initialized)
def enable_legacy_model_io(val):
sess = session_ctx.GetDefaultSession()
assert type(val) is bool
sess.config_proto.io_conf.enable_legacy_model_io = val
@oneflow_export("config.enable_model_io_v2")
def api_enable_model_io_v2(val):
r"""Whether or not use version2 of model input/output function.
......
......@@ -22,6 +22,7 @@ import oneflow.python.framework.python_callback as python_callback
import oneflow.python.framework.balanced_splitter as balanced_splitter
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.id_util as id_util
import oneflow.python.eager.blob_cache as blob_cache_util
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.blob_register as blob_register_util
import oneflow.python.eager.object as object_util
......@@ -108,6 +109,7 @@ def FeedValueToEagerBlob(blob_object, blob_def, ndarray):
for i, physical_blob_object in enumerate(physical_blob_objects):
feed_ctx.set_rank(i)
_FeedValueToInputPhysicalBlob(feed_ctx, blob_def, physical_blob_object)
blob_cache_util.TryDisableBlobCache(blob_object)
def _CreateEagerInputBlobAndFeedValue(arg_blob_def, arg_ndarray):
......
......@@ -31,6 +31,8 @@ import oneflow.python.framework.push_util as push_util
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.lib.core.enable_if as enable_if
import oneflow.python.eager.vm_util as vm_util
import oneflow.python.eager.op_executor as op_executor
from oneflow.python.experimental import interface_op_read_and_write
from oneflow.core.job.job_set_pb2 import ConfigProto
from oneflow.python.framework.function_desc import FunctionDesc
import oneflow.python.framework.module as module_util
......@@ -42,6 +44,7 @@ from oneflow.python.framework.session_context import SessionStatus
from oneflow.python.oneflow_export import oneflow_export, oneflow_deprecate
from oneflow.python.framework.function_desc import FunctionDesc
from oneflow.python.framework.check_point import SnapshotManager
import oneflow.python.framework.check_point_v2 as check_point_v2
import oneflow.python.eager.blob_register as blob_register_util
from contextlib import contextmanager
from typing import Callable
......@@ -68,8 +71,13 @@ class Session(object):
self.job_name2module_name2module_ = {}
self.existed_module_names_ = set()
self.var_name2var_blob_ = {}
# parallel desc symbol id in op attribute does not always correct
# for lazy ops as parallel conf may be updated in some passes
# (like non_distributed_optimizer_pass)
self.interface_op_name2op_attr_ = {}
self.interface_op_name2job_name_ = {}
self.lazy_interface_op_name2parallel_conf_ = {}
self.op_name2lazy_blob_cache_ = {}
self.job_name2name_scope_stack_ = {}
self.eager_global_function_desc_stack_ = []
self.function_flag_name2default_val_ = {}
......@@ -178,6 +186,25 @@ class Session(object):
self.Init()
return self
def _UpdateInfo4LazyInterfaceOp(self):
for op_attr in c_api_util.GetOpAttributes().op_attribute:
op_conf = op_attr.op_conf
if c_api_util.IsInterfaceOpConf(op_conf):
self.interface_op_name2op_attr_[op_conf.name] = op_attr
for job in c_api_util.GetJobSet().job:
op_name2parallel_conf = {}
for placement_group in job.placement.placement_group:
for op_name in placement_group.op_set.op_name:
op_name2parallel_conf[op_name] = placement_group.parallel_conf
for op_conf in job.net.op:
if c_api_util.IsInterfaceOpConf(op_conf):
self.interface_op_name2job_name_[
op_conf.name
] = job.job_conf.job_name
self.lazy_interface_op_name2parallel_conf_[
op_conf.name
] = op_name2parallel_conf[op_conf.name]
def Init(self):
assert self.status_ is SessionStatus.OPEN
self.status_ = SessionStatus.RUNNING
......@@ -194,12 +221,21 @@ class Session(object):
assert len(self.job_name2function_desc_.items()) > 0
c_api_util.StartLazyGlobalSession()
self.inter_user_job_info_ = c_api_util.GetInterUserJobInfo()
# Get latest op_attr and job_name after compiler.Compile
self._UpdateInfo4LazyInterfaceOp()
if not config_util.api_legacy_model_io_enabled():
check_point_v2.Init()
else:
self.eager_config_proto_ctx_ = oneflow_api.LogicalConfigProtoContext(
str(self.config_proto)
)
return self
def FindOrCreateLazyBlob(self, op_name, Create):
if op_name not in self.op_name2lazy_blob_cache_:
self.op_name2lazy_blob_cache_[op_name] = Create()
return self.op_name2lazy_blob_cache_[op_name]
def TryClose(self):
if self.status_ is SessionStatus.RUNNING:
self.Close()
......@@ -210,6 +246,7 @@ class Session(object):
assert len(self.job_name2var_name2var_blob_) == 0
del self.var_name2var_blob_
del self.job_name2module_name2module_
self.ReleaseLazyRefBlob()
self.ForceReleaseEagerBlobs()
c_api_util.StopLazyGlobalSession()
c_api_util.DestroyLazyGlobalSession()
......@@ -231,6 +268,9 @@ class Session(object):
assert self.running_job_cnt_ == 0
self.cond_var_.release()
def ReleaseLazyRefBlob(self):
self.op_name2lazy_blob_cache_.clear()
def ForceReleaseEagerBlobs(self):
blob_register_util.GetDefaultBlobRegister().ForceReleaseAll()
self.backward_blob_register_.ForceReleaseAll()
......@@ -299,17 +339,29 @@ class Session(object):
self.job_name2var_name2var_blob_[job_name][var_name] = var_blob
def AddInfo4InterfaceOpName(self, interface_op_name, op_attribute):
self.interface_op_name2op_attr_[interface_op_name] = op_attribute
self.interface_op_name2job_name_[
interface_op_name
] = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
if oneflow.eager_execution_enabled():
self.interface_op_name2op_attr_[interface_op_name] = op_attribute
self.interface_op_name2job_name_[
interface_op_name
] = c_api_util.JobBuildAndInferCtx_GetCurrentJobName()
else:
# In lazy mode, we update fields with
# the latest info in another function after compiler.Compile
pass
def OpAttribute4InterfaceOpName(self, interface_op_name):
return self.interface_op_name2op_attr_[interface_op_name]
def ParallelConf4LazyInterfaceOpName(self, interface_op_name):
return self.lazy_interface_op_name2parallel_conf_[interface_op_name]
def JobName4InterfaceOpName(self, interface_op_name):
return self.interface_op_name2job_name_[interface_op_name]
@property
def interface_ops(self):
return self.interface_op_name2op_attr_.keys()
# return global_variable_blob, job_variable_blob
def TryGetVariableBlobOfJobFromStash(self, job_name, var_name):
if var_name not in self.var_name2var_blob_:
......
......@@ -84,6 +84,11 @@ std::string GetSerializedInterUserJobInfo(std::string* error_str) {
std::string(""));
}
std::string GetSerializedOpAttributes(std::string* error_str) {
return oneflow::GetSerializedOpAttributes().GetDataAndSerializedErrorProto(error_str,
std::string(""));
}
std::string GetSerializedJobSet(std::string* error_str) {
return oneflow::GetSerializedJobSet().GetDataAndSerializedErrorProto(error_str, std::string(""));
}
......
......@@ -17,7 +17,9 @@ limitations under the License.
#include <google/protobuf/text_format.h>
#include "oneflow/core/common/buffer_manager.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/graph/op_graph.h"
#include "oneflow/core/register/ofblob.h"
#include "oneflow/core/operator/op_attribute.pb.h"
#include "oneflow/core/job/job_set.pb.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/env.pb.h"
......@@ -183,6 +185,22 @@ Maybe<std::string> GetSerializedJobSet() {
return PbMessage2TxtString(job_ctx_mgr->job_set());
}
Maybe<std::string> GetSerializedOpAttributes() {
OpAttributeList op_attribute_list;
auto* job_ctx_mgr = Global<LazyJobBuildAndInferCtxMgr>::Get();
CHECK_NOTNULL_OR_RETURN(job_ctx_mgr);
for (int i = 0; i < job_ctx_mgr->job_set().job_size(); i++) {
const Job& job = job_ctx_mgr->job_set().job(i);
auto scope = std::make_unique<GlobalJobDescScope>(job.job_conf(), i);
const auto& op_graph = JUST(OpGraph::New(job));
op_graph->ForEachNode([&op_attribute_list](OpNode* op_node) {
const auto& op_attribute = op_node->op().GetOpAttributeWithoutOpNameAndLbn();
op_attribute_list.mutable_op_attribute()->Add()->CopyFrom(*op_attribute);
});
}
return PbMessage2TxtString(op_attribute_list);
}
Maybe<std::string> GetFunctionConfigDef() {
std::string ret;
google::protobuf::TextFormat::PrintToString(GlobalFunctionConfigDef(), &ret);
......
......@@ -20,6 +20,7 @@ from oneflow.python.oneflow_export import oneflow_export
import oneflow.python.framework.session_context as session_ctx
import oneflow.python.framework.compile_context as compile_context
import oneflow.python.framework.remote_blob as remote_blob_util
import oneflow.python.framework.runtime_mode as rt_mode
import oneflow.python.framework.distribute as distribute_util
import oneflow.python.experimental.name_scope as name_scope
import oneflow.core.operator.op_conf_pb2 as op_conf_util
......@@ -189,7 +190,7 @@ def get_eager_variable(
)
if job_var_blob is None:
op_conf = _GenerateVariableOpConf(
op_conf = GenerateVariableOpConf(
name=name,
shape=shape,
dtype=dtype,
......@@ -202,7 +203,7 @@ def get_eager_variable(
)
op_attribute = compile_context.CurJobAddConsistentOp(op_conf)
if var_blob is None:
var_blob = _CreateEagerVariableBlob(op_attribute)
var_blob = CreateEagerVariableBlob(op_attribute)
op_executor.EagerInitVariableBlob(sess, op_conf, var_blob)
assert isinstance(var_blob, remote_blob_util.EagerConsistentBlob)
......@@ -250,7 +251,7 @@ def get_lazy_variable(
)
if job_var_blob is None:
op_conf = _GenerateVariableOpConf(
op_conf = GenerateVariableOpConf(
name=name,
shape=shape,
dtype=dtype,
......@@ -275,7 +276,7 @@ def get_lazy_variable(
return job_var_blob
def _GenerateVariableOpConf(
def GenerateVariableOpConf(
name,
shape,
dtype=None,
......@@ -293,11 +294,14 @@ def _GenerateVariableOpConf(
assert dtype is not None
op_conf.variable_conf.data_type = dtype.oneflow_proto_dtype
root_path = (
compile_context.GetCurJobConfigProto().default_initialize_with_snapshot_path
)
dir_path = os.path.join(root_path, name)
file_path = os.path.join(dir_path, "out")
if rt_mode.CurrentMode() == rt_mode.NORMAL_MODE:
root_path = None
else:
root_path = (
compile_context.GetCurJobConfigProto().default_initialize_with_snapshot_path
)
dir_path = os.path.join(root_path, name)
file_path = os.path.join(dir_path, "out")
if root_path and os.path.isfile(file_path):
op_conf.variable_conf.initialize_with_snapshot.path = dir_path
op_conf.variable_conf.initialize_with_snapshot.key = "out"
......@@ -336,7 +340,7 @@ def _CreateVariableBlob(op_conf):
return remote_blob_util.RemoteBlob(lbi)
def _CreateEagerVariableBlob(op_attribute):
def CreateEagerVariableBlob(op_attribute, job_name=None):
bn_in_op2blob_object = {}
def BuildInstruction(builder):
......@@ -351,6 +355,6 @@ def _CreateEagerVariableBlob(op_attribute):
lbi = logical_blob_id_util.LogicalBlobId()
lbi.op_name = op_attribute.op_conf.name
lbi.blob_name = op_attribute.op_conf.variable_conf.out
return remote_blob_util.EagerLogicalBlob(
lbi, blob_object=bn_in_op2blob_object["out"]
return remote_blob_util.EagerConsistentBlob(
lbi, blob_object=bn_in_op2blob_object["out"], job_name=job_name
)
......@@ -18,11 +18,13 @@ from __future__ import absolute_import
import functools
import math
import numpy as np
import oneflow as flow
import oneflow.core.operator.op_conf_pb2 as op_conf_util
import oneflow.python.framework.dtype as dtype_util
from oneflow.python.oneflow_export import oneflow_export
from typing import Optional, Sequence
from typing import Optional, Sequence, Union
@oneflow_export("constant_initializer")
......@@ -1060,3 +1062,142 @@ def _CalcGain(nonlinearity, negative_slope):
raise NotImplementedError(
"Only support None, 'tanh', 'sigmoid', 'relu' and 'leaky_relu' nonlinearity"
)
_init_map = {}
def register_initializer(flow_initializer):
def deco(func):
_init_map[flow_initializer] = func
return func
return deco
def GetInitializer(initializer_conf, random_seed, var_blob_shape):
f = None
for m in _init_map:
if initializer_conf.HasField(m):
f = _init_map[m]
break
assert f is not None, initializer_conf
return f(getattr(initializer_conf, m), random_seed, var_blob_shape)
@register_initializer("constant_conf")
@register_initializer("constant_int_conf")
def ConstantInitializerImpl(
initializer_conf: Union[
op_conf_util.ConstantInitializerConf, op_conf_util.ConstantIntInitializerConf
],
random_seed: int,
var_blob_shape: Sequence[int],
):
return lambda length: np.full((length,), initializer_conf.value)
@register_initializer("random_normal_conf")
def RandomNormalInitializerImpl(
initializer_conf: op_conf_util.RandomNormalInitializerConf,
random_seed: int,
var_blob_shape: Sequence[int],
):
rng = np.random.default_rng(random_seed)
return lambda length: rng.normal(
loc=initializer_conf.mean, scale=initializer_conf.std, size=length
)
@register_initializer("random_uniform_conf")
def RandomUniformInitializerImpl(
initializer_conf: op_conf_util.RandomUniformIntInitializerConf,
random_seed: int,
var_blob_shape: Sequence[int],
):
rng = np.random.default_rng(random_seed)
return lambda length: rng.uniform(
low=initializer_conf.min,
high=np.nextafter(initializer_conf.max, float("inf")),
size=length,
)
@register_initializer("random_uniform_int_conf")
def RandomUniformIntInitializerImpl(
initializer_conf: op_conf_util.RandomUniformIntInitializerConf,
random_seed: int,
var_blob_shape: Sequence[int],
):
rng = np.random.default_rng(random_seed)
return lambda length: rng.integers(
low=initializer_conf.min, high=initializer_conf.max, size=length
)
def RngTruncatedNormal(mean, std, length, rng):
truncated_value = 2 * std
data = []
while len(data) < length:
data.extend(
filter(
lambda value: abs(value - mean) < truncated_value,
rng.normal(mean, std, size=length - len(data)),
)
)
return data
@register_initializer("truncated_normal_conf")
def TruncatedNormalInitializerImpl(
initializer_conf: op_conf_util.TruncatedNormalInitializerConf,
random_seed: int,
var_blob_shape: Sequence[int],
):
rng = np.random.default_rng(random_seed)
return lambda length: RngTruncatedNormal(
initializer_conf.mean, initializer_conf.std, length, rng,
)
def GenInitialFan(initializer_conf, var_blob_shape: Sequence[int]):
variance_norm = initializer_conf.variance_norm
data_format = initializer_conf.data_format
fan_in = np.prod(var_blob_shape[1:]).astype(np.int).item()
fan_out = var_blob_shape[0]
if data_format == "channel_first":
fan_out *= np.prod(var_blob_shape[2:]).astype(np.int).item()
else:
fan_out *= np.prod(var_blob_shape[1:-1]).astype(np.int).item()
if variance_norm == op_conf_util.kAverage:
fan = (fan_in + fan_out) / 2
elif variance_norm == op_conf_util.kFanIn:
fan = fan_in
elif variance_norm == op_conf_util.kFanOut:
fan = fan_out
else:
raise NotImplemented()
return fan
@register_initializer("variance_scaling_conf")
def VarianceScalingInitializerImpl(
initializer_conf: op_conf_util.VarianceScalingInitializerConf,
random_seed: int,
var_blob_shape: Sequence[int],
):
scale = initializer_conf.scale / GenInitialFan(initializer_conf, var_blob_shape)
distribution = initializer_conf.distribution
rng = np.random.default_rng(random_seed)
if distribution == op_conf_util.kTruncatedNormal:
stddev = math.sqrt(scale) / 0.87962566103423978
return lambda length: RngTruncatedNormal(0, stddev, length, rng)
elif distribution == op_conf_util.kRandomNormal:
stddev = math.sqrt(scale)
return lambda length: rng.normal(0, stddev, size=length,)
elif distribution == op_conf_util.kRandomUniform:
limit = math.sqrt(3.0 * scale)
return lambda length: rng.uniform(low=-limit, high=limit, size=length)
else:
raise NotImplemented()
......@@ -102,7 +102,7 @@ def _compare_bceloss_with_np(
v = flow.get_variable(
shape=target.shape,
dtype=flow.float32,
initializer=flow.constant_initializer(1),
initializer=flow.zeros_initializer(),
name="v",
)
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
import shutil
import tempfile
import numpy as np
import oneflow as flow
import oneflow.typing as tp
def refresh_session():
flow.clear_default_session()
flow.config.gpu_device_num(flow.unittest.env.device_num())
def get_placement():
node_size = flow.unittest.env.node_size()
device_ids = "0-{}".format(flow.unittest.env.device_num() - 1)
machine_device_ids = [
"{}:{}".format(node_id, device_ids) for node_id in range(node_size)
]
return flow.scope.placement("gpu", machine_device_ids)
def get_simple_momentum_training_model(dtype):
assert dtype == flow.float32
@flow.global_function(type="train")
def model(x: tp.Numpy.Placeholder((4, 5))) -> tp.Numpy:
with get_placement():
w = flow.get_variable(
name="w",
shape=(5, 6),
dtype=flow.float32,
initializer=flow.random_normal_initializer(mean=10, stddev=1),
distribute=flow.distribute.split(0),
)
y = flow.matmul(x, w)
flow.optimizer.SGD(
flow.optimizer.PiecewiseConstantScheduler([], [0.01]), momentum=0.9
).minimize(y)
return y
return model
def get_simple_model(dtype):
@flow.global_function()
def add() -> tp.Numpy:
with get_placement():
x = flow.get_variable(
name="x",
shape=(9, 3),
dtype=dtype,
initializer=flow.random_normal_initializer(mean=10, stddev=1),
distribute=flow.distribute.split(0),
)
y = flow.get_variable(
name="y",
shape=(9, 3),
dtype=dtype,
initializer=flow.constant_initializer(5, dtype=dtype),
)
z = flow.get_variable(
name="z",
shape=(9, 3),
dtype=dtype,
initializer=flow.random_normal_initializer(),
)
return flow.math.add_n([x, y, z])
return add
def get_large_model(dtype):
@flow.global_function()
def large() -> tp.Numpy:
with get_placement():
x = flow.get_variable(
name="x",
shape=(10, 2801, 820, 4),
dtype=dtype,
initializer=flow.random_normal_initializer(mean=10, stddev=1),
distribute=flow.distribute.split(0),
)
return flow.math.reduce_mean(x)
return large
def get_add_and_reduce_mean_model(dtype):
@flow.global_function()
def model() -> tp.Numpy:
with get_placement():
x = flow.get_variable(
name="x",
shape=(10, 801, 820, 4),
dtype=dtype,
initializer=flow.random_normal_initializer(mean=10, stddev=1),
distribute=flow.distribute.split(0),
)
y = flow.get_variable(
name="y",
shape=(10, 801, 820, 4),
dtype=dtype,
initializer=flow.random_normal_initializer(mean=10, stddev=1),
distribute=flow.distribute.split(0),
)
return flow.math.reduce_mean(x + y)
return model
def get_checkpoint_ready_model(model_getter, dtype):
model = model_getter(dtype)
if flow.eager_execution_enabled():
model()
return model
def _TestSaveCorrectness(test_case, model_getter, dtype, legacy_api):
"""
Save weights by new model io, load weights by legacy model io,
and check the equality.
"""
with tempfile.TemporaryDirectory() as save_dir:
refresh_session()
flow.config.enable_legacy_model_io(False)
large1 = get_checkpoint_ready_model(model_getter, dtype)
if legacy_api:
check_point = flow.train.CheckPoint()
check_point.save(save_dir)
else:
flow.checkpoint.save(save_dir)
res1 = large1()
refresh_session()
flow.config.enable_legacy_model_io(True)
large2 = get_checkpoint_ready_model(model_getter, dtype)
check_point = flow.train.CheckPoint()
check_point.load(save_dir)
flow.sync_default_session()
res2 = large2()
test_case.assertTrue(np.array_equal(res1, res2))
def _TestRoundTrip(test_case, model_getter, dtype):
"""
Save weights by new model io, load weights by new model io,
and check the equality.
"""
with tempfile.TemporaryDirectory() as save_dir:
refresh_session()
large1 = get_checkpoint_ready_model(model_getter, dtype)
flow.checkpoint.save(save_dir)
res1 = large1()
refresh_session()
large2 = get_checkpoint_ready_model(model_getter, dtype)
vars_in_file = flow.checkpoint.get(save_dir)
flow.load_variables(vars_in_file)
res2 = large2()
test_case.assertTrue(np.array_equal(res1, res2))
def _TestLoadCorrectness(test_case, model_getter, dtype, legacy_api):
"""
Save weights by legacy model io, load weights by new model io,
and check the equality.
"""
with tempfile.TemporaryDirectory() as save_dir:
refresh_session()
flow.config.enable_legacy_model_io(True)
large1 = get_checkpoint_ready_model(model_getter, dtype)
check_point = flow.train.CheckPoint()
check_point.init()
check_point.save(save_dir)
res1 = large1()
flow.clear_default_session()
flow.config.gpu_device_num(4)
flow.config.enable_legacy_model_io(False)
large2 = get_checkpoint_ready_model(model_getter, dtype)
if legacy_api:
check_point = flow.train.CheckPoint()
check_point.load(save_dir)
else:
vars_in_file = flow.checkpoint.get(save_dir)
flow.load_variables(vars_in_file)
res2 = large2()
test_case.assertTrue(np.array_equal(res1, res2))
def _TestPartiallyLoadNumpy(test_case, dtype):
refresh_session()
model = get_checkpoint_ready_model(get_add_and_reduce_mean_model, dtype)
var_x = flow.get_all_variables()["x"]
var_y_value_before_loading = flow.get_all_variables()["y"].numpy()
new_val_np = np.random.random(var_x.shape).astype(np.float32)
flow.load_variables({"x": new_val_np})
var_y_value_after_loading = flow.get_all_variables()["y"].numpy()
flow_res = model()
np_res = (var_y_value_after_loading + new_val_np).mean()
test_case.assertTrue(np.allclose(flow_res, np_res))
test_case.assertTrue(
np.array_equal(var_y_value_before_loading, var_y_value_after_loading)
)
def _TestMixedModel(test_case, dtype):
with tempfile.TemporaryDirectory() as save_dir1, tempfile.TemporaryDirectory() as save_dir2:
def get_variable(name):
return flow.get_variable(
name=name,
shape=(10, 80, 40, 20),
dtype=dtype,
initializer=flow.random_normal_initializer(mean=10, stddev=1),
distribute=flow.distribute.split(0),
)
def get_part_of_mixed_model(dtype):
@flow.global_function()
def model() -> tp.Numpy:
with get_placement():
x = get_variable("x")
return x
return model
def get_mixed_model(dtype):
@flow.global_function()
def model() -> tp.Numpy:
with get_placement():
x1 = get_variable("x_from_model1")
x2 = get_variable("x_from_model2")
return x1 + x2
return model
refresh_session()
model1 = get_checkpoint_ready_model(get_part_of_mixed_model, dtype)
flow.checkpoint.save(save_dir1)
refresh_session()
model2 = get_checkpoint_ready_model(get_part_of_mixed_model, dtype)
flow.checkpoint.save(save_dir2)
refresh_session()
mixed_model = get_checkpoint_ready_model(get_mixed_model, dtype)
var_dict_from_model1 = flow.checkpoint.get(save_dir1)
var_dict_from_model2 = flow.checkpoint.get(save_dir2)
new_var_dict = {}
for key, val in var_dict_from_model1.items():
new_var_dict["{}_from_model1".format(key)] = val
for key, val in var_dict_from_model2.items():
new_var_dict["{}_from_model2".format(key)] = val
flow.load_variables(new_var_dict)
res = mixed_model()
test_case.assertTrue(
np.allclose(
res,
var_dict_from_model1["x"].numpy() + var_dict_from_model2["x"].numpy(),
)
)
def _TestResumeTraining(test_case):
with tempfile.TemporaryDirectory() as save_dir:
x = np.random.random((4, 5)).astype(np.float32)
refresh_session()
model = get_checkpoint_ready_model(
get_simple_momentum_training_model, flow.float32
)
model(x)
flow.checkpoint.save(save_dir)
model(x)
w1 = flow.get_all_variables()["w"].numpy()
refresh_session()
model = get_checkpoint_ready_model(
get_simple_momentum_training_model, flow.float32
)
flow.load_variables(flow.checkpoint.get(save_dir))
model(x)
w2 = flow.get_all_variables()["w"].numpy()
test_case.assertTrue(np.array_equal(w1, w2))
def _TestAssignmentBetweenMemory(test_case, dtype):
refresh_session()
model = get_checkpoint_ready_model(get_simple_model, dtype)
all_vars = flow.get_all_variables()
flow.load_variables({"x": all_vars["z"]})
flow_res = model()
np_res = all_vars["z"].numpy() * 2 + all_vars["y"].numpy()
test_case.assertTrue(np.allclose(flow_res, np_res))
class TestCheckpoint(flow.unittest.TestCase):
@flow.unittest.skip_unless_1n4d()
@unittest.skipIf(
flow.unittest.env.eager_execution_enabled(),
"legacy model io doesn't work in eager mode",
)
def test_save_correctness_1node_legacy_api(test_case):
_TestSaveCorrectness(test_case, get_simple_model, flow.float, True)
@flow.unittest.skip_unless_1n4d()
@unittest.skipIf(
flow.unittest.env.eager_execution_enabled(),
"legacy model io doesn't work in eager mode",
)
def test_load_correctness_1node_legacy_api(test_case):
_TestLoadCorrectness(test_case, get_simple_model, flow.float, True)
@flow.unittest.skip_unless_1n4d()
@unittest.skipIf(
flow.unittest.env.eager_execution_enabled(),
"legacy model io doesn't work in eager mode",
)
def test_save_correctness_1node(test_case):
for dtype in [flow.float, flow.double]:
_TestSaveCorrectness(test_case, get_large_model, dtype, False)
@flow.unittest.skip_unless_2n4d()
@unittest.skipIf(
flow.unittest.env.eager_execution_enabled(),
"legacy model io doesn't work in eager mode",
)
def test_save_correctness_2node(test_case):
_TestSaveCorrectness(test_case, get_large_model, flow.float, False)
@flow.unittest.skip_unless_1n4d()
@unittest.skipIf(
flow.unittest.env.eager_execution_enabled(),
"legacy model io doesn't work in eager mode",
)
def test_load_correctness_1node(test_case):
for dtype in [flow.float, flow.double]:
_TestLoadCorrectness(test_case, get_large_model, dtype, False)
@flow.unittest.skip_unless_2n4d()
@unittest.skipIf(
flow.unittest.env.eager_execution_enabled(),
"legacy model io doesn't work in eager mode",
)
def test_load_correctness_2node(test_case):
_TestLoadCorrectness(test_case, get_large_model, flow.float, False)
@flow.unittest.skip_unless_1n4d()
def test_assignment_between_memory(test_case):
_TestAssignmentBetweenMemory(test_case, flow.float)
@flow.unittest.skip_unless_1n4d()
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
"Save and load are covered by other tests in lazy mode",
)
def test_round_trip(test_case):
_TestRoundTrip(test_case, get_large_model, flow.float)
@flow.unittest.skip_unless_1n4d()
def test_partially_load_numpy(test_case):
_TestPartiallyLoadNumpy(test_case, flow.float)
@flow.unittest.skip_unless_1n2d()
def test_mixed_model(test_case):
_TestMixedModel(test_case, flow.float)
@flow.unittest.skip_unless_1n2d()
def test_resume_training(test_case):
_TestResumeTraining(test_case)
if __name__ == "__main__":
unittest.main()
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
import os
import shutil
import tempfile
import time
from collections import OrderedDict
from scipy.stats import ks_2samp
import numpy as np
import oneflow as flow
import oneflow.typing as tp
from test_util import GenArgDict
SHAPE = (4, 8, 5, 6)
def get_simple_model(dtype, initializer):
@flow.global_function()
def model() -> tp.Numpy:
x = flow.get_variable(
name="x", shape=SHAPE, dtype=dtype, initializer=initializer,
)
return x
return model
def CompareTwoDistribution(test_case, dtype, initializer):
flow.clear_default_session()
flow.config.enable_legacy_model_io(True)
model = get_simple_model(dtype, initializer)
flow.train.CheckPoint().init()
legacy_init_res = model()
flow.clear_default_session()
flow.config.enable_legacy_model_io(False)
model = get_simple_model(dtype, initializer)
new_init_res = model()
s = ks_2samp(legacy_init_res.flatten(), new_init_res.flatten())
pvalue = s.pvalue
test_case.assertGreater(pvalue, 0.0001, msg=initializer)
class TestInitializer(flow.unittest.TestCase):
def test_int_initializer(test_case):
initializers = [
flow.random_uniform_initializer(minval=-6, maxval=18, dtype=flow.int32),
flow.constant_initializer(value=4, dtype=flow.int32),
]
for initializer in initializers:
CompareTwoDistribution(test_case, flow.int32, initializer)
def test_float_initializer(test_case):
initializers = [
flow.random_normal_initializer(mean=3, stddev=4),
flow.random_uniform_initializer(minval=-6, maxval=18),
flow.truncated_normal_initializer(mean=-5, stddev=8),
flow.xavier_uniform_initializer(data_format="NCHW"),
flow.xavier_uniform_initializer(data_format="NHWC"),
flow.xavier_normal_initializer(data_format="NCHW"),
flow.xavier_normal_initializer(data_format="NHWC"),
flow.constant_initializer(value=4),
flow.ones_initializer(),
flow.zeros_initializer(),
]
kaiming_args = GenArgDict(
OrderedDict(
shape=[SHAPE],
mode=["fan_in", "fan_out", "fan_avg"],
distribution=["random_normal", "random_uniform"],
data_format=["NCHW", "NHWC"],
negative_slope=[0.5],
)
)
vs_args = GenArgDict(
OrderedDict(
scale=[3.4],
mode=["fan_in", "fan_out", "fan_avg"],
distribution=["truncated_normal", "random_normal", "random_uniform"],
data_format=["NCHW", "NHWC"],
)
)
for args in kaiming_args:
initializers.append(flow.kaiming_initializer(**args))
for args in vs_args:
initializers.append(flow.variance_scaling_initializer(**args))
for initializer in initializers:
CompareTwoDistribution(test_case, flow.float32, initializer)
if __name__ == "__main__":
unittest.main()
......@@ -115,7 +115,7 @@ class TestLayerNorm(flow.unittest.TestCase):
diff = dx_tf.numpy() - b.numpy()
max_diff = np.max(np.abs(diff))
if data_type == "float16":
tolerance = 2e-3
tolerance = 3e-3
else:
tolerance = 1e-5
assert np.allclose(
......
......@@ -81,7 +81,7 @@ def _compare_mseloss_with_np(
v = flow.get_variable(
shape=input.shape,
dtype=flow.float32,
initializer=flow.ones_initializer(),
initializer=flow.zeros_initializer(),
name="x_var",
)
x_var = of_input + v
......
......@@ -235,18 +235,23 @@ void WriteSlice(user_op::KernelComputeContext* ctx, const user_op::Tensor* src,
SliceIndexHelper<NDIM> sliced_splitted_large_idx_cvtr(large_slice_param.size);
SliceIndexHelper<NDIM> entire_full_small_idx_cvtr(small_slice_param.dims);
SliceIndexHelper<NDIM> sliced_full_small_idx_cvtr(small_slice_param.size);
// Calculate the length of continuous part
int cnt = 1;
for (int i = NDIM - 1; i >= 0; i--) {
if (large_slice_param.step[i] == 1) { cnt *= large_slice_param.size[i]; }
if (!large_slice_param.IsFullSlice(i) || !small_slice_param.IsFullSlice(i)) { break; }
}
const auto* src_ptr = src->dptr<T>();
auto* dst_ptr = dst->mut_dptr<T>();
FOR_RANGE(int, i, 0, elem_cnt) {
for (int i = 0; i < elem_cnt; i += cnt) {
const int64_t large_offset = SliceOffsetToEntireOffset<NDIM>(
i, large_slice_param, entire_splitted_large_idx_cvtr, sliced_splitted_large_idx_cvtr);
const int64_t small_offset = SliceOffsetToEntireOffset<NDIM>(
i, small_slice_param, entire_full_small_idx_cvtr, sliced_full_small_idx_cvtr);
const int64_t src_offset = from_large_to_small ? large_offset : small_offset;
const int64_t dst_offset = from_large_to_small ? small_offset : large_offset;
// TODO(jianhao): optimize the performance by dedicated kernels
AutoMemcpy(ctx->device_ctx(), dst_ptr + dst_offset, src_ptr + src_offset,
GetSizeOfDataType(src->data_type()), src->mem_case(), dst->mem_case());
cnt * GetSizeOfDataType(src->data_type()), src->mem_case(), dst->mem_case());
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册