未验证 提交 cfee9c13 编写于 作者: L Ligoml 提交者: GitHub

[cherry-pick2.4]for CodeStyle (#47608)

* only run pre-commit

* only run pre-commit
上级 99c872fa
...@@ -124,9 +124,7 @@ class LegacyPyLayerContext(object): ...@@ -124,9 +124,7 @@ class LegacyPyLayerContext(object):
def with_mateclass(meta, *bases): def with_mateclass(meta, *bases):
class impl(meta): class impl(meta):
def __new__(cls, name, temp_bases, attrs): def __new__(cls, name, temp_bases, attrs):
return meta(name, bases, attrs) return meta(name, bases, attrs)
...@@ -134,7 +132,6 @@ def with_mateclass(meta, *bases): ...@@ -134,7 +132,6 @@ def with_mateclass(meta, *bases):
class CPyLayer(object): class CPyLayer(object):
@classmethod @classmethod
@dygraph_only @dygraph_only
def apply(cls, *args, **kwargs): def apply(cls, *args, **kwargs):
...@@ -182,12 +179,14 @@ class CPyLayer(object): ...@@ -182,12 +179,14 @@ class CPyLayer(object):
class PyLayerBackward(LegacyPyLayerContext): class PyLayerBackward(LegacyPyLayerContext):
def backward(self, *args, **kwargs): def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
if self._amp_state and 'enable' in self._amp_state and self._amp_state[ if (
'enable']: self._amp_state
and 'enable' in self._amp_state
and self._amp_state['enable']
):
with auto_cast(**args[0]._amp_state): with auto_cast(**args[0]._amp_state):
return self._forward_cls.backward(*args, **kwargs) return self._forward_cls.backward(*args, **kwargs)
else: else:
...@@ -197,10 +196,10 @@ class PyLayerBackward(LegacyPyLayerContext): ...@@ -197,10 +196,10 @@ class PyLayerBackward(LegacyPyLayerContext):
class LayerMeta(type): class LayerMeta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls._backward_function = type(name + '_backward', (PyLayerBackward, ), cls._backward_function = type(
{"_forward_cls": cls}) name + '_backward', (PyLayerBackward,), {"_forward_cls": cls}
)
return super(LayerMeta, cls).__init__(name, bases, attrs) return super(LayerMeta, cls).__init__(name, bases, attrs)
...@@ -292,7 +291,8 @@ class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)): ...@@ -292,7 +291,8 @@ class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)):
return grad return grad
""" """
raise NotImplementedError( raise NotImplementedError(
"You must implement the forward function for PyLayer.") "You must implement the forward function for PyLayer."
)
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -332,11 +332,11 @@ class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)): ...@@ -332,11 +332,11 @@ class LegacyPyLayer(with_mateclass(LayerMeta, CPyLayer)):
""" """
raise NotImplementedError( raise NotImplementedError(
"You must implement the backward function for PyLayer.") "You must implement the backward function for PyLayer."
)
class EagerPyLayerContext(object): class EagerPyLayerContext(object):
def save_for_backward(self, *tensors): def save_for_backward(self, *tensors):
""" """
Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
...@@ -542,25 +542,22 @@ class EagerPyLayerContext(object): ...@@ -542,25 +542,22 @@ class EagerPyLayerContext(object):
class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext): class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext):
def backward(self, *args): def backward(self, *args):
return self._forward_cls.backward(self, *args) return self._forward_cls.backward(self, *args)
class EagerPyLayerMeta(type): class EagerPyLayerMeta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls._backward_function = type(name + '_backward', cls._backward_function = type(
(EagerPyLayerBackward, ), name + '_backward', (EagerPyLayerBackward,), {"_forward_cls": cls}
{"_forward_cls": cls}) )
return super(EagerPyLayerMeta, cls).__init__(name, bases, attrs) return super(EagerPyLayerMeta, cls).__init__(name, bases, attrs)
class EagerPyLayer( class EagerPyLayer(
with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, EagerPyLayerContext)
EagerPyLayerContext)): ):
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
""" """
...@@ -597,7 +594,8 @@ class EagerPyLayer( ...@@ -597,7 +594,8 @@ class EagerPyLayer(
return grad return grad
""" """
raise NotImplementedError( raise NotImplementedError(
"You must implement the forward function for PyLayer.") "You must implement the forward function for PyLayer."
)
@staticmethod @staticmethod
def backward(ctx, *args): def backward(ctx, *args):
...@@ -637,11 +635,11 @@ class EagerPyLayer( ...@@ -637,11 +635,11 @@ class EagerPyLayer(
""" """
raise NotImplementedError( raise NotImplementedError(
"You must implement the backward function for PyLayer.") "You must implement the backward function for PyLayer."
)
def once_differentiable(backward): def once_differentiable(backward):
def wrapper(ctx, *args): def wrapper(ctx, *args):
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
outputs = backward(ctx, *args) outputs = backward(ctx, *args)
......
...@@ -124,8 +124,11 @@ def device_count(): ...@@ -124,8 +124,11 @@ def device_count():
''' '''
num_gpus = core.get_cuda_device_count() if hasattr( num_gpus = (
core, 'get_cuda_device_count') else 0 core.get_cuda_device_count()
if hasattr(core, 'get_cuda_device_count')
else 0
)
return num_gpus return num_gpus
...@@ -165,7 +168,7 @@ def extract_cuda_device_id(device, op_name): ...@@ -165,7 +168,7 @@ def extract_cuda_device_id(device, op_name):
Return: Return:
int: The id of the given device. If device is None, return the id of current device. int: The id of the given device. If device is None, return the id of current device.
''' '''
if (device is None): if device is None:
return core.get_cuda_current_device_id() return core.get_cuda_current_device_id()
if isinstance(device, int): if isinstance(device, int):
...@@ -178,15 +181,19 @@ def extract_cuda_device_id(device, op_name): ...@@ -178,15 +181,19 @@ def extract_cuda_device_id(device, op_name):
else: else:
raise ValueError( raise ValueError(
"The current string {} is not expected. Because {} only support string which is like 'gpu:x'. " "The current string {} is not expected. Because {} only support string which is like 'gpu:x'. "
"Please input appropriate string again!".format( "Please input appropriate string again!".format(device, op_name)
device, op_name)) )
else: else:
raise ValueError( raise ValueError(
"The device type {} is not expected. Because {} only support int, str or paddle.CUDAPlace. " "The device type {} is not expected. Because {} only support int, str or paddle.CUDAPlace. "
"Please input appropriate device again!".format(device, op_name)) "Please input appropriate device again!".format(device, op_name)
)
assert device_id >= 0, f"The device id must be not less than 0, but got id = {device_id}." assert (
assert device_id < device_count( device_id >= 0
), f"The device id must be not less than 0, but got id = {device_id}."
assert (
device_id < device_count()
), f"The device id {device_id} exceeds gpu card number {device_count()}" ), f"The device id {device_id} exceeds gpu card number {device_count()}"
return device_id return device_id
...@@ -424,7 +431,8 @@ def get_device_properties(device=None): ...@@ -424,7 +431,8 @@ def get_device_properties(device=None):
raise ValueError( raise ValueError(
"The API paddle.device.cuda.get_device_properties is not supported in " "The API paddle.device.cuda.get_device_properties is not supported in "
"CPU-only PaddlePaddle. Please reinstall PaddlePaddle with GPU support " "CPU-only PaddlePaddle. Please reinstall PaddlePaddle with GPU support "
"to call this API.") "to call this API."
)
if device is not None: if device is not None:
if isinstance(device, int): if isinstance(device, int):
...@@ -438,12 +446,14 @@ def get_device_properties(device=None): ...@@ -438,12 +446,14 @@ def get_device_properties(device=None):
raise ValueError( raise ValueError(
"The current string {} is not expected. Because paddle.device." "The current string {} is not expected. Because paddle.device."
"cuda.get_device_properties only support string which is like 'gpu:x'. " "cuda.get_device_properties only support string which is like 'gpu:x'. "
"Please input appropriate string again!".format(device)) "Please input appropriate string again!".format(device)
)
else: else:
raise ValueError( raise ValueError(
"The device type {} is not expected. Because paddle.device.cuda." "The device type {} is not expected. Because paddle.device.cuda."
"get_device_properties only support int, str or paddle.CUDAPlace. " "get_device_properties only support int, str or paddle.CUDAPlace. "
"Please input appropriate device again!".format(device)) "Please input appropriate device again!".format(device)
)
else: else:
device_id = -1 device_id = -1
......
...@@ -40,8 +40,9 @@ def wait_server_ready(endpoints): ...@@ -40,8 +40,9 @@ def wait_server_ready(endpoints):
not_ready_endpoints = [] not_ready_endpoints = []
for ep in endpoints: for ep in endpoints:
ip_port = ep.split(":") ip_port = ep.split(":")
with closing(socket.socket(socket.AF_INET, with closing(
socket.SOCK_STREAM)) as sock: socket.socket(socket.AF_INET, socket.SOCK_STREAM)
) as sock:
sock.settimeout(2) sock.settimeout(2)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if hasattr(socket, 'SO_REUSEPORT'): if hasattr(socket, 'SO_REUSEPORT'):
...@@ -53,8 +54,9 @@ def wait_server_ready(endpoints): ...@@ -53,8 +54,9 @@ def wait_server_ready(endpoints):
not_ready_endpoints.append(ep) not_ready_endpoints.append(ep)
if not all_ok: if not all_ok:
sys.stderr.write("server not ready, wait 3 sec to retry...\n") sys.stderr.write("server not ready, wait 3 sec to retry...\n")
sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + sys.stderr.write(
"\n") "not ready endpoints:" + str(not_ready_endpoints) + "\n"
)
sys.stderr.flush() sys.stderr.flush()
time.sleep(3) time.sleep(3)
else: else:
......
...@@ -30,7 +30,9 @@ from paddle.fluid.framework import _set_expected_place ...@@ -30,7 +30,9 @@ from paddle.fluid.framework import _set_expected_place
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
from paddle.distributed.fleet.launch_utils import check_backend from paddle.distributed.fleet.launch_utils import check_backend
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.distributed.fleet.base.private_helper_function import wait_server_ready # noqa: F401 from paddle.distributed.fleet.base.private_helper_function import (
wait_server_ready,
) # noqa: F401
from paddle.distributed import collective from paddle.distributed import collective
from paddle.distributed.collective import _set_group_map from paddle.distributed.collective import _set_group_map
from paddle.distributed.collective import _set_group_map_by_name from paddle.distributed.collective import _set_group_map_by_name
...@@ -63,6 +65,7 @@ def _get_global_parallel_env(): ...@@ -63,6 +65,7 @@ def _get_global_parallel_env():
def _start_kv_server(port, http_server_d, size): def _start_kv_server(port, http_server_d, size):
from paddle.distributed.fleet.utils.http_server import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(port), size=size) http_server = KVServer(int(port), size=size)
http_server.start() http_server.start()
wait_seconds = 3 wait_seconds = 3
...@@ -73,10 +76,15 @@ def _start_kv_server(port, http_server_d, size): ...@@ -73,10 +76,15 @@ def _start_kv_server(port, http_server_d, size):
def _is_cpuonly(backend): def _is_cpuonly(backend):
check_backend(backend) check_backend(backend)
if (backend in ['auto', 'nccl', 'bkcl', 'hccl', 'heter', 'cncl'] and if (
(core.is_compiled_with_cuda() or core.is_compiled_with_xpu() backend in ['auto', 'nccl', 'bkcl', 'hccl', 'heter', 'cncl']
and (
core.is_compiled_with_cuda()
or core.is_compiled_with_xpu()
or core.is_compiled_with_npu() or core.is_compiled_with_npu()
or core.is_compiled_with_mlu())) or backend is 'xccl': or core.is_compiled_with_mlu()
)
) or backend is 'xccl':
# passes 'auto' and can use cuda or xpu, use the default logics. so return False # passes 'auto' and can use cuda or xpu, use the default logics. so return False
return False return False
...@@ -87,9 +95,10 @@ def _is_cpuonly(backend): ...@@ -87,9 +95,10 @@ def _is_cpuonly(backend):
def _check_var_exists(var_name): def _check_var_exists(var_name):
var = os.environ.get(var_name, None) var = os.environ.get(var_name, None)
if var is None: if var is None:
raise ValueError("paddle.distributed initialize error, " raise ValueError(
"environment variable %s is needed, but not set." % "paddle.distributed initialize error, "
var_name) "environment variable %s is needed, but not set." % var_name
)
def init_parallel_env(): def init_parallel_env():
...@@ -167,15 +176,21 @@ def init_parallel_env(): ...@@ -167,15 +176,21 @@ def init_parallel_env():
backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto') backend = os.environ.get('PADDLE_DISTRI_BACKEND', 'auto')
is_cpu_only = _is_cpuonly(backend) is_cpu_only = _is_cpuonly(backend)
# 1. gpu xpu check, must be gpu or xpu, # 1. gpu xpu check, must be gpu or xpu,
if not (is_cpu_only or core.is_compiled_with_cuda() if not (
or core.is_compiled_with_xpu() or core.is_compiled_with_npu() is_cpu_only
or core.is_compiled_with_mlu()): or core.is_compiled_with_cuda()
or core.is_compiled_with_xpu()
or core.is_compiled_with_npu()
or core.is_compiled_with_mlu()
):
raise NotImplementedError( raise NotImplementedError(
"If you want to use CPU-only version, please use 'gloo' as backend") "If you want to use CPU-only version, please use 'gloo' as backend"
)
if backend == "xccl": if backend == "xccl":
FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format( FLAGS_selected_custom_devices = 'FLAGS_selected_{}s'.format(
parallel_env.device_type) parallel_env.device_type
)
_check_var_exists(FLAGS_selected_custom_devices) _check_var_exists(FLAGS_selected_custom_devices)
else: else:
if not is_cpu_only and core.is_compiled_with_cuda(): if not is_cpu_only and core.is_compiled_with_cuda():
...@@ -203,8 +218,9 @@ def init_parallel_env(): ...@@ -203,8 +218,9 @@ def init_parallel_env():
# they need to call a function to change default place, # they need to call a function to change default place,
# here just set correctly place to users # here just set correctly place to users
if backend == "xccl": if backend == "xccl":
place = core.CustomPlace(parallel_env.device_type, place = core.CustomPlace(
parallel_env.device_id) parallel_env.device_type, parallel_env.device_id
)
elif is_cpu_only: elif is_cpu_only:
place = core.CPUPlace() place = core.CPUPlace()
elif core.is_compiled_with_cuda(): elif core.is_compiled_with_cuda():
...@@ -228,11 +244,15 @@ def init_parallel_env(): ...@@ -228,11 +244,15 @@ def init_parallel_env():
assert rank >= 0 and world_size > rank and world_size > 1, ( assert rank >= 0 and world_size > rank and world_size > 1, (
"rank must be non-negative and world_size must be the " "rank must be non-negative and world_size must be the "
"maximum rank plus one. Moreover, at least two processes are " "maximum rank plus one. Moreover, at least two processes are "
"required to create a process group.") "required to create a process group."
)
master_addr = os.getenv("MASTER_ADDR", None) master_addr = os.getenv("MASTER_ADDR", None)
master_port = os.getenv("MASTER_PORT", None) master_port = os.getenv("MASTER_PORT", None)
endpoints = ":".join([master_addr, master_port endpoints = (
]) if master_addr and master_port else None ":".join([master_addr, master_port])
if master_addr and master_port
else None
)
if endpoints is None: if endpoints is None:
endpoints = os.getenv("PADDLE_MASTER", None) endpoints = os.getenv("PADDLE_MASTER", None)
if endpoints is None: if endpoints is None:
...@@ -241,23 +261,28 @@ def init_parallel_env(): ...@@ -241,23 +261,28 @@ def init_parallel_env():
"The environment variable 'MASTER_ADDR' and 'MASTER_PORT' " "The environment variable 'MASTER_ADDR' and 'MASTER_PORT' "
"must be specified, for example 'export MASTER_ADDR=127.0.0.1' " "must be specified, for example 'export MASTER_ADDR=127.0.0.1' "
"and 'export MASTER_ADDR=54612'. Or you can start your training" "and 'export MASTER_ADDR=54612'. Or you can start your training"
"with paddle.distributed.run module.") "with paddle.distributed.run module."
)
master_addr, master_port = endpoints.split(":") master_addr, master_port = endpoints.split(":")
master_port = int(master_port) master_port = int(master_port)
is_master = rank == 0 is_master = rank == 0
stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900")) stop_check_timeout = int(os.getenv("FLAGS_stop_check_timeout", "900"))
default_store = core.TCPStore(master_addr, default_store = core.TCPStore(
master_addr,
master_port, master_port,
is_master, is_master,
world_size, world_size,
timeout=stop_check_timeout) timeout=stop_check_timeout,
)
_set_default_store(default_store) _set_default_store(default_store)
pg = _new_process_group_impl(backend, pg = _new_process_group_impl(
backend,
default_store, default_store,
rank, rank,
world_size, world_size,
_default_group_name, _default_group_name,
pg_options=None) pg_options=None,
)
ranks = list(range(world_size)) ranks = list(range(world_size))
group = Group(rank, 0, ranks, pg=pg, name=_default_group_name) group = Group(rank, 0, ranks, pg=pg, name=_default_group_name)
_set_group_map_by_name(_default_group_name, group) _set_group_map_by_name(_default_group_name, group)
...@@ -283,8 +308,10 @@ def init_parallel_env(): ...@@ -283,8 +308,10 @@ def init_parallel_env():
size = {'_worker': parallel_env.world_size} size = {'_worker': parallel_env.world_size}
if backend == "heter": if backend == "heter":
size = {'_worker': len(node_num)} size = {'_worker': len(node_num)}
http_server = Process(target=_start_kv_server, http_server = Process(
args=(int(ep_rank_0[1]), http_server_d, size)) target=_start_kv_server,
args=(int(ep_rank_0[1]), http_server_d, size),
)
http_server.daemon = True http_server.daemon = True
http_server_d["running"] = True http_server_d["running"] = True
http_server.start() http_server.start()
...@@ -302,22 +329,28 @@ def init_parallel_env(): ...@@ -302,22 +329,28 @@ def init_parallel_env():
# init nccl or hccl or bkcl or heter context # init nccl or hccl or bkcl or heter context
if is_cpu_only: if is_cpu_only:
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.GLOOParallelContext(strategy, place)) core.GLOOParallelContext(strategy, place)
elif (backend == "heter"): )
elif backend == "heter":
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.HeterParallelContext(strategy, parallel_env.device_id)) core.HeterParallelContext(strategy, parallel_env.device_id)
)
elif core.is_compiled_with_cuda(): elif core.is_compiled_with_cuda():
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.NCCLParallelContext(strategy, place)) core.NCCLParallelContext(strategy, place)
)
elif core.is_compiled_with_xpu(): elif core.is_compiled_with_xpu():
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.BKCLParallelContext(strategy, place)) core.BKCLParallelContext(strategy, place)
)
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.HCCLParallelContext(strategy, place)) core.HCCLParallelContext(strategy, place)
)
elif core.is_compiled_with_mlu(): elif core.is_compiled_with_mlu():
parallel_helper._set_parallel_ctx( parallel_helper._set_parallel_ctx(
core.CNCLParallelContext(strategy, place)) core.CNCLParallelContext(strategy, place)
)
if backend != "heter": if backend != "heter":
other_endpoints = strategy.trainer_endpoints[:] other_endpoints = strategy.trainer_endpoints[:]
......
...@@ -23,21 +23,38 @@ from paddle.distributed.utils.log_utils import get_logger ...@@ -23,21 +23,38 @@ from paddle.distributed.utils.log_utils import get_logger
from paddle.fluid.framework import in_dygraph_mode from paddle.fluid.framework import in_dygraph_mode
# Old version # Old version
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2 from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import (
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import ShardingStage2 ShardingOptimizerStage2,
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import ShardingStage3 )
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import ShardingScaler from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage2 import (
ShardingStage2,
)
from paddle.distributed.fleet.meta_parallel.sharding.sharding_stage3 import (
ShardingStage3,
)
from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import (
ShardingScaler,
)
# New version # New version
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import GroupShardedOptimizerStage2 from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_optimizer_stage2 import (
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import GroupShardedStage2 GroupShardedOptimizerStage2,
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import GroupShardedStage3 )
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import GroupShardedScaler from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage2 import (
GroupShardedStage2,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import (
GroupShardedStage3,
)
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler,
)
logger_ = get_logger(logging.WARNING) logger_ = get_logger(logging.WARNING)
def group_sharded_parallel(model, def group_sharded_parallel(
model,
optimizer, optimizer,
level, level,
scaler=None, scaler=None,
...@@ -46,7 +63,8 @@ def group_sharded_parallel(model, ...@@ -46,7 +63,8 @@ def group_sharded_parallel(model,
sync_buffers=False, sync_buffers=False,
buffer_max_size=2**23, buffer_max_size=2**23,
segment_size=2**20, segment_size=2**20,
sync_comm=False): sync_comm=False,
):
""" """
Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation. Use group_sharded_parallel can perform group shared configuration on the model, optimizer and GradScaler. Level has three string options, 'os', 'os_g' and 'p_g_os' corresponds to three different usage scenarios: optimizer state segmentation, optimizer state + gradient segmentation, and parameter + gradient + optimizer state segmentation.
Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation. Usually, optimizer state + gradient segmentation is actually a re optimization of optimizer state segmentation, so optimizer state + gradient segmentation can be used to realize optimizer state segmentation.
...@@ -100,13 +118,16 @@ def group_sharded_parallel(model, ...@@ -100,13 +118,16 @@ def group_sharded_parallel(model,
""" """
# check optition type # check optition type
assert isinstance( assert isinstance(
model, model, paddle.nn.Layer
paddle.nn.Layer), "The model must be the instance of paddle.nn.Layer." ), "The model must be the instance of paddle.nn.Layer."
assert isinstance( assert isinstance(
optimizer, Optimizer optimizer, Optimizer
), "The optimizer must be the instance of paddle.optimizer.Optimizer." ), "The optimizer must be the instance of paddle.optimizer.Optimizer."
assert level in ['os', 'os_g', assert level in [
'p_g_os'], "The level must be os, os_g or p_g_os." 'os',
'os_g',
'p_g_os',
], "The level must be os, os_g or p_g_os."
def check_dtype(param): def check_dtype(param):
return param.dtype == paddle.float16 return param.dtype == paddle.float16
...@@ -124,39 +145,50 @@ def group_sharded_parallel(model, ...@@ -124,39 +145,50 @@ def group_sharded_parallel(model,
params=optimizer._parameter_list, params=optimizer._parameter_list,
optim=optimizer, optim=optimizer,
group=group, group=group,
offload=offload) offload=offload,
model = GroupShardedStage2(model, )
model = GroupShardedStage2(
model,
optimizer, optimizer,
group=group, group=group,
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size) buffer_max_size=buffer_max_size,
)
else: else:
optimizer = ShardingOptimizerStage2(params=model.parameters(), optimizer = ShardingOptimizerStage2(
params=model.parameters(),
optim=optimizer, optim=optimizer,
group=group, group=group,
offload=offload) offload=offload,
model = ShardingStage2(model, )
model = ShardingStage2(
model,
optimizer, optimizer,
group=group, group=group,
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
buffer_max_size=buffer_max_size) buffer_max_size=buffer_max_size,
)
elif level == 'p_g_os': elif level == 'p_g_os':
if in_dygraph_mode(): if in_dygraph_mode():
model = GroupShardedStage3(model, model = GroupShardedStage3(
model,
optimizer=optimizer, optimizer=optimizer,
group=group, group=group,
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
segment_size=segment_size, segment_size=segment_size,
offload=offload, offload=offload,
sync_comm=sync_comm) sync_comm=sync_comm,
)
else: else:
model = ShardingStage3(model, model = ShardingStage3(
model,
optimizer=optimizer, optimizer=optimizer,
group=group, group=group,
sync_buffers=sync_buffers, sync_buffers=sync_buffers,
segment_size=segment_size, segment_size=segment_size,
offload=offload, offload=offload,
sync_comm=sync_comm) sync_comm=sync_comm,
)
else: else:
raise ValueError("Please enter the correct level.") raise ValueError("Please enter the correct level.")
if isinstance(scaler, paddle.amp.GradScaler): if isinstance(scaler, paddle.amp.GradScaler):
...@@ -219,7 +251,8 @@ def save_group_sharded_model(model, output, optimizer=None): ...@@ -219,7 +251,8 @@ def save_group_sharded_model(model, output, optimizer=None):
save_group_sharded_model(model, optimizer, output=output_dir) save_group_sharded_model(model, optimizer, output=output_dir)
""" """
logger_.info( logger_.info(
"==========Begin to save group sharded model and optimizer==========") "==========Begin to save group sharded model and optimizer=========="
)
assert not os.path.isfile( assert not os.path.isfile(
output output
), "Saving directory ({}) should be a directory, not a file".format(output) ), "Saving directory ({}) should be a directory, not a file".format(output)
...@@ -243,4 +276,5 @@ def save_group_sharded_model(model, output, optimizer=None): ...@@ -243,4 +276,5 @@ def save_group_sharded_model(model, output, optimizer=None):
output_opt = os.path.join(output, "model.pdopt") output_opt = os.path.join(output, "model.pdopt")
paddle.save(optimizer._optim.state_dict(), output_opt) paddle.save(optimizer._optim.state_dict(), output_opt)
logger_.info( logger_.info(
"==========End to save group sharded model and optimizer==========") "==========End to save group sharded model and optimizer=========="
)
...@@ -28,12 +28,27 @@ import numpy as np ...@@ -28,12 +28,27 @@ import numpy as np
import paddle import paddle
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (
check_variable_and_dtype, convert_dtype) check_dtype,
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph check_type,
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, check_variable_and_dtype,
elementwise_mul, elementwise_sub, nn, ops, convert_dtype,
tensor) )
from paddle.fluid.framework import (
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
)
from paddle.fluid.layers import (
control_flow,
elementwise_add,
elementwise_div,
elementwise_mul,
elementwise_sub,
nn,
ops,
tensor,
)
from paddle.tensor import arange, concat, gather_nd, multinomial from paddle.tensor import arange, concat, gather_nd, multinomial
...@@ -53,10 +68,16 @@ class Distribution(object): ...@@ -53,10 +68,16 @@ class Distribution(object):
def __init__(self, batch_shape=(), event_shape=()): def __init__(self, batch_shape=(), event_shape=()):
self._batch_shape = batch_shape if isinstance( self._batch_shape = (
batch_shape, tuple) else tuple(batch_shape) batch_shape
self._event_shape = event_shape if isinstance( if isinstance(batch_shape, tuple)
event_shape, tuple) else tuple(event_shape) else tuple(batch_shape)
)
self._event_shape = (
event_shape
if isinstance(event_shape, tuple)
else tuple(event_shape)
)
super(Distribution, self).__init__() super(Distribution, self).__init__()
...@@ -155,7 +176,8 @@ class Distribution(object): ...@@ -155,7 +176,8 @@ class Distribution(object):
if is_variable and is_number: if is_variable and is_number:
raise ValueError( raise ValueError(
'if one argument is Tensor, all arguments should be Tensor') 'if one argument is Tensor, all arguments should be Tensor'
)
return is_variable return is_variable
...@@ -170,15 +192,17 @@ class Distribution(object): ...@@ -170,15 +192,17 @@ class Distribution(object):
""" """
numpy_args = [] numpy_args = []
variable_args = [] variable_args = []
tmp = 0. tmp = 0.0
for arg in args: for arg in args:
if isinstance(arg, float): if isinstance(arg, float):
arg = [arg] arg = [arg]
if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)): if not isinstance(arg, (list, tuple, np.ndarray, tensor.Variable)):
raise TypeError( raise TypeError(
"Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}" "Type of input args must be float, list, numpy.ndarray or Tensor, but received type {}".format(
.format(type(arg))) type(arg)
)
)
arg_np = np.array(arg) arg_np = np.array(arg)
arg_dtype = arg_np.dtype arg_dtype = arg_np.dtype
...@@ -216,20 +240,24 @@ class Distribution(object): ...@@ -216,20 +240,24 @@ class Distribution(object):
value (Tensor): Change value's dtype if value's dtype is different from param. value (Tensor): Change value's dtype if value's dtype is different from param.
""" """
if _non_static_mode(): if _non_static_mode():
if value.dtype != param.dtype and convert_dtype( if value.dtype != param.dtype and convert_dtype(value.dtype) in [
value.dtype) in ['float32', 'float64']: 'float32',
'float64',
]:
warnings.warn( warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
) )
if in_dygraph_mode(): if in_dygraph_mode():
return _C_ops.cast(value, param.dtype) return _C_ops.cast(value, param.dtype)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
return _legacy_C_ops.cast(value, 'in_dtype', value.dtype, return _legacy_C_ops.cast(
'out_dtype', param.dtype) value, 'in_dtype', value.dtype, 'out_dtype', param.dtype
)
return value return value
check_variable_and_dtype(value, 'value', ['float32', 'float64'], check_variable_and_dtype(
'log_prob') value, 'value', ['float32', 'float64'], 'log_prob'
)
if value.dtype != param.dtype: if value.dtype != param.dtype:
warnings.warn( warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted." "dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
...@@ -244,8 +272,11 @@ class Distribution(object): ...@@ -244,8 +272,11 @@ class Distribution(object):
multi-dimensional, values of last axis denote the probabilities of multi-dimensional, values of last axis denote the probabilities of
occurrence of each of the events. occurrence of each of the events.
""" """
return (paddle.log(probs) - paddle.log1p(-probs)) \ return (
if is_binary else paddle.log(probs) (paddle.log(probs) - paddle.log1p(-probs))
if is_binary
else paddle.log(probs)
)
def _logits_to_probs(self, logits, is_binary=False): def _logits_to_probs(self, logits, is_binary=False):
r""" r"""
...@@ -253,5 +284,8 @@ class Distribution(object): ...@@ -253,5 +284,8 @@ class Distribution(object):
log odds, whereas for the multi-dimensional case, the values along the log odds, whereas for the multi-dimensional case, the values along the
last dimension denote the log probabilities of the events. last dimension denote the log probabilities of the events.
""" """
return paddle.nn.functional.sigmoid(logits) \ return (
if is_binary else paddle.nn.functional.softmax(logits, axis=-1) paddle.nn.functional.sigmoid(logits)
if is_binary
else paddle.nn.functional.softmax(logits, axis=-1)
)
...@@ -83,8 +83,9 @@ def register_kl(cls_p, cls_q): ...@@ -83,8 +83,9 @@ def register_kl(cls_p, cls_q):
def kl_beta_beta(): def kl_beta_beta():
pass # insert implementation here pass # insert implementation here
""" """
if (not issubclass(cls_p, Distribution) if not issubclass(cls_p, Distribution) or not issubclass(
or not issubclass(cls_q, Distribution)): cls_q, Distribution
):
raise TypeError('cls_p and cls_q must be subclass of Distribution') raise TypeError('cls_p and cls_q must be subclass of Distribution')
def decorator(f): def decorator(f):
...@@ -98,8 +99,11 @@ def _dispatch(cls_p, cls_q): ...@@ -98,8 +99,11 @@ def _dispatch(cls_p, cls_q):
"""Multiple dispatch into concrete implement function""" """Multiple dispatch into concrete implement function"""
# find all matched super class pair of p and q # find all matched super class pair of p and q
matchs = [(super_p, super_q) for super_p, super_q in _REGISTER_TABLE matchs = [
if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)] (super_p, super_q)
for super_p, super_q in _REGISTER_TABLE
if issubclass(cls_p, super_p) and issubclass(cls_q, super_q)
]
if not matchs: if not matchs:
raise NotImplementedError raise NotImplementedError
...@@ -108,16 +112,20 @@ def _dispatch(cls_p, cls_q): ...@@ -108,16 +112,20 @@ def _dispatch(cls_p, cls_q):
if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]: if _REGISTER_TABLE[left_p, left_q] is not _REGISTER_TABLE[right_p, right_q]:
warnings.warn( warnings.warn(
'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'. 'Ambiguous kl_divergence({}, {}). Please register_kl({}, {})'.format(
format(cls_p.__name__, cls_q.__name__, left_p.__name__, cls_p.__name__,
right_q.__name__), RuntimeWarning) cls_q.__name__,
left_p.__name__,
right_q.__name__,
),
RuntimeWarning,
)
return _REGISTER_TABLE[left_p, left_q] return _REGISTER_TABLE[left_p, left_q]
@functools.total_ordering @functools.total_ordering
class _Compare(object): class _Compare(object):
def __init__(self, *classes): def __init__(self, *classes):
self.classes = classes self.classes = classes
...@@ -135,22 +143,33 @@ class _Compare(object): ...@@ -135,22 +143,33 @@ class _Compare(object):
@register_kl(Beta, Beta) @register_kl(Beta, Beta)
def _kl_beta_beta(p, q): def _kl_beta_beta(p, q):
return ((q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma()) - return (
(p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma()) + (q.alpha.lgamma() + q.beta.lgamma() + (p.alpha + p.beta).lgamma())
((p.alpha - q.alpha) * p.alpha.digamma()) + - (p.alpha.lgamma() + p.beta.lgamma() + (q.alpha + q.beta).lgamma())
((p.beta - q.beta) * p.beta.digamma()) + + ((p.alpha - q.alpha) * p.alpha.digamma())
(((q.alpha + q.beta) - (p.alpha + p.beta)) * + ((p.beta - q.beta) * p.beta.digamma())
(p.alpha + p.beta).digamma())) + (
((q.alpha + q.beta) - (p.alpha + p.beta))
* (p.alpha + p.beta).digamma()
)
)
@register_kl(Dirichlet, Dirichlet) @register_kl(Dirichlet, Dirichlet)
def _kl_dirichlet_dirichlet(p, q): def _kl_dirichlet_dirichlet(p, q):
return ( return (
(p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma()) - (p.concentration.sum(-1).lgamma() - q.concentration.sum(-1).lgamma())
((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1)) + - ((p.concentration.lgamma() - q.concentration.lgamma()).sum(-1))
(((p.concentration - q.concentration) * + (
(p.concentration.digamma() - (
p.concentration.sum(-1).digamma().unsqueeze(-1))).sum(-1))) (p.concentration - q.concentration)
* (
p.concentration.digamma()
- p.concentration.sum(-1).digamma().unsqueeze(-1)
)
).sum(-1)
)
)
@register_kl(Categorical, Categorical) @register_kl(Categorical, Categorical)
...@@ -170,8 +189,7 @@ def _kl_uniform_uniform(p, q): ...@@ -170,8 +189,7 @@ def _kl_uniform_uniform(p, q):
@register_kl(ExponentialFamily, ExponentialFamily) @register_kl(ExponentialFamily, ExponentialFamily)
def _kl_expfamily_expfamily(p, q): def _kl_expfamily_expfamily(p, q):
"""Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_ """Compute kl-divergence using `Bregman divergences <https://www.lix.polytechnique.fr/~nielsen/EntropyEF-ICIP2010.pdf>`_"""
"""
if not type(p) == type(q): if not type(p) == type(q):
raise NotImplementedError raise NotImplementedError
...@@ -187,19 +205,22 @@ def _kl_expfamily_expfamily(p, q): ...@@ -187,19 +205,22 @@ def _kl_expfamily_expfamily(p, q):
try: try:
if _non_static_mode(): if _non_static_mode():
p_grads = paddle.grad(p_log_norm, p_grads = paddle.grad(
p_natural_params, p_log_norm, p_natural_params, create_graph=True
create_graph=True) )
else: else:
p_grads = paddle.static.gradients(p_log_norm, p_natural_params) p_grads = paddle.static.gradients(p_log_norm, p_natural_params)
except RuntimeError as e: except RuntimeError as e:
raise TypeError( raise TypeError(
"Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q})." "Cann't compute kl_divergence({cls_p}, {cls_q}) use bregman divergence. Please register_kl({cls_p}, {cls_q}).".format(
.format(cls_p=type(p).__name__, cls_q=type(q).__name__)) from e cls_p=type(p).__name__, cls_q=type(q).__name__
)
) from e
kl = q._log_normalizer(*q_natural_params) - p_log_norm kl = q._log_normalizer(*q_natural_params) - p_log_norm
for p_param, q_param, p_grad in zip(p_natural_params, q_natural_params, for p_param, q_param, p_grad in zip(
p_grads): p_natural_params, q_natural_params, p_grads
):
term = (q_param - p_param) * p_grad term = (q_param - p_param) * p_grad
kl -= _sum_rightmost(term, len(q.event_shape)) kl -= _sum_rightmost(term, len(q.event_shape))
......
...@@ -19,12 +19,23 @@ import numpy as np ...@@ -19,12 +19,23 @@ import numpy as np
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.distribution import distribution from paddle.distribution import distribution
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (
check_variable_and_dtype, convert_dtype) check_dtype,
check_type,
check_variable_and_dtype,
convert_dtype,
)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, from paddle.fluid.layers import (
elementwise_mul, elementwise_sub, nn, ops, control_flow,
tensor) elementwise_add,
elementwise_div,
elementwise_mul,
elementwise_sub,
nn,
ops,
tensor,
)
class Normal(distribution.Distribution): class Normal(distribution.Distribution):
...@@ -90,12 +101,18 @@ class Normal(distribution.Distribution): ...@@ -90,12 +101,18 @@ class Normal(distribution.Distribution):
def __init__(self, loc, scale, name=None): def __init__(self, loc, scale, name=None):
if not _non_static_mode(): if not _non_static_mode():
check_type(loc, 'loc', check_type(
loc,
'loc',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal') 'Normal',
check_type(scale, 'scale', )
check_type(
scale,
'scale',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, tensor.Variable, list, tuple),
'Normal') 'Normal',
)
self.batch_size_unknown = False self.batch_size_unknown = False
self.all_arg_is_float = False self.all_arg_is_float = False
...@@ -115,11 +132,15 @@ class Normal(distribution.Distribution): ...@@ -115,11 +132,15 @@ class Normal(distribution.Distribution):
else: else:
if isinstance(loc, float) and isinstance(scale, float): if isinstance(loc, float) and isinstance(scale, float):
self.all_arg_is_float = True self.all_arg_is_float = True
if isinstance(loc, np.ndarray) and str( if isinstance(loc, np.ndarray) and str(loc.dtype) in [
loc.dtype) in ['float32', 'float64']: 'float32',
'float64',
]:
self.dtype = loc.dtype self.dtype = loc.dtype
elif isinstance(scale, np.ndarray) and str( elif isinstance(scale, np.ndarray) and str(scale.dtype) in [
scale.dtype) in ['float32', 'float64']: 'float32',
'float64',
]:
self.dtype = scale.dtype self.dtype = scale.dtype
# pylint: disable=unbalanced-tuple-unpacking # pylint: disable=unbalanced-tuple-unpacking
self.loc, self.scale = self._to_tensor(loc, scale) self.loc, self.scale = self._to_tensor(loc, scale)
...@@ -149,21 +170,21 @@ class Normal(distribution.Distribution): ...@@ -149,21 +170,21 @@ class Normal(distribution.Distribution):
if self.batch_size_unknown: if self.batch_size_unknown:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.loc + self.scale, batch_shape + shape, self.dtype, 0.) self.loc + self.scale, batch_shape + shape, self.dtype, 0.0
)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
zero_tmp_shape = nn.shape(zero_tmp_reshape) zero_tmp_shape = nn.shape(zero_tmp_reshape)
normal_random_tmp = nn.gaussian_random(zero_tmp_shape, normal_random_tmp = nn.gaussian_random(
mean=0., zero_tmp_shape, mean=0.0, std=1.0, seed=seed, dtype=self.dtype
std=1., )
seed=seed,
dtype=self.dtype)
output = normal_random_tmp * (zero_tmp_reshape + self.scale) output = normal_random_tmp * (zero_tmp_reshape + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
return output return output
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.gaussian_random(output_shape, mean=0., std=1., seed=seed, dtype=self.dtype) * \ output = nn.gaussian_random(
(tensor.zeros(output_shape, dtype=self.dtype) + self.scale) output_shape, mean=0.0, std=1.0, seed=seed, dtype=self.dtype
) * (tensor.zeros(output_shape, dtype=self.dtype) + self.scale)
output = elementwise_add(output, self.loc, name=name) output = elementwise_add(output, self.loc, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
...@@ -189,13 +210,14 @@ class Normal(distribution.Distribution): ...@@ -189,13 +210,14 @@ class Normal(distribution.Distribution):
""" """
name = self.name + '_entropy' name = self.name + '_entropy'
batch_shape = list((self.loc + self.scale).shape) batch_shape = list((self.loc + self.scale).shape)
zero_tmp = tensor.fill_constant_batch_size_like(self.loc + self.scale, zero_tmp = tensor.fill_constant_batch_size_like(
batch_shape, self.dtype, self.loc + self.scale, batch_shape, self.dtype, 0.0
0.) )
return elementwise_add(0.5 + zero_tmp, return elementwise_add(
0.5 * math.log(2 * math.pi) + nn.log( 0.5 + zero_tmp,
(self.scale + zero_tmp)), 0.5 * math.log(2 * math.pi) + nn.log((self.scale + zero_tmp)),
name=name) name=name,
)
def log_prob(self, value): def log_prob(self, value):
"""Log probability density/mass function. """Log probability density/mass function.
...@@ -212,10 +234,11 @@ class Normal(distribution.Distribution): ...@@ -212,10 +234,11 @@ class Normal(distribution.Distribution):
var = self.scale * self.scale var = self.scale * self.scale
log_scale = nn.log(self.scale) log_scale = nn.log(self.scale)
return elementwise_sub(-1. * ((value - self.loc) * (value - self.loc)) / return elementwise_sub(
(2. * var), -1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var),
log_scale + math.log(math.sqrt(2. * math.pi)), log_scale + math.log(math.sqrt(2.0 * math.pi)),
name=name) name=name,
)
def probs(self, value): def probs(self, value):
"""Probability density/mass function. """Probability density/mass function.
...@@ -231,10 +254,13 @@ class Normal(distribution.Distribution): ...@@ -231,10 +254,13 @@ class Normal(distribution.Distribution):
value = self._check_values_dtype_in_probs(self.loc, value) value = self._check_values_dtype_in_probs(self.loc, value)
var = self.scale * self.scale var = self.scale * self.scale
return elementwise_div(ops.exp(-1. * ((value - self.loc) * return elementwise_div(
(value - self.loc)) / (2. * var)), ops.exp(
-1.0 * ((value - self.loc) * (value - self.loc)) / (2.0 * var)
),
(math.sqrt(2 * math.pi) * self.scale), (math.sqrt(2 * math.pi) * self.scale),
name=name) name=name,
)
def kl_divergence(self, other): def kl_divergence(self, other):
r"""The KL-divergence between two normal distributions. r"""The KL-divergence between two normal distributions.
...@@ -274,9 +300,9 @@ class Normal(distribution.Distribution): ...@@ -274,9 +300,9 @@ class Normal(distribution.Distribution):
name = self.name + '_kl_divergence' name = self.name + '_kl_divergence'
var_ratio = self.scale / other.scale var_ratio = self.scale / other.scale
var_ratio = (var_ratio * var_ratio) var_ratio = var_ratio * var_ratio
t1 = (self.loc - other.loc) / other.scale t1 = (self.loc - other.loc) / other.scale
t1 = (t1 * t1) t1 = t1 * t1
return elementwise_add(0.5 * var_ratio, return elementwise_add(
0.5 * (t1 - 1. - nn.log(var_ratio)), 0.5 * var_ratio, 0.5 * (t1 - 1.0 - nn.log(var_ratio)), name=name
name=name) )
...@@ -21,20 +21,33 @@ import typing ...@@ -21,20 +21,33 @@ import typing
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
from paddle.distribution import (constraint, distribution, from paddle.distribution import (
transformed_distribution, variable) constraint,
distribution,
transformed_distribution,
variable,
)
__all__ = [ # noqa __all__ = [ # noqa
'Transform', 'AbsTransform', 'AffineTransform', 'ChainTransform', 'Transform',
'ExpTransform', 'IndependentTransform', 'PowerTransform', 'AbsTransform',
'ReshapeTransform', 'SigmoidTransform', 'SoftmaxTransform', 'AffineTransform',
'StackTransform', 'StickBreakingTransform', 'TanhTransform' 'ChainTransform',
'ExpTransform',
'IndependentTransform',
'PowerTransform',
'ReshapeTransform',
'SigmoidTransform',
'SoftmaxTransform',
'StackTransform',
'StickBreakingTransform',
'TanhTransform',
] ]
class Type(enum.Enum): class Type(enum.Enum):
"""Mapping type of a transformation. """Mapping type of a transformation."""
"""
BIJECTION = 'bijection' # bijective(injective and surjective) BIJECTION = 'bijection' # bijective(injective and surjective)
INJECTION = 'injection' # injective-only INJECTION = 'injection' # injective-only
SURJECTION = 'surjection' # surjective-only SURJECTION = 'surjection' # surjective-only
...@@ -42,8 +55,7 @@ class Type(enum.Enum): ...@@ -42,8 +55,7 @@ class Type(enum.Enum):
@classmethod @classmethod
def is_injective(cls, _type): def is_injective(cls, _type):
"""Both bijection and injection are injective mapping. """Both bijection and injection are injective mapping."""
"""
return _type in (cls.BIJECTION, cls.INJECTION) return _type in (cls.BIJECTION, cls.INJECTION)
...@@ -139,7 +151,8 @@ class Transform(object): ...@@ -139,7 +151,8 @@ class Transform(object):
""" """
if isinstance(input, distribution.Distribution): if isinstance(input, distribution.Distribution):
return transformed_distribution.TransformedDistribution( return transformed_distribution.TransformedDistribution(
input, [self]) input, [self]
)
if isinstance(input, Transform): if isinstance(input, Transform):
return ChainTransform([self, input]) return ChainTransform([self, input])
return self.forward(x) return self.forward(x)
...@@ -158,11 +171,13 @@ class Transform(object): ...@@ -158,11 +171,13 @@ class Transform(object):
""" """
if not isinstance(x, paddle.fluid.framework.Variable): if not isinstance(x, paddle.fluid.framework.Variable):
raise TypeError( raise TypeError(
f"Expected 'x' is a Tensor or Real, but got {type(x)}.") f"Expected 'x' is a Tensor or Real, but got {type(x)}."
)
if x.dim() < self._domain.event_rank: if x.dim() < self._domain.event_rank:
raise ValueError( raise ValueError(
f'The dimensions of x({x.dim()}) should be ' f'The dimensions of x({x.dim()}) should be '
f'grater than or equal to {self._domain.event_rank}') f'grater than or equal to {self._domain.event_rank}'
)
return self._forward(x) return self._forward(x)
def inverse(self, y): def inverse(self, y):
...@@ -177,11 +192,13 @@ class Transform(object): ...@@ -177,11 +192,13 @@ class Transform(object):
""" """
if not isinstance(y, paddle.fluid.framework.Variable): if not isinstance(y, paddle.fluid.framework.Variable):
raise TypeError( raise TypeError(
f"Expected 'y' is a Tensor or Real, but got {type(y)}.") f"Expected 'y' is a Tensor or Real, but got {type(y)}."
)
if y.dim() < self._codomain.event_rank: if y.dim() < self._codomain.event_rank:
raise ValueError( raise ValueError(
f'The dimensions of y({y.dim()}) should be ' f'The dimensions of y({y.dim()}) should be '
f'grater than or equal to {self._codomain.event_rank}') f'grater than or equal to {self._codomain.event_rank}'
)
return self._inverse(y) return self._inverse(y)
def forward_log_det_jacobian(self, x): def forward_log_det_jacobian(self, x):
...@@ -197,16 +214,21 @@ class Transform(object): ...@@ -197,16 +214,21 @@ class Transform(object):
""" """
if not isinstance(x, paddle.fluid.framework.Variable): if not isinstance(x, paddle.fluid.framework.Variable):
raise TypeError( raise TypeError(
f"Expected 'y' is a Tensor or Real, but got {type(x)}.") f"Expected 'y' is a Tensor or Real, but got {type(x)}."
if isinstance(x, paddle.fluid.framework.Variable )
) and x.dim() < self._domain.event_rank: if (
isinstance(x, paddle.fluid.framework.Variable)
and x.dim() < self._domain.event_rank
):
raise ValueError( raise ValueError(
f'The dimensions of x({x.dim()}) should be ' f'The dimensions of x({x.dim()}) should be '
f'grater than or equal to {self._domain.event_rank}') f'grater than or equal to {self._domain.event_rank}'
)
if not self._is_injective(): if not self._is_injective():
raise NotImplementedError( raise NotImplementedError(
"forward_log_det_jacobian can't be implemented for non-injective" "forward_log_det_jacobian can't be implemented for non-injective"
"transforms.") "transforms."
)
return self._call_forward_log_det_jacobian(x) return self._call_forward_log_det_jacobian(x)
...@@ -227,7 +249,8 @@ class Transform(object): ...@@ -227,7 +249,8 @@ class Transform(object):
if y.dim() < self._codomain.event_rank: if y.dim() < self._codomain.event_rank:
raise ValueError( raise ValueError(
f'The dimensions of y({y.dim()}) should be ' f'The dimensions of y({y.dim()}) should be '
f'grater than or equal to {self._codomain.event_rank}') f'grater than or equal to {self._codomain.event_rank}'
)
return self._call_inverse_log_det_jacobian(y) return self._call_inverse_log_det_jacobian(y)
def forward_shape(self, shape): def forward_shape(self, shape):
...@@ -241,7 +264,8 @@ class Transform(object): ...@@ -241,7 +264,8 @@ class Transform(object):
""" """
if not isinstance(shape, typing.Sequence): if not isinstance(shape, typing.Sequence):
raise TypeError( raise TypeError(
f"Expected shape is Sequence[int] type, but got {type(shape)}.") f"Expected shape is Sequence[int] type, but got {type(shape)}."
)
return self._forward_shape(shape) return self._forward_shape(shape)
def inverse_shape(self, shape): def inverse_shape(self, shape):
...@@ -255,7 +279,8 @@ class Transform(object): ...@@ -255,7 +279,8 @@ class Transform(object):
""" """
if not isinstance(shape, typing.Sequence): if not isinstance(shape, typing.Sequence):
raise TypeError( raise TypeError(
f"Expected shape is Sequence[int] type, but got {type(shape)}.") f"Expected shape is Sequence[int] type, but got {type(shape)}."
)
return self._inverse_shape(shape) return self._inverse_shape(shape)
@property @property
...@@ -288,7 +313,8 @@ class Transform(object): ...@@ -288,7 +313,8 @@ class Transform(object):
return -self._inverse_log_det_jacobian(self.forward(y)) return -self._inverse_log_det_jacobian(self.forward(y))
raise NotImplementedError( raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian' 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian'
'is implemented. One of them is required.') 'is implemented. One of them is required.'
)
def _call_inverse_log_det_jacobian(self, y): def _call_inverse_log_det_jacobian(self, y):
"""Inner method called by ``inverse_log_det_jacobian``""" """Inner method called by ``inverse_log_det_jacobian``"""
...@@ -298,7 +324,8 @@ class Transform(object): ...@@ -298,7 +324,8 @@ class Transform(object):
return -self._forward_log_det_jacobian(self._inverse(y)) return -self._forward_log_det_jacobian(self._inverse(y))
raise NotImplementedError( raise NotImplementedError(
'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian ' 'Neither _forward_log_det_jacobian nor _inverse_log_det_jacobian '
'is implemented. One of them is required') 'is implemented. One of them is required'
)
def _forward_shape(self, shape): def _forward_shape(self, shape):
"""Inner method called by ``forward_shape``, which is used to infer the """Inner method called by ``forward_shape``, which is used to infer the
...@@ -421,7 +448,8 @@ class AffineTransform(Transform): ...@@ -421,7 +448,8 @@ class AffineTransform(Transform):
raise TypeError(f"Expected 'loc' is a Tensor, but got {type(loc)}") raise TypeError(f"Expected 'loc' is a Tensor, but got {type(loc)}")
if not isinstance(scale, paddle.fluid.framework.Variable): if not isinstance(scale, paddle.fluid.framework.Variable):
raise TypeError( raise TypeError(
f"Expected scale is a Tensor, but got {type(scale)}") f"Expected scale is a Tensor, but got {type(scale)}"
)
self._loc = loc self._loc = loc
self._scale = scale self._scale = scale
super(AffineTransform, self).__init__() super(AffineTransform, self).__init__()
...@@ -447,13 +475,17 @@ class AffineTransform(Transform): ...@@ -447,13 +475,17 @@ class AffineTransform(Transform):
return tuple( return tuple(
paddle.broadcast_shape( paddle.broadcast_shape(
paddle.broadcast_shape(shape, self._loc.shape), paddle.broadcast_shape(shape, self._loc.shape),
self._scale.shape)) self._scale.shape,
)
)
def _inverse_shape(self, shape): def _inverse_shape(self, shape):
return tuple( return tuple(
paddle.broadcast_shape( paddle.broadcast_shape(
paddle.broadcast_shape(shape, self._loc.shape), paddle.broadcast_shape(shape, self._loc.shape),
self._scale.shape)) self._scale.shape,
)
)
@property @property
def _domain(self): def _domain(self):
...@@ -505,7 +537,8 @@ class ChainTransform(Transform): ...@@ -505,7 +537,8 @@ class ChainTransform(Transform):
) )
if not all(isinstance(t, Transform) for t in transforms): if not all(isinstance(t, Transform) for t in transforms):
raise TypeError( raise TypeError(
"All elements of transforms should be Transform type.") "All elements of transforms should be Transform type."
)
self.transforms = transforms self.transforms = transforms
super(ChainTransform, self).__init__() super(ChainTransform, self).__init__()
...@@ -524,11 +557,12 @@ class ChainTransform(Transform): ...@@ -524,11 +557,12 @@ class ChainTransform(Transform):
return y return y
def _forward_log_det_jacobian(self, x): def _forward_log_det_jacobian(self, x):
value = 0. value = 0.0
event_rank = self._domain.event_rank event_rank = self._domain.event_rank
for t in self.transforms: for t in self.transforms:
value += self._sum_rightmost(t.forward_log_det_jacobian(x), value += self._sum_rightmost(
event_rank - t._domain.event_rank) t.forward_log_det_jacobian(x), event_rank - t._domain.event_rank
)
x = t.forward(x) x = t.forward(x)
event_rank += t._codomain.event_rank - t._domain.event_rank event_rank += t._codomain.event_rank - t._domain.event_rank
return value return value
...@@ -683,7 +717,8 @@ class IndependentTransform(Transform): ...@@ -683,7 +717,8 @@ class IndependentTransform(Transform):
def __init__(self, base, reinterpreted_batch_rank): def __init__(self, base, reinterpreted_batch_rank):
if not isinstance(base, Transform): if not isinstance(base, Transform):
raise TypeError( raise TypeError(
f"Expected 'base' is Transform type, but get {type(base)}") f"Expected 'base' is Transform type, but get {type(base)}"
)
if reinterpreted_batch_rank <= 0: if reinterpreted_batch_rank <= 0:
raise ValueError( raise ValueError(
f"Expected 'reinterpreted_batch_rank' is grater than zero, but got {reinterpreted_batch_rank}" f"Expected 'reinterpreted_batch_rank' is grater than zero, but got {reinterpreted_batch_rank}"
...@@ -708,7 +743,8 @@ class IndependentTransform(Transform): ...@@ -708,7 +743,8 @@ class IndependentTransform(Transform):
def _forward_log_det_jacobian(self, x): def _forward_log_det_jacobian(self, x):
return self._base.forward_log_det_jacobian(x).sum( return self._base.forward_log_det_jacobian(x).sum(
list(range(-self._reinterpreted_batch_rank, 0))) list(range(-self._reinterpreted_batch_rank, 0))
)
def _forward_shape(self, shape): def _forward_shape(self, shape):
return self._base.forward_shape(shape) return self._base.forward_shape(shape)
...@@ -718,13 +754,15 @@ class IndependentTransform(Transform): ...@@ -718,13 +754,15 @@ class IndependentTransform(Transform):
@property @property
def _domain(self): def _domain(self):
return variable.Independent(self._base._domain, return variable.Independent(
self._reinterpreted_batch_rank) self._base._domain, self._reinterpreted_batch_rank
)
@property @property
def _codomain(self): def _codomain(self):
return variable.Independent(self._base._codomain, return variable.Independent(
self._reinterpreted_batch_rank) self._base._codomain, self._reinterpreted_batch_rank
)
class PowerTransform(Transform): class PowerTransform(Transform):
...@@ -758,7 +796,8 @@ class PowerTransform(Transform): ...@@ -758,7 +796,8 @@ class PowerTransform(Transform):
def __init__(self, power): def __init__(self, power):
if not isinstance(power, paddle.fluid.framework.Variable): if not isinstance(power, paddle.fluid.framework.Variable):
raise TypeError( raise TypeError(
f"Expected 'power' is a tensor, but got {type(power)}") f"Expected 'power' is a tensor, but got {type(power)}"
)
self._power = power self._power = power
super(PowerTransform, self).__init__() super(PowerTransform, self).__init__()
...@@ -827,13 +866,16 @@ class ReshapeTransform(Transform): ...@@ -827,13 +866,16 @@ class ReshapeTransform(Transform):
def __init__(self, in_event_shape, out_event_shape): def __init__(self, in_event_shape, out_event_shape):
if not isinstance(in_event_shape, typing.Sequence) or not isinstance( if not isinstance(in_event_shape, typing.Sequence) or not isinstance(
out_event_shape, typing.Sequence): out_event_shape, typing.Sequence
):
raise TypeError( raise TypeError(
f"Expected type of 'in_event_shape' and 'out_event_shape' is " f"Expected type of 'in_event_shape' and 'out_event_shape' is "
f"Squence[int], but got 'in_event_shape': {in_event_shape}, " f"Squence[int], but got 'in_event_shape': {in_event_shape}, "
f"'out_event_shape': {out_event_shape}") f"'out_event_shape': {out_event_shape}"
)
if functools.reduce(operator.mul, in_event_shape) != functools.reduce( if functools.reduce(operator.mul, in_event_shape) != functools.reduce(
operator.mul, out_event_shape): operator.mul, out_event_shape
):
raise ValueError( raise ValueError(
f"The numel of 'in_event_shape' should be 'out_event_shape', " f"The numel of 'in_event_shape' should be 'out_event_shape', "
f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}" f"but got {functools.reduce(operator.mul, in_event_shape)}!={functools.reduce(operator.mul, out_event_shape)}"
...@@ -861,39 +903,45 @@ class ReshapeTransform(Transform): ...@@ -861,39 +903,45 @@ class ReshapeTransform(Transform):
def _forward(self, x): def _forward(self, x):
return x.reshape( return x.reshape(
tuple(x.shape)[:x.dim() - len(self._in_event_shape)] + tuple(x.shape)[: x.dim() - len(self._in_event_shape)]
self._out_event_shape) + self._out_event_shape
)
def _inverse(self, y): def _inverse(self, y):
return y.reshape( return y.reshape(
tuple(y.shape)[:y.dim() - len(self._out_event_shape)] + tuple(y.shape)[: y.dim() - len(self._out_event_shape)]
self._in_event_shape) + self._in_event_shape
)
def _forward_shape(self, shape): def _forward_shape(self, shape):
if len(shape) < len(self._in_event_shape): if len(shape) < len(self._in_event_shape):
raise ValueError( raise ValueError(
f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}" f"Expected length of 'shape' is not less than {len(self._in_event_shape)}, but got {len(shape)}"
) )
if shape[-len(self._in_event_shape):] != self._in_event_shape: if shape[-len(self._in_event_shape) :] != self._in_event_shape:
raise ValueError( raise ValueError(
f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}" f"Event shape mismatch, expected: {self._in_event_shape}, but got {shape[-len(self._in_event_shape):]}"
) )
return tuple(shape[:-len(self._in_event_shape)]) + self._out_event_shape return (
tuple(shape[: -len(self._in_event_shape)]) + self._out_event_shape
)
def _inverse_shape(self, shape): def _inverse_shape(self, shape):
if len(shape) < len(self._out_event_shape): if len(shape) < len(self._out_event_shape):
raise ValueError( raise ValueError(
f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}" f"Expected 'shape' length is not less than {len(self._out_event_shape)}, but got {len(shape)}"
) )
if shape[-len(self._out_event_shape):] != self._out_event_shape: if shape[-len(self._out_event_shape) :] != self._out_event_shape:
raise ValueError( raise ValueError(
f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}" f"Event shape mismatch, expected: {self._out_event_shape}, but got {shape[-len(self._out_event_shape):]}"
) )
return tuple(shape[:-len(self._out_event_shape)]) + self._in_event_shape return (
tuple(shape[: -len(self._out_event_shape)]) + self._in_event_shape
)
def _forward_log_det_jacobian(self, x): def _forward_log_det_jacobian(self, x):
# paddle.zeros not support zero dimension Tensor. # paddle.zeros not support zero dimension Tensor.
shape = x.shape[:x.dim() - len(self._in_event_shape)] or [1] shape = x.shape[: x.dim() - len(self._in_event_shape)] or [1]
return paddle.zeros(shape, dtype=x.dtype) return paddle.zeros(shape, dtype=x.dtype)
...@@ -928,7 +976,7 @@ class SigmoidTransform(Transform): ...@@ -928,7 +976,7 @@ class SigmoidTransform(Transform):
@property @property
def _codomain(self): def _codomain(self):
return variable.Variable(False, 0, constraint.Range(0., 1.)) return variable.Variable(False, 0, constraint.Range(0.0, 1.0))
def _forward(self, x): def _forward(self, x):
return F.sigmoid(x) return F.sigmoid(x)
...@@ -997,7 +1045,7 @@ class SoftmaxTransform(Transform): ...@@ -997,7 +1045,7 @@ class SoftmaxTransform(Transform):
class StackTransform(Transform): class StackTransform(Transform):
r""" ``StackTransform`` applies a sequence of transformations along the r"""``StackTransform`` applies a sequence of transformations along the
specific axis. specific axis.
Args: Args:
...@@ -1042,7 +1090,8 @@ class StackTransform(Transform): ...@@ -1042,7 +1090,8 @@ class StackTransform(Transform):
) )
if not all(isinstance(t, Transform) for t in transforms): if not all(isinstance(t, Transform) for t in transforms):
raise TypeError( raise TypeError(
'Expected all element in transforms is Transform Type.') 'Expected all element in transforms is Transform Type.'
)
if not isinstance(axis, int): if not isinstance(axis, int):
raise TypeError(f"Expected 'axis' is int, but got{type(axis)}.") raise TypeError(f"Expected 'axis' is int, but got{type(axis)}.")
...@@ -1062,34 +1111,45 @@ class StackTransform(Transform): ...@@ -1062,34 +1111,45 @@ class StackTransform(Transform):
def _forward(self, x): def _forward(self, x):
self._check_size(x) self._check_size(x)
return paddle.stack([ return paddle.stack(
[
t.forward(v) t.forward(v)
for v, t in zip(paddle.unstack(x, self._axis), self._transforms) for v, t in zip(paddle.unstack(x, self._axis), self._transforms)
], self._axis) ],
self._axis,
)
def _inverse(self, y): def _inverse(self, y):
self._check_size(y) self._check_size(y)
return paddle.stack([ return paddle.stack(
[
t.inverse(v) t.inverse(v)
for v, t in zip(paddle.unstack(y, self._axis), self._transforms) for v, t in zip(paddle.unstack(y, self._axis), self._transforms)
], self._axis) ],
self._axis,
)
def _forward_log_det_jacobian(self, x): def _forward_log_det_jacobian(self, x):
self._check_size(x) self._check_size(x)
return paddle.stack([ return paddle.stack(
[
t.forward_log_det_jacobian(v) t.forward_log_det_jacobian(v)
for v, t in zip(paddle.unstack(x, self._axis), self._transforms) for v, t in zip(paddle.unstack(x, self._axis), self._transforms)
], self._axis) ],
self._axis,
)
def _check_size(self, v): def _check_size(self, v):
if not (-v.dim() <= self._axis < v.dim()): if not (-v.dim() <= self._axis < v.dim()):
raise ValueError( raise ValueError(
f'Input dimensions {v.dim()} should be grater than stack ' f'Input dimensions {v.dim()} should be grater than stack '
f'transform axis {self._axis}.') f'transform axis {self._axis}.'
)
if v.shape[self._axis] != len(self._transforms): if v.shape[self._axis] != len(self._transforms):
raise ValueError( raise ValueError(
f'Input size along {self._axis} should be equal to the ' f'Input size along {self._axis} should be equal to the '
f'length of transforms.') f'length of transforms.'
)
@property @property
def _domain(self): def _domain(self):
...@@ -1097,8 +1157,9 @@ class StackTransform(Transform): ...@@ -1097,8 +1157,9 @@ class StackTransform(Transform):
@property @property
def _codomain(self): def _codomain(self):
return variable.Stack([t._codomain for t in self._transforms], return variable.Stack(
self._axis) [t._codomain for t in self._transforms], self._axis
)
class StickBreakingTransform(Transform): class StickBreakingTransform(Transform):
...@@ -1131,8 +1192,9 @@ class StickBreakingTransform(Transform): ...@@ -1131,8 +1192,9 @@ class StickBreakingTransform(Transform):
offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1) offset = x.shape[-1] + 1 - paddle.ones([x.shape[-1]]).cumsum(-1)
z = F.sigmoid(x - offset.log()) z = F.sigmoid(x - offset.log())
z_cumprod = (1 - z).cumprod(-1) z_cumprod = (1 - z).cumprod(-1)
return F.pad(z, [0]*2*(len(x.shape)-1) + [0, 1], value=1) * \ return F.pad(z, [0] * 2 * (len(x.shape) - 1) + [0, 1], value=1) * F.pad(
F.pad(z_cumprod, [0]*2*(len(x.shape)-1) + [1, 0], value=1) z_cumprod, [0] * 2 * (len(x.shape) - 1) + [1, 0], value=1
)
def _inverse(self, y): def _inverse(self, y):
y_crop = y[..., :-1] y_crop = y[..., :-1]
...@@ -1150,12 +1212,12 @@ class StickBreakingTransform(Transform): ...@@ -1150,12 +1212,12 @@ class StickBreakingTransform(Transform):
def _forward_shape(self, shape): def _forward_shape(self, shape):
if not shape: if not shape:
raise ValueError(f"Expected 'shape' is not empty, but got {shape}") raise ValueError(f"Expected 'shape' is not empty, but got {shape}")
return shape[:-1] + (shape[-1] + 1, ) return shape[:-1] + (shape[-1] + 1,)
def _inverse_shape(self, shape): def _inverse_shape(self, shape):
if not shape: if not shape:
raise ValueError(f"Expected 'shape' is not empty, but got {shape}") raise ValueError(f"Expected 'shape' is not empty, but got {shape}")
return shape[:-1] + (shape[-1] - 1, ) return shape[:-1] + (shape[-1] - 1,)
@property @property
def _domain(self): def _domain(self):
...@@ -1219,4 +1281,4 @@ class TanhTransform(Transform): ...@@ -1219,4 +1281,4 @@ class TanhTransform(Transform):
See details: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 See details: https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80
""" """
return 2. * (math.log(2.) - x - F.softplus(-2. * x)) return 2.0 * (math.log(2.0) - x - F.softplus(-2.0 * x))
...@@ -19,12 +19,27 @@ import numpy as np ...@@ -19,12 +19,27 @@ import numpy as np
from paddle import _C_ops, _legacy_C_ops from paddle import _C_ops, _legacy_C_ops
from paddle.distribution import distribution from paddle.distribution import distribution
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.data_feeder import (check_dtype, check_type, from paddle.fluid.data_feeder import (
check_variable_and_dtype, convert_dtype) check_dtype,
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode, _in_legacy_dygraph check_type,
from paddle.fluid.layers import (control_flow, elementwise_add, elementwise_div, check_variable_and_dtype,
elementwise_mul, elementwise_sub, nn, ops, convert_dtype,
tensor) )
from paddle.fluid.framework import (
_non_static_mode,
in_dygraph_mode,
_in_legacy_dygraph,
)
from paddle.fluid.layers import (
control_flow,
elementwise_add,
elementwise_div,
elementwise_mul,
elementwise_sub,
nn,
ops,
tensor,
)
from paddle.tensor import arange, concat, gather_nd, multinomial from paddle.tensor import arange, concat, gather_nd, multinomial
...@@ -91,12 +106,18 @@ class Uniform(distribution.Distribution): ...@@ -91,12 +106,18 @@ class Uniform(distribution.Distribution):
def __init__(self, low, high, name=None): def __init__(self, low, high, name=None):
if not _non_static_mode(): if not _non_static_mode():
check_type(low, 'low', check_type(
low,
'low',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform') 'Uniform',
check_type(high, 'high', )
check_type(
high,
'high',
(int, float, np.ndarray, tensor.Variable, list, tuple), (int, float, np.ndarray, tensor.Variable, list, tuple),
'Uniform') 'Uniform',
)
self.all_arg_is_float = False self.all_arg_is_float = False
self.batch_size_unknown = False self.batch_size_unknown = False
...@@ -116,11 +137,15 @@ class Uniform(distribution.Distribution): ...@@ -116,11 +137,15 @@ class Uniform(distribution.Distribution):
else: else:
if isinstance(low, float) and isinstance(high, float): if isinstance(low, float) and isinstance(high, float):
self.all_arg_is_float = True self.all_arg_is_float = True
if isinstance(low, np.ndarray) and str( if isinstance(low, np.ndarray) and str(low.dtype) in [
low.dtype) in ['float32', 'float64']: 'float32',
'float64',
]:
self.dtype = low.dtype self.dtype = low.dtype
elif isinstance(high, np.ndarray) and str( elif isinstance(high, np.ndarray) and str(high.dtype) in [
high.dtype) in ['float32', 'float64']: 'float32',
'float64',
]:
self.dtype = high.dtype self.dtype = high.dtype
# pylint: disable=unbalanced-tuple-unpacking # pylint: disable=unbalanced-tuple-unpacking
self.low, self.high = self._to_tensor(low, high) self.low, self.high = self._to_tensor(low, high)
...@@ -148,27 +173,33 @@ class Uniform(distribution.Distribution): ...@@ -148,27 +173,33 @@ class Uniform(distribution.Distribution):
if self.batch_size_unknown: if self.batch_size_unknown:
output_shape = shape + batch_shape output_shape = shape + batch_shape
zero_tmp = tensor.fill_constant_batch_size_like( zero_tmp = tensor.fill_constant_batch_size_like(
self.low + self.high, batch_shape + shape, self.dtype, 0.) self.low + self.high, batch_shape + shape, self.dtype, 0.0
)
uniform_random_tmp = nn.uniform_random_batch_size_like( uniform_random_tmp = nn.uniform_random_batch_size_like(
zero_tmp, zero_tmp,
zero_tmp.shape, zero_tmp.shape,
dtype=self.dtype, dtype=self.dtype,
min=0., min=0.0,
max=1., max=1.0,
seed=seed) seed=seed,
)
zero_tmp_reshape = nn.reshape(zero_tmp, output_shape) zero_tmp_reshape = nn.reshape(zero_tmp, output_shape)
uniform_random_tmp_reshape = nn.reshape(uniform_random_tmp, uniform_random_tmp_reshape = nn.reshape(
output_shape) uniform_random_tmp, output_shape
output = uniform_random_tmp_reshape * (zero_tmp_reshape + )
self.high - self.low) output = uniform_random_tmp_reshape * (
zero_tmp_reshape + self.high - self.low
)
output = elementwise_add(output, self.low, name=name) output = elementwise_add(output, self.low, name=name)
return output return output
else: else:
output_shape = shape + batch_shape output_shape = shape + batch_shape
output = nn.uniform_random( output = nn.uniform_random(
output_shape, dtype=self.dtype, min=0., max=1., output_shape, dtype=self.dtype, min=0.0, max=1.0, seed=seed
seed=seed) * (tensor.zeros(output_shape, dtype=self.dtype) + ) * (
(self.high - self.low)) tensor.zeros(output_shape, dtype=self.dtype)
+ (self.high - self.low)
)
output = elementwise_add(output, self.low, name=name) output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float: if self.all_arg_is_float:
return nn.reshape(output, shape, name=name) return nn.reshape(output, shape, name=name)
...@@ -197,10 +228,12 @@ class Uniform(distribution.Distribution): ...@@ -197,10 +228,12 @@ class Uniform(distribution.Distribution):
return nn.log(lb * ub) - nn.log(self.high - self.low) return nn.log(lb * ub) - nn.log(self.high - self.low)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, lb = _legacy_C_ops.cast(
'out_dtype', value.dtype) lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', value.dtype
ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, )
'out_dtype', value.dtype) ub = _legacy_C_ops.cast(
ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', value.dtype
)
return nn.log(lb * ub) - nn.log(self.high - self.low) return nn.log(lb * ub) - nn.log(self.high - self.low)
name = self.name + '_log_prob' name = self.name + '_log_prob'
...@@ -208,9 +241,9 @@ class Uniform(distribution.Distribution): ...@@ -208,9 +241,9 @@ class Uniform(distribution.Distribution):
ub_bool = value < self.high ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype) lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype) ub = tensor.cast(ub_bool, dtype=value.dtype)
return elementwise_sub(nn.log(lb * ub), return elementwise_sub(
nn.log(self.high - self.low), nn.log(lb * ub), nn.log(self.high - self.low), name=name
name=name) )
def probs(self, value): def probs(self, value):
"""Probability density/mass function. """Probability density/mass function.
...@@ -233,10 +266,12 @@ class Uniform(distribution.Distribution): ...@@ -233,10 +266,12 @@ class Uniform(distribution.Distribution):
return (lb * ub) / (self.high - self.low) return (lb * ub) / (self.high - self.low)
if _in_legacy_dygraph(): if _in_legacy_dygraph():
lb = _legacy_C_ops.cast(lb_bool, 'in_dtype', lb_bool.dtype, lb = _legacy_C_ops.cast(
'out_dtype', value.dtype) lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', value.dtype
ub = _legacy_C_ops.cast(ub_bool, 'in_dtype', ub_bool.dtype, )
'out_dtype', value.dtype) ub = _legacy_C_ops.cast(
ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', value.dtype
)
return (lb * ub) / (self.high - self.low) return (lb * ub) / (self.high - self.low)
name = self.name + '_probs' name = self.name + '_probs'
......
此差异已折叠。
...@@ -46,11 +46,16 @@ def set_default_dtype(d): ...@@ -46,11 +46,16 @@ def set_default_dtype(d):
else: else:
raise TypeError( raise TypeError(
"set_default_dtype only supports [float16, float32, float64] " "set_default_dtype only supports [float16, float32, float64] "
", but received %s" % d.__name__) ", but received %s" % d.__name__
)
else: else:
if d in [ if d in [
'float16', 'float32', 'float64', u'float16', u'float32', 'float16',
u'float64' 'float32',
'float64',
u'float16',
u'float32',
u'float64',
]: ]:
# this code is a little bit dangerous, since error could happen # this code is a little bit dangerous, since error could happen
# when casting no-ascii code to str in python2. # when casting no-ascii code to str in python2.
...@@ -61,7 +66,8 @@ def set_default_dtype(d): ...@@ -61,7 +66,8 @@ def set_default_dtype(d):
else: else:
raise TypeError( raise TypeError(
"set_default_dtype only supports [float16, float32, float64] " "set_default_dtype only supports [float16, float32, float64] "
", but received %s" % str(d)) ", but received %s" % str(d)
)
LayerHelperBase.set_default_dtype(d) LayerHelperBase.set_default_dtype(d)
......
此差异已折叠。
...@@ -64,25 +64,35 @@ def forward_grad(outputs, inputs, grad_inputs=None): ...@@ -64,25 +64,35 @@ def forward_grad(outputs, inputs, grad_inputs=None):
paddle.disable_static() paddle.disable_static()
""" """
if not utils.prim_enabled(): if not utils.prim_enabled():
raise RuntimeError('forward_grad must be running on primitive' raise RuntimeError(
'operators, use enable_prim to turn it on.') 'forward_grad must be running on primitive'
'operators, use enable_prim to turn it on.'
)
if not isinstance(outputs, (framework.Variable, typing.Sequence)): if not isinstance(outputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected outputs is Tensor|Sequence[Tesnor], ' raise TypeError(
f'but got {type(outputs)}.') f'Expected outputs is Tensor|Sequence[Tesnor], '
f'but got {type(outputs)}.'
)
if not isinstance(inputs, (framework.Variable, typing.Sequence)): if not isinstance(inputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected inputs is Tensor|Sequence[Tesnor], ' raise TypeError(
f'but got {type(inputs)}.') f'Expected inputs is Tensor|Sequence[Tesnor], '
f'but got {type(inputs)}.'
)
ys, xs, xs_dot = utils.as_tensors(outputs), utils.as_tensors( ys, xs, xs_dot = (
inputs), utils.as_tensors(grad_inputs) utils.as_tensors(outputs),
utils.as_tensors(inputs),
utils.as_tensors(grad_inputs),
)
block = framework.default_main_program().current_block() block = framework.default_main_program().current_block()
if any(x.block != block for x in xs + ys): if any(x.block != block for x in xs + ys):
raise RuntimeError( raise RuntimeError(
'Variable in inputs and targets should exist in current block of ' 'Variable in inputs and targets should exist in current block of '
'main program.') 'main program.'
)
primx.orig2prim(block) primx.orig2prim(block)
ad = primx.Transform(ys[0].block) ad = primx.Transform(ys[0].block)
...@@ -141,22 +151,32 @@ def grad(outputs, inputs, grad_outputs=None): ...@@ -141,22 +151,32 @@ def grad(outputs, inputs, grad_outputs=None):
# backward.gradients returns a list though the inputs is a signle Tensor. # backward.gradients returns a list though the inputs is a signle Tensor.
# The follow code snippet fixes the problem by return the first element # The follow code snippet fixes the problem by return the first element
# of grad_inputs when the inputs is a signle Tensor. # of grad_inputs when the inputs is a signle Tensor.
if isinstance(inputs, framework.Variable) and isinstance( if (
grad_inputs, typing.Sequence) and len(grad_inputs) > 0: isinstance(inputs, framework.Variable)
and isinstance(grad_inputs, typing.Sequence)
and len(grad_inputs) > 0
):
return grad_inputs[0] return grad_inputs[0]
else: else:
return grad_inputs return grad_inputs
if not isinstance(outputs, (framework.Variable, typing.Sequence)): if not isinstance(outputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected outputs is Tensor|Sequence[Tesnor], ' raise TypeError(
f'but got {type(outputs)}.') f'Expected outputs is Tensor|Sequence[Tesnor], '
f'but got {type(outputs)}.'
)
if not isinstance(inputs, (framework.Variable, typing.Sequence)): if not isinstance(inputs, (framework.Variable, typing.Sequence)):
raise TypeError(f'Expected inputs is Tensor|Sequence[Tesnor], ' raise TypeError(
f'but got {type(inputs)}.') f'Expected inputs is Tensor|Sequence[Tesnor], '
f'but got {type(inputs)}.'
)
ys, xs, ys_bar = utils.as_tensors(outputs), utils.as_tensors( ys, xs, ys_bar = (
inputs), utils.as_tensors(grad_outputs) utils.as_tensors(outputs),
utils.as_tensors(inputs),
utils.as_tensors(grad_outputs),
)
block = framework.default_main_program().current_block() block = framework.default_main_program().current_block()
if any((x is not None and x.block != block) for x in xs + ys): if any((x is not None and x.block != block) for x in xs + ys):
raise RuntimeError( raise RuntimeError(
......
...@@ -21,15 +21,23 @@ from paddle.fluid.framework import Operator, default_main_program ...@@ -21,15 +21,23 @@ from paddle.fluid.framework import Operator, default_main_program
from paddle.incubate.autograd.utils import as_tensors from paddle.incubate.autograd.utils import as_tensors
from .primops import add, fill_const from .primops import add, fill_const
from .primreg import (lookup_orig2prim, lookup_prim2orig, op_position_inputs, from .primreg import (
op_position_output) lookup_orig2prim,
lookup_prim2orig,
op_position_inputs,
op_position_output,
)
from .primrules import _jvp, _orig2prim, _prim2orig, _transpose from .primrules import _jvp, _orig2prim, _prim2orig, _transpose
from .utils import (flatten, flatten_and_remove_none, get_input_var_list, from .utils import (
get_output_var_list) flatten,
flatten_and_remove_none,
get_input_var_list,
get_output_var_list,
)
def topo_path(xs, ys, block=None): def topo_path(xs, ys, block=None):
""" Returns the list of ops on the path from `xs` to `ys` in topological """Returns the list of ops on the path from `xs` to `ys` in topological
order. order.
TODO(Tongxin): supporting control flow and nested blocks. TODO(Tongxin): supporting control flow and nested blocks.
...@@ -51,13 +59,16 @@ def topo_path(xs, ys, block=None): ...@@ -51,13 +59,16 @@ def topo_path(xs, ys, block=None):
# Initialize reached vars # Initialize reached vars
for x in xs: for x in xs:
assert x is None or x.block == block, f'x is not None and x.block != block' assert (
x is None or x.block == block
), f'x is not None and x.block != block'
reached_vars[id(x)] = x reached_vars[id(x)] = x
# Reaching test, returning whether an op is reached from the given input # Reaching test, returning whether an op is reached from the given input
reaching = lambda op: any( reaching = lambda op: any(
id(v) in reached_vars id(v) in reached_vars
for v in flatten_and_remove_none(get_input_var_list(op))) for v in flatten_and_remove_none(get_input_var_list(op))
)
# block.ops are supposedly in the order that preserves correct data # block.ops are supposedly in the order that preserves correct data
# dependence. # dependence.
...@@ -71,7 +82,8 @@ def topo_path(xs, ys, block=None): ...@@ -71,7 +82,8 @@ def topo_path(xs, ys, block=None):
used_vars = OrderedDict((id(y), y) for y in ys if id(y) in reached_vars) used_vars = OrderedDict((id(y), y) for y in ys if id(y) in reached_vars)
back_reaching = lambda op: any( back_reaching = lambda op: any(
id(out) in used_vars id(out) in used_vars
for out in flatten_and_remove_none(get_output_var_list(op))) for out in flatten_and_remove_none(get_output_var_list(op))
)
# Backward pass to find all used variables # Backward pass to find all used variables
for op in reversed(path): for op in reversed(path):
...@@ -87,7 +99,7 @@ def topo_path(xs, ys, block=None): ...@@ -87,7 +99,7 @@ def topo_path(xs, ys, block=None):
def output_vars_on_path(path): def output_vars_on_path(path):
""" Returns the output variables of all the ops on the path from `xs` """Returns the output variables of all the ops on the path from `xs`
to `ys`. to `ys`.
Args: Args:
...@@ -105,7 +117,7 @@ def output_vars_on_path(path): ...@@ -105,7 +117,7 @@ def output_vars_on_path(path):
class VarMap(object): class VarMap(object):
""" A general map data structure for linking variables to variables. """A general map data structure for linking variables to variables.
An example is linking variables to their gradients. An example is linking variables to their gradients.
""" """
...@@ -126,7 +138,8 @@ class VarMap(object): ...@@ -126,7 +138,8 @@ class VarMap(object):
if isinstance(key_vars, paddle.fluid.framework.Variable): if isinstance(key_vars, paddle.fluid.framework.Variable):
if not isinstance(value_vars, paddle.fluid.framework.Variable): if not isinstance(value_vars, paddle.fluid.framework.Variable):
raise TypeError( raise TypeError(
f'value_vars must be Variable, but got {type(value_vars)}') f'value_vars must be Variable, but got {type(value_vars)}'
)
self.tab[id(key_vars)] = id(value_vars) self.tab[id(key_vars)] = id(value_vars)
else: else:
assert len(key_vars) == len(value_vars), ( assert len(key_vars) == len(value_vars), (
...@@ -169,11 +182,12 @@ class VarMap(object): ...@@ -169,11 +182,12 @@ class VarMap(object):
# TODO(lml): supporting control flow, nested blocks, and block other than current block of main program. # TODO(lml): supporting control flow, nested blocks, and block other than current block of main program.
class Transform(object): class Transform(object):
""" An object that maintains the state of transformations applied to a """An object that maintains the state of transformations applied to a
primitve program. """ primitve program."""
def __init__(self, block): def __init__(self, block):
assert block == default_main_program().current_block( assert (
block == default_main_program().current_block()
), f'only support transform on current block of main program.' ), f'only support transform on current block of main program.'
self.block = block self.block = block
self.vars = self.init_vars(block) self.vars = self.init_vars(block)
...@@ -225,7 +239,7 @@ class Transform(object): ...@@ -225,7 +239,7 @@ class Transform(object):
block._sync_with_cpp() block._sync_with_cpp()
def var2dot_rec(self, vars): def var2dot_rec(self, vars):
""" Lookup var2dot recursively.""" """Lookup var2dot recursively."""
if isinstance(vars, paddle.fluid.framework.Variable): if isinstance(vars, paddle.fluid.framework.Variable):
dot = self.var2dot.lookup(vars) dot = self.var2dot.lookup(vars)
return dot return dot
...@@ -244,7 +258,7 @@ class Transform(object): ...@@ -244,7 +258,7 @@ class Transform(object):
return bars return bars
def linearize(self, xs, ys, xs_dot=None): def linearize(self, xs, ys, xs_dot=None):
""" Performs the linearization transform, a.k.a, forward mode AD """Performs the linearization transform, a.k.a, forward mode AD
transform, on a primitive lowered program. transform, on a primitive lowered program.
Args: Args:
...@@ -266,15 +280,18 @@ class Transform(object): ...@@ -266,15 +280,18 @@ class Transform(object):
else: else:
assert len(xs) == len(xs_dot), ( assert len(xs) == len(xs_dot), (
f'len(xs) should be equal to len(xs_dot), ' f'len(xs) should be equal to len(xs_dot), '
f'but len(xs)={len(xs)} and len(xs_dot)={len(xs_dot)}') f'but len(xs)={len(xs)} and len(xs_dot)={len(xs_dot)}'
)
for x, dot in zip(xs, xs_dot): for x, dot in zip(xs, xs_dot):
assert x.dtype == dot.dtype, ( assert x.dtype == dot.dtype, (
f'x.dtype should be equal to dot.dtype, ' f'x.dtype should be equal to dot.dtype, '
f'but x.dtype={x.dtype} and dot.dtype={dot.dtype}') f'but x.dtype={x.dtype} and dot.dtype={dot.dtype}'
)
assert x.shape == dot.shape, ( assert x.shape == dot.shape, (
f'x.shape should be equal to dot.shape, ' f'x.shape should be equal to dot.shape, '
f'but x.shape={x.shape} and dot.shape={dot.shape}') f'but x.shape={x.shape} and dot.shape={dot.shape}'
)
self.var2dot.add(x, dot) self.var2dot.add(x, dot)
path, unused_xs, _ = topo_path(xs, ys, self.block) path, unused_xs, _ = topo_path(xs, ys, self.block)
...@@ -300,7 +317,7 @@ class Transform(object): ...@@ -300,7 +317,7 @@ class Transform(object):
return xs_dot, ys_dot return xs_dot, ys_dot
def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False): def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
""" Performs the transpose transform, a.k.a, reverse mode AD """Performs the transpose transform, a.k.a, reverse mode AD
transform, on a linearized primitive program. transform, on a linearized primitive program.
Note, `transpose` is supposed to be used in couple with `linearize`. Note, `transpose` is supposed to be used in couple with `linearize`.
...@@ -329,7 +346,8 @@ class Transform(object): ...@@ -329,7 +346,8 @@ class Transform(object):
else: else:
assert len(ys_dot) == len(ys_bar), ( assert len(ys_dot) == len(ys_bar), (
f'len(ys_dot) should be equal to len(ys_bar), ' f'len(ys_dot) should be equal to len(ys_bar), '
f'but len(ys_dot)={len(ys_dot)} and len(ys_bar)={len(ys_bar)}') f'but len(ys_dot)={len(ys_dot)} and len(ys_bar)={len(ys_bar)}'
)
for y_dot, y_bar in zip(ys_dot, ys_bar): for y_dot, y_bar in zip(ys_dot, ys_bar):
assert y_dot.shape == y_bar.shape, ( assert y_dot.shape == y_bar.shape, (
f'y_dot.shape should be equal to y_bar.shape, ' f'y_dot.shape should be equal to y_bar.shape, '
...@@ -373,7 +391,8 @@ class Transform(object): ...@@ -373,7 +391,8 @@ class Transform(object):
ins = flatten(op_position_inputs(op)) ins = flatten(op_position_inputs(op))
assert len(ins) == len(ins_bar), ( assert len(ins) == len(ins_bar), (
f'len(ins) should be equal to len(ins_bar), ' f'len(ins) should be equal to len(ins_bar), '
f'but len(ins)={len(ins)} and len(ins_bar)={len(ins_bar)}') f'but len(ins)={len(ins)} and len(ins_bar)={len(ins_bar)}'
)
for dot, bar in zip(ins, ins_bar): for dot, bar in zip(ins, ins_bar):
if bar is not None: if bar is not None:
...@@ -392,7 +411,8 @@ class Transform(object): ...@@ -392,7 +411,8 @@ class Transform(object):
vars_to_remove = set() vars_to_remove = set()
for op in path: for op in path:
vars_to_remove.update( vars_to_remove.update(
flatten_and_remove_none(get_output_var_list(op))) flatten_and_remove_none(get_output_var_list(op))
)
op_indexes = [] op_indexes = []
...@@ -461,9 +481,11 @@ def _lower(block, reverse, blacklist): ...@@ -461,9 +481,11 @@ def _lower(block, reverse, blacklist):
for orig_out, new_out in zip( for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)), expand_nested_list(get_output_var_list(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args)))): expand_nested_list(as_tensors(lower_fn(op, *input_args))),
):
assert not (orig_out is None) ^ ( assert not (orig_out is None) ^ (
new_out is None), "orig_out and new_out should match." new_out is None
), "orig_out and new_out should match."
vars_to_remove.add(new_out.name) vars_to_remove.add(new_out.name)
value_table[new_out.name] = new_out value_table[new_out.name] = new_out
to_bind[orig_out.name] = new_out.name to_bind[orig_out.name] = new_out.name
...@@ -472,7 +494,8 @@ def _lower(block, reverse, blacklist): ...@@ -472,7 +494,8 @@ def _lower(block, reverse, blacklist):
inputs = {} inputs = {}
for i in range(len(op.input_names)): for i in range(len(op.input_names)):
inputs[op.input_names[i]] = bind_name( inputs[op.input_names[i]] = bind_name(
op.input(op.input_names[i]), to_bind) op.input(op.input_names[i]), to_bind
)
outputs = {} outputs = {}
for i in range(len(op.output_names)): for i in range(len(op.output_names)):
...@@ -482,14 +505,17 @@ def _lower(block, reverse, blacklist): ...@@ -482,14 +505,17 @@ def _lower(block, reverse, blacklist):
for name in sorted(op.attr_names): for name in sorted(op.attr_names):
attrs[name] = op.attr(name) attrs[name] = op.attr(name)
from paddle.fluid.dygraph.base import param_guard from paddle.fluid.dygraph.base import param_guard
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
with param_guard(inputs), param_guard(outputs): with param_guard(inputs), param_guard(outputs):
op = Operator(block=block, op = Operator(
block=block,
desc=new_op_desc, desc=new_op_desc,
type=op.type, type=op.type,
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
attrs=attrs) attrs=attrs,
)
block.ops.append(op) block.ops.append(op)
# Step3: Do some post-processing work # Step3: Do some post-processing work
...@@ -509,8 +535,9 @@ def _lower(block, reverse, blacklist): ...@@ -509,8 +535,9 @@ def _lower(block, reverse, blacklist):
op._rename_output(out_name, to_bind_rev[out_name]) op._rename_output(out_name, to_bind_rev[out_name])
for var_name in sorted(vars_to_remove): for var_name in sorted(vars_to_remove):
assert var_name in to_bind_rev, 'var_name "{}" is not in to_bind_rev.'.format( assert (
var_name) var_name in to_bind_rev
), 'var_name "{}" is not in to_bind_rev.'.format(var_name)
if var_name != to_bind_rev[var_name]: if var_name != to_bind_rev[var_name]:
block.desc._remove_var(cpt.to_bytes(var_name)) block.desc._remove_var(cpt.to_bytes(var_name))
del block.vars[var_name] del block.vars[var_name]
...@@ -536,7 +563,8 @@ def orig2prim(block=None): ...@@ -536,7 +563,8 @@ def orig2prim(block=None):
""" """
block = default_main_program().current_block() if block is None else block block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block( assert (
block == default_main_program().current_block()
), f'block is neither None nor current block of main program' ), f'block is neither None nor current block of main program'
_lower(block, reverse=False, blacklist=[]) _lower(block, reverse=False, blacklist=[])
...@@ -581,7 +609,8 @@ def prim2orig(block=None, blacklist=None): ...@@ -581,7 +609,8 @@ def prim2orig(block=None, blacklist=None):
""" """
block = default_main_program().current_block() if block is None else block block = default_main_program().current_block() if block is None else block
assert block == default_main_program().current_block( assert (
block == default_main_program().current_block()
), f'block is neither None nor current block of main program' ), f'block is neither None nor current block of main program'
blacklist = [] if blacklist is None else blacklist blacklist = [] if blacklist is None else blacklist
_lower(block, reverse=True, blacklist=blacklist) _lower(block, reverse=True, blacklist=blacklist)
...@@ -18,7 +18,6 @@ from paddle.fluid import framework as framework ...@@ -18,7 +18,6 @@ from paddle.fluid import framework as framework
class PrimOption(object): class PrimOption(object):
def __init__(self): def __init__(self):
self.enable_prim = False self.enable_prim = False
...@@ -175,7 +174,7 @@ def flatten_and_remove_none(inp): ...@@ -175,7 +174,7 @@ def flatten_and_remove_none(inp):
def as_tensors(xs): def as_tensors(xs):
if isinstance(xs, framework.Variable): if isinstance(xs, framework.Variable):
return (xs, ) return (xs,)
elif isinstance(xs, typing.Sequence): elif isinstance(xs, typing.Sequence):
return tuple(xs) return tuple(xs)
else: else:
......
...@@ -14,7 +14,14 @@ ...@@ -14,7 +14,14 @@
from paddle.optimizer import Optimizer from paddle.optimizer import Optimizer
from paddle.fluid import core, framework, layers, unique_name from paddle.fluid import core, framework, layers, unique_name
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard from paddle.fluid.framework import (
Program,
Variable,
name_scope,
default_main_program,
default_startup_program,
device_guard,
)
from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layer_helper import LayerHelper
import paddle import paddle
import numpy as np import numpy as np
...@@ -116,24 +123,27 @@ class LookAhead(Optimizer): ...@@ -116,24 +123,27 @@ class LookAhead(Optimizer):
_slow_str = "slow" _slow_str = "slow"
def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None): def __init__(self, inner_optimizer, alpha=0.5, k=5, name=None):
assert (inner_optimizer is not None), "inner optimizer can not be None" assert inner_optimizer is not None, "inner optimizer can not be None"
assert ( assert (
0.0 <= alpha <= 1.0 0.0 <= alpha <= 1.0
), "alpha should be larger or equal to 0.0, and less or equal than 1.0" ), "alpha should be larger or equal to 0.0, and less or equal than 1.0"
assert (isinstance(k, int) and k > 0), "k should be a positive integer" assert isinstance(k, int) and k > 0, "k should be a positive integer"
self.inner_optimizer = inner_optimizer self.inner_optimizer = inner_optimizer
if self.inner_optimizer._parameter_list is None: if self.inner_optimizer._parameter_list is None:
parameters = framework.default_main_program().global_block( parameters = (
).all_parameters() framework.default_main_program().global_block().all_parameters()
)
else: else:
parameters = self.inner_optimizer._parameter_list parameters = self.inner_optimizer._parameter_list
super(LookAhead, self).__init__(learning_rate=alpha, super(LookAhead, self).__init__(
learning_rate=alpha,
parameters=parameters, parameters=parameters,
weight_decay=None, weight_decay=None,
grad_clip=None, grad_clip=None,
name=name) name=name,
)
self.alpha = alpha self.alpha = alpha
self.k = k self.k = k
...@@ -179,9 +189,9 @@ class LookAhead(Optimizer): ...@@ -179,9 +189,9 @@ class LookAhead(Optimizer):
grad_var = param._grad_ivar() grad_var = param._grad_ivar()
params_grads.append((param, grad_var)) params_grads.append((param, grad_var))
self._apply_optimize(loss=None, self._apply_optimize(
startup_program=None, loss=None, startup_program=None, params_grads=params_grads
params_grads=params_grads) )
def _create_accumulators(self, block, parameters): def _create_accumulators(self, block, parameters):
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
...@@ -196,24 +206,28 @@ class LookAhead(Optimizer): ...@@ -196,24 +206,28 @@ class LookAhead(Optimizer):
shape=[1], shape=[1],
value=0, value=0,
dtype='int32', dtype='int32',
persistable=True) persistable=True,
)
self.helper.append_op(type='increment', self.helper.append_op(
type='increment',
inputs={'X': [self._global_step_var]}, inputs={'X': [self._global_step_var]},
outputs={'Out': [self._global_step_var]}, outputs={'Out': [self._global_step_var]},
attrs={'step': 1.0}) attrs={'step': 1.0},
)
def _append_optimize_op(self, block, param_and_grad): def _append_optimize_op(self, block, param_and_grad):
one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones') one_var = paddle.ones(shape=[1], dtype='int32', name='lookahead_ones')
zero_var = paddle.zeros(shape=[1], zero_var = paddle.zeros(
dtype='int32', shape=[1], dtype='int32', name='lookahead_zeros'
name='lookahead_zeros') )
k_var = layers.create_global_var( k_var = layers.create_global_var(
name=unique_name.generate("lookahead_k"), name=unique_name.generate("lookahead_k"),
shape=[1], shape=[1],
value=self.k, value=self.k,
dtype='int32', dtype='int32',
persistable=True) persistable=True,
)
mod = paddle.remainder(self._global_step_var, k_var) mod = paddle.remainder(self._global_step_var, k_var)
...@@ -236,11 +250,9 @@ class LookAhead(Optimizer): ...@@ -236,11 +250,9 @@ class LookAhead(Optimizer):
paddle.assign(tmp_var_1, slow_var) paddle.assign(tmp_var_1, slow_var)
@imperative_base.no_grad @imperative_base.no_grad
def minimize(self, def minimize(
loss, self, loss, startup_program=None, parameters=None, no_grad_set=None
startup_program=None, ):
parameters=None,
no_grad_set=None):
""" """
Add operations to minimize ``loss`` by updating ``parameters``. Add operations to minimize ``loss`` by updating ``parameters``.
...@@ -287,12 +299,13 @@ class LookAhead(Optimizer): ...@@ -287,12 +299,13 @@ class LookAhead(Optimizer):
loss, loss,
startup_program=startup_program, startup_program=startup_program,
parameters=parameters, parameters=parameters,
no_grad_set=no_grad_set) no_grad_set=no_grad_set,
)
self._increment_global_var() self._increment_global_var()
_ = self._apply_optimize(loss, _ = self._apply_optimize(
startup_program=startup_program, loss, startup_program=startup_program, params_grads=params_grads
params_grads=params_grads) )
return optimize_ops, params_grads return optimize_ops, params_grads
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -20,14 +20,16 @@ from paddle import _C_ops, _legacy_C_ops ...@@ -20,14 +20,16 @@ from paddle import _C_ops, _legacy_C_ops
from paddle import in_dynamic_mode from paddle import in_dynamic_mode
def sparse_attention(query, def sparse_attention(
query,
key, key,
value, value,
sparse_csr_offset, sparse_csr_offset,
sparse_csr_columns, sparse_csr_columns,
key_padding_mask=None, key_padding_mask=None,
attn_mask=None, attn_mask=None,
name=None): name=None,
):
r""" r"""
This operator sparsify the Attention matrix in Transformer module This operator sparsify the Attention matrix in Transformer module
to achieve the effect of reducing memory consumption and computation. to achieve the effect of reducing memory consumption and computation.
...@@ -144,9 +146,19 @@ def sparse_attention(query, ...@@ -144,9 +146,19 @@ def sparse_attention(query,
# [1.99830270, 2.99830270]]]] # [1.99830270, 2.99830270]]]]
""" """
if in_dynamic_mode(): if in_dynamic_mode():
result_attention, result_sdd, result_softmax = _legacy_C_ops.sparse_attention( (
query, key, value, sparse_csr_offset, sparse_csr_columns, result_attention,
key_padding_mask, attn_mask) result_sdd,
result_softmax,
) = _legacy_C_ops.sparse_attention(
query,
key,
value,
sparse_csr_offset,
sparse_csr_columns,
key_padding_mask,
attn_mask,
)
return result_attention return result_attention
helper = LayerHelper('sparse_attention', **locals()) helper = LayerHelper('sparse_attention', **locals())
...@@ -166,7 +178,7 @@ def sparse_attention(query, ...@@ -166,7 +178,7 @@ def sparse_attention(query,
outputs = { outputs = {
'Out': out, 'Out': out,
'SparseDotSdd': result_sdd, 'SparseDotSdd': result_sdd,
'Softmax': result_softmax 'Softmax': result_softmax,
} }
helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs) helper.append_op(type='sparse_attention', inputs=inputs, outputs=outputs)
return out return out
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
...@@ -37,16 +37,15 @@ def l2_norm(x, axis, epsilon=1e-12, name=None): ...@@ -37,16 +37,15 @@ def l2_norm(x, axis, epsilon=1e-12, name=None):
helper = LayerHelper("l2_normalize", **locals()) helper = LayerHelper("l2_normalize", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype) out = helper.create_variable_for_type_inference(dtype=x.dtype)
norm = helper.create_variable_for_type_inference(dtype=x.dtype) norm = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(type="norm", helper.append_op(
type="norm",
inputs={"X": x}, inputs={"X": x},
outputs={ outputs={"Out": out, "Norm": norm},
"Out": out,
"Norm": norm
},
attrs={ attrs={
"axis": 1 if axis is None else axis, "axis": 1 if axis is None else axis,
"epsilon": epsilon, "epsilon": epsilon,
}) },
)
return paddle.squeeze(norm, axis=[axis]) return paddle.squeeze(norm, axis=[axis])
...@@ -93,14 +92,13 @@ def _weight_norm(v, g, dim): ...@@ -93,14 +92,13 @@ def _weight_norm(v, g, dim):
v_normalized = F.l2_normalize(p_matrix, axis=1) v_normalized = F.l2_normalize(p_matrix, axis=1)
v_normalized = paddle.reshape(v_normalized, transposed_shape) v_normalized = paddle.reshape(v_normalized, transposed_shape)
v_normalized = paddle.transpose(v_normalized, perm) v_normalized = paddle.transpose(v_normalized, perm)
weight = F.elementwise_mul(v_normalized, weight = F.elementwise_mul(
g, v_normalized, g, axis=dim if dim is not None else -1
axis=dim if dim is not None else -1) )
return weight return weight
class WeightNorm(object): class WeightNorm(object):
def __init__(self, name, dim): def __init__(self, name, dim):
if dim is None: if dim is None:
dim = -1 dim = -1
...@@ -116,8 +114,10 @@ class WeightNorm(object): ...@@ -116,8 +114,10 @@ class WeightNorm(object):
def apply(layer, name, dim): def apply(layer, name, dim):
for k, hook in layer._forward_pre_hooks.items(): for k, hook in layer._forward_pre_hooks.items():
if isinstance(hook, WeightNorm) and hook.name == name: if isinstance(hook, WeightNorm) and hook.name == name:
raise RuntimeError("Cannot register two weight_norm hooks on " raise RuntimeError(
"the same parameter {}".format(name)) "Cannot register two weight_norm hooks on "
"the same parameter {}".format(name)
)
if dim is None: if dim is None:
dim = -1 dim = -1
......
...@@ -94,11 +94,14 @@ def export(layer, path, input_spec=None, opset_version=9, **configs): ...@@ -94,11 +94,14 @@ def export(layer, path, input_spec=None, opset_version=9, **configs):
raise ValueError( raise ValueError(
"The input path MUST be format of dirname/file_prefix " "The input path MUST be format of dirname/file_prefix "
"[dirname\\file_prefix in Windows system], but " "[dirname\\file_prefix in Windows system], but "
"the file_prefix is empty in received path: {}".format(path)) "the file_prefix is empty in received path: {}".format(path)
)
save_file = path + '.onnx' save_file = path + '.onnx'
p2o.dygraph2onnx(layer, p2o.dygraph2onnx(
layer,
save_file, save_file,
input_spec=input_spec, input_spec=input_spec,
opset_version=opset_version, opset_version=opset_version,
**configs) **configs
)
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册