未验证 提交 983ae1d7 编写于 作者: W wanghuancoder 提交者: GitHub

delete legacy dygraph code in python/paddle/distributed (#49304)

* delete legacy dygraph code in python/paddle/distributed

* refine
上级 91cdd295
......@@ -31,7 +31,7 @@ from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.executor import _to_name_str, global_scope
from paddle.fluid.framework import Operator
from paddle.fluid.framework import _current_expected_place as _get_device
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers.utils import flatten
from paddle.metric import Metric
from paddle.static import InputSpec
......@@ -300,7 +300,7 @@ class Engine:
return inputs_spec, labels_spec
def _prepare_data_tensor(self, inputs_spec, labels_spec, inputs, labels):
if _non_static_mode() or self._dygraph_mode:
if in_dygraph_mode() or self._dygraph_mode:
raise ValueError("Only support static graph mode.")
if inputs_spec:
......@@ -512,7 +512,7 @@ class Engine:
self._has_prepared[mode] = True
def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
if in_dygraph_mode() or self._dygraph_mode:
paddle.disable_static()
self._dygraph_mode = True
self._logger.info("Building model with 'to_static' method.")
......@@ -1713,7 +1713,7 @@ class Engine:
self._build(mode)
self._plan(mode)
else:
if _non_static_mode() or self._dygraph_mode:
if in_dygraph_mode() or self._dygraph_mode:
raise ValueError(
"Please call `prepare()` or `fit()` or `evaluate()` or `predict()` before calling `cost()`."
)
......
......@@ -17,8 +17,8 @@ from collections import OrderedDict
import paddle
import paddle.fluid.core as core
from paddle import _legacy_C_ops
from paddle.fluid.framework import in_dygraph_mode
from ...fluid.framework import _non_static_mode
from ...fluid.layers.tensor import fill_constant
from ..collective import _get_global_env, _new_ring_id
......@@ -154,7 +154,7 @@ class ProcessGroup:
)
tmp = (
paddle.to_tensor([1], dtype="int32")
if _non_static_mode()
if in_dygraph_mode()
else fill_constant([0], dtype="int32", value="1")
)
# use legacy ops
......
......@@ -18,7 +18,7 @@ import paddle
# (TODO: GhostScreaming) It will be removed later.
import paddle.fluid.core as core
from paddle.framework import _non_static_mode, in_dygraph_mode
from paddle.framework import in_dygraph_mode
from .communication.group import Group, _add_new_group, is_initialized
from .fleet.layers.mpu.mp_ops import _c_concat # noqa: F401
......@@ -301,7 +301,7 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
# hang caused by cross-creation of new_group
tmp = (
paddle.to_tensor([1], dtype="int32")
if _non_static_mode()
if in_dygraph_mode()
else paddle.full([0], 1, dtype="int32")
)
paddle.distributed.all_reduce(tmp, sync_op=True)
......
......@@ -18,7 +18,6 @@ import pickle
import numpy as np
import paddle
import paddle.distributed as dist
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
......@@ -64,38 +63,7 @@ def all_gather(tensor_list, tensor, group=None, sync_op=True):
print(tensor_list)
# [[[4, 5, 6], [4, 5, 6]], [[1, 2, 3], [1, 2, 3]]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.all_gather(tensor_list, tensor, group, sync_op)
# NOTE: uncomment code below when having fully complex support
# def convert_to_complex(list_of_tensor):
# list_of_complex = []
# for tensor in list_of_tensor:
# list_of_complex.append(paddle.as_complex(tensor))
# return list_of_complex
# is_input_complex = (tensor.dtype == paddle.complex64
# or tensor.dtype == paddle.complex128)
# if is_input_complex:
# tensor = paddle.as_real(tensor)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
nranks = dist.get_world_size()
out = paddle._legacy_C_ops.c_allgather(
tensor,
'use_calc_stream',
sync_op,
'ring_id',
ring_id,
'nranks',
nranks,
)
tensor_list.clear()
tensor_list.extend(paddle.split(out, nranks, 0))
return stream.all_gather(tensor_list, tensor, group, sync_op)
def _convert_object_to_tensor(obj):
......
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from paddle.distributed.communication.reduce import ReduceOp
......@@ -57,31 +55,6 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.all_reduce(
tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=False
)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if op == ReduceOp.SUM:
return paddle._legacy_C_ops.c_allreduce_sum_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
)
elif op == ReduceOp.MAX:
return paddle._legacy_C_ops.c_allreduce_max_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
)
elif op == ReduceOp.MIN:
return paddle._legacy_C_ops.c_allreduce_min_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
)
elif op == ReduceOp.PROD:
return paddle._legacy_C_ops.c_allreduce_prod_(
tensor, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
)
else:
raise ValueError("Unknown parameter: {}.".format(op))
return stream.all_reduce(
tensor, op=op, group=group, sync_op=sync_op, use_calc_stream=False
)
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
......@@ -59,22 +57,9 @@ def alltoall(in_tensor_list, out_tensor_list, group=None, sync_op=True):
# [[[1, 2, 3], [4, 5, 6]], [[13, 14, 15], [16, 17, 18]]] (2 GPUs, out for rank 0)
# [[[7, 8, 9], [10, 11, 12]], [[19, 20, 21], [22, 23, 24]]] (2 GPUs, out for rank 1)
"""
if not framework._in_legacy_dygraph():
return stream.alltoall(
out_tensor_list, in_tensor_list, group, sync_op, False
)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
temp = paddle.concat(in_tensor_list, axis=0)
nranks = len(in_tensor_list)
use_calc_stream = sync_op
out = paddle._legacy_C_ops.alltoall(
temp, 'use_calc_stream', use_calc_stream, 'ring_id', ring_id
return stream.alltoall(
out_tensor_list, in_tensor_list, group, sync_op, False
)
out_tensor_list.extend(paddle.split(out, nranks, 0))
def alltoall_single(
......@@ -149,13 +134,12 @@ def alltoall_single(
# output for rank 1: [[0., 0.], [0., 0.], [1., 1.], [1., 1.]]
"""
if not framework._in_legacy_dygraph():
return stream.alltoall_single(
out_tensor,
in_tensor,
out_split_sizes,
in_split_sizes,
group,
sync_op,
False,
)
return stream.alltoall_single(
out_tensor,
in_tensor,
out_split_sizes,
in_split_sizes,
group,
sync_op,
False,
)
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
def broadcast(tensor, src, group=None, sync_op=True):
......@@ -55,31 +53,10 @@ def broadcast(tensor, src, group=None, sync_op=True):
print(data)
# [[1, 2, 3], [1, 2, 3]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.broadcast(
tensor,
src,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
assert gsrc >= 0, "src rank out of group, need global rank"
return paddle._legacy_C_ops.c_broadcast(
tensor,
return stream.broadcast(
tensor,
'root',
gsrc,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
src,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
......@@ -19,6 +19,7 @@ import paddle.distributed as dist
import paddle.fluid.core as core
import paddle.fluid.framework as framework
import paddle.fluid.layer_helper as layer_helper
from paddle.fluid.framework import in_dygraph_mode
class Group:
......@@ -235,32 +236,32 @@ def get_group(id=0):
def _sync_calc_stream(tensor):
if framework._non_static_mode():
if in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_calc_stream(tensor, tensor)
op_type = 'c_sync_calc_stream'
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
)
else:
op_type = 'c_sync_calc_stream'
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
)
def _sync_comm_stream(tensor, ring_id=0):
if framework._non_static_mode():
if in_dygraph_mode():
return paddle._legacy_C_ops.c_sync_comm_stream(
[tensor], [tensor], 'ring_id', ring_id
)
op_type = 'c_sync_comm_stream'
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id},
)
else:
op_type = 'c_sync_comm_stream'
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={'ring_id': ring_id},
)
def wait(tensor, group=None, use_calc_stream=True):
......@@ -336,18 +337,18 @@ def barrier(group=None):
ring_id = 0 if group is None else group.id
barrier_tensor = paddle.full([1], 1, dtype="int32")
if framework._non_static_mode():
if in_dygraph_mode():
return paddle._legacy_C_ops.barrier(
barrier_tensor, barrier_tensor, 'ring_id', ring_id
)
op_type = 'barrier'
if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.")
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [barrier_tensor]},
outputs={'Out': [barrier_tensor]},
attrs={'ring_id': ring_id},
)
else:
op_type = 'barrier'
if not isinstance(ring_id, int):
raise ValueError("The type of 'group' for barrier must be int.")
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(
type=op_type,
inputs={'X': [barrier_tensor]},
outputs={'Out': [barrier_tensor]},
attrs={'ring_id': ring_id},
)
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
def recv(tensor, src=0, group=None, sync_op=True):
......@@ -48,29 +46,8 @@ def recv(tensor, src=0, group=None, sync_op=True):
print(data)
# [7, 8, 9] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.recv(
tensor, src=src, group=group, sync_op=sync_op, use_calc_stream=False
)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
use_calc_stream = sync_op
gsrc = src if group is None else group.get_group_rank(src)
ring_id = 0 if group is None else group.id
return paddle._legacy_C_ops.recv_v2(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'peer',
src,
'dtype',
tensor.dtype,
'out_shape',
tensor.shape,
return stream.recv(
tensor, src=src, group=group, sync_op=sync_op, use_calc_stream=False
)
......
......@@ -121,16 +121,14 @@ def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
# [[5, 7, 9], [5, 7, 9]] (2 GPUs, out for rank 0)
# [[1, 2, 3], [1, 2, 3]] (2 GPUs, out for rank 1)
"""
if not framework._in_legacy_dygraph():
return stream.reduce(
tensor,
dst=dst,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
return stream.reduce(
tensor,
dst=dst,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from paddle.distributed.communication.reduce import ReduceOp
from paddle.distributed.communication.stream.reduce_scatter import (
_reduce_scatter_base as _reduce_scatter_base_stream,
......@@ -62,15 +61,14 @@ def reduce_scatter(
# [8, 10] (2 GPUs, out for rank 1)
"""
if not framework._in_legacy_dygraph():
return stream.reduce_scatter(
tensor,
tensor_list,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
return stream.reduce_scatter(
tensor,
tensor_list,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
def _reduce_scatter_base(
......@@ -111,12 +109,11 @@ def _reduce_scatter_base(
# [5, 7] (2 GPUs, out for rank 1)
"""
if not framework._in_legacy_dygraph():
return _reduce_scatter_base_stream(
output,
input,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
return _reduce_scatter_base_stream(
output,
input,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False,
)
......@@ -12,10 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
from paddle.distributed.communication.group import _get_global_group
def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
......@@ -61,34 +58,4 @@ def scatter(tensor, tensor_list=None, src=0, group=None, sync_op=True):
# [1, 2, 3] [10, 11, 12] (2 GPUs, out for rank 0)
# [4, 5, 6] [4, 5, 6] (2 GPUs, out for rank 1)
"""
if not framework._in_legacy_dygraph():
return stream.scatter(tensor, tensor_list, src, group, sync_op)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
ring_id = 0 if group is None else group.id
gsrc = src if group is None else group.get_group_rank(src)
rank = _get_global_group().rank if group is None else group.rank
nranks = _get_global_group().nranks if group is None else group.nranks
assert gsrc >= 0, "src rank out of group, need global rank"
if rank != gsrc:
tensor_list = []
for _ in range(nranks):
tensor_list.append(tensor)
temp = paddle.concat(tensor_list, axis=0)
use_calc_stream = sync_op
return framework._legacy_C_ops.c_scatter(
temp,
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'nranks',
nranks,
'root',
gsrc,
)
return stream.scatter(tensor, tensor_list, src, group, sync_op)
......@@ -12,9 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.distributed.communication.stream as stream
import paddle.fluid.framework as framework
def send(tensor, dst=0, group=None, sync_op=True):
......@@ -48,27 +46,8 @@ def send(tensor, dst=0, group=None, sync_op=True):
print(data)
# [7, 8, 9] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.send(
tensor, dst=dst, group=group, sync_op=sync_op, use_calc_stream=False
)
# code below will be removed after we remove the old dygraph
if group is not None and not group.is_member():
return
use_calc_stream = sync_op
gdst = dst if group is None else group.get_group_rank(dst)
assert gdst >= 0, "dst rank out of group, need global rank"
ring_id = 0 if group is None else group.id
return paddle._legacy_C_ops.send_v2(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'peer',
gdst,
return stream.send(
tensor, dst=dst, group=group, sync_op=sync_op, use_calc_stream=False
)
......
......@@ -18,13 +18,7 @@ from paddle.common_ops_import import dygraph_utils
from paddle.distributed import collective
from paddle.fluid import core
from paddle.fluid.data_feeder import check_dtype, check_variable_and_dtype
from paddle.framework import (
LayerHelper,
_in_legacy_dygraph,
_varbase_creator,
in_dygraph_mode,
in_dynamic_mode,
)
from paddle.framework import LayerHelper, _varbase_creator, in_dygraph_mode
from paddle.nn import Layer
from ....communication.reduce import ReduceOp, _get_reduce_op
......@@ -69,39 +63,29 @@ def _c_identity(tensor, group=None):
return dy
return c_identity_eager.apply(tensor)
else:
op_type = 'c_identity'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
elif _in_legacy_dygraph():
return _legacy_C_ops.c_identity(
check_variable_and_dtype(
tensor,
'use_calc_stream',
True,
'ring_id',
ring_id,
'use_model_parallel',
True,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity',
)
op_type = 'c_identity'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_identity',
)
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True,
},
)
return out
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True,
},
)
return out
def _c_concat(tensor, group=None):
......@@ -125,7 +109,7 @@ def _c_concat(tensor, group=None):
rank = group.rank
nranks = group.nranks
if in_dynamic_mode():
if in_dygraph_mode():
return _legacy_C_ops.c_concat(
tensor,
'ring_id',
......@@ -139,31 +123,31 @@ def _c_concat(tensor, group=None):
'use_model_parallel',
True,
)
else:
op_type = 'c_concat'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
op_type = 'c_concat'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_concat',
)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_concat',
)
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True,
'nranks': nranks,
'rank': rank,
},
)
return out
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'use_model_parallel': True,
'nranks': nranks,
'rank': rank,
},
)
return out
def _c_split(tensor, group=None):
......@@ -191,7 +175,7 @@ def _c_split(tensor, group=None):
else group.nranks
)
if in_dynamic_mode():
if in_dygraph_mode():
return _legacy_C_ops.c_split(
tensor,
'use_calc_stream',
......@@ -205,31 +189,31 @@ def _c_split(tensor, group=None):
'use_model_parallel',
True,
)
else:
op_type = 'c_split'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
op_type = 'c_split'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_split',
)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
'_c_split',
)
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'rank': rank,
'nranks': nranks,
'use_model_parallel': True,
},
)
return out
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
'rank': rank,
'nranks': nranks,
'use_model_parallel': True,
},
)
return out
def _mp_allreduce(
......@@ -286,41 +270,29 @@ def _mp_allreduce(
return mp_allreduce_eager.apply(
tensor, group, use_calc_stream, use_model_parallel
)
else:
ring_id = 0 if group is None else group.id
op_type = 'mp_allreduce_sum'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
ring_id = 0 if group is None else group.id
if _in_legacy_dygraph():
if op == ReduceOp.SUM:
return _legacy_C_ops.mp_allreduce_sum_(
tensor,
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
)
else:
raise ValueError("Unknown parameter: {}.".format(op))
op_type = 'mp_allreduce_sum'
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=tensor.dtype)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
op_type,
)
check_variable_and_dtype(
tensor,
'tensor',
['float16', 'float32', 'float64', 'int32', 'int64'],
op_type,
)
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
},
)
return out
helper.append_op(
type=op_type,
inputs={'X': tensor},
outputs={'Out': out},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
},
)
return out
def _c_lookup_table(table, index, start_index=0, name=None):
......@@ -337,23 +309,23 @@ def _c_lookup_table(table, index, start_index=0, name=None):
Returns:
Tensor.
"""
if in_dynamic_mode():
if in_dygraph_mode():
return _legacy_C_ops.c_embedding(
table, index, "start_index", start_index
)
op_type = 'c_embedding'
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='table')
check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='c_embedding',
inputs={'Ids': index, 'W': table},
outputs={'Out': tmp},
attrs={"start_index": start_index},
)
return tmp
else:
op_type = 'c_embedding'
helper = LayerHelper(op_type, **locals())
dtype = helper.input_dtype(input_param_name='table')
check_variable_and_dtype(index, 'input', ['int32', 'int64'], op_type)
tmp = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='c_embedding',
inputs={'Ids': index, 'W': table},
outputs={'Out': tmp},
attrs={"start_index": start_index},
)
return tmp
class _Linear(Layer):
......@@ -426,7 +398,7 @@ def _c_softmax_with_cross_entropy(
if input_dims - 1 == label_dims:
label = paddle.unsqueeze(label, axis=-1)
if in_dynamic_mode():
if in_dygraph_mode():
softmax, loss = _legacy_C_ops.c_softmax_with_cross_entropy(
logits, label, 'ring_id', ring_id, 'rank', rank, 'nranks', nranks
)
......@@ -434,33 +406,33 @@ def _c_softmax_with_cross_entropy(
return loss
else:
return loss, softmax
else:
attrs = {
'ring_id': ring_id,
'rank': rank,
'nranks': nranks,
}
helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='c_softmax_with_cross_entropy',
inputs={'Logits': logits, 'Label': label},
outputs={'Softmax': softmax, 'Loss': loss},
attrs=attrs,
)
attrs = {
'ring_id': ring_id,
'rank': rank,
'nranks': nranks,
}
helper = LayerHelper('c_softmax_with_cross_entropy', **locals())
softmax = helper.create_variable_for_type_inference(dtype=logits.dtype)
loss = helper.create_variable_for_type_inference(dtype=logits.dtype)
helper.append_op(
type='c_softmax_with_cross_entropy',
inputs={'Logits': logits, 'Label': label},
outputs={'Softmax': softmax, 'Loss': loss},
attrs=attrs,
)
if return_softmax:
return loss, softmax
if return_softmax:
return loss, softmax
return loss
return loss
def _linear(x, weight, bias=None, name=None):
"""
Fuction Linear
"""
if in_dynamic_mode():
if in_dygraph_mode():
pre_bias = _varbase_creator(dtype=x.dtype)
_legacy_C_ops.matmul(
x,
......@@ -827,7 +799,7 @@ def split(
supported_operations
)
)
if in_dynamic_mode():
if in_dygraph_mode():
raise ValueError(
"paddle.distributed.split cannot be used in dynamic "
"graph mode, plese use ParallelEmbedding, ParallelRowLinear, "
......
......@@ -20,7 +20,8 @@ import paddle
from paddle import _legacy_C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.framework import LayerHelper, in_dynamic_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.framework import LayerHelper
from paddle.static import Variable
__all__ = []
......@@ -211,7 +212,7 @@ def dropout(
) # semantic transfer
# dygraph using tracker, doesn't need determinate seed
if in_dynamic_mode():
if in_dygraph_mode():
out, mask = _legacy_C_ops.dropout(
x,
'dropout_prob',
......@@ -226,34 +227,34 @@ def dropout(
mode,
)
return out
else:
seed = determinate_seed(rng_name)
seed = determinate_seed(rng_name)
if isinstance(p, Variable) and not p.shape != [1]:
raise TypeError(
"Required p.shape == [1] if type(p) is Variable, but received p.shape = {}".format(
p.shape
if isinstance(p, Variable) and not p.shape != [1]:
raise TypeError(
"Required p.shape == [1] if type(p) is Variable, but received p.shape = {}".format(
p.shape
)
)
)
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
)
helper = LayerHelper('dropout', **locals())
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'dropout'
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
)
out = helper.create_variable_for_type_inference(dtype=x.dtype)
mask = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.UINT8, stop_gradient=True
)
helper.append_op(
type='dropout',
inputs={'X': [x], 'Seed': seed},
outputs={'Out': [out], 'Mask': [mask]},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
},
)
return out
helper.append_op(
type='dropout',
inputs={'X': [x], 'Seed': seed},
outputs={'Out': [out], 'Mask': [mask]},
attrs={
'dropout_prob': p,
'is_test': not training,
'dropout_implementation': mode,
},
)
return out
......@@ -19,10 +19,10 @@ from .meta_optimizer_base import MetaOptimizerBase
__all__ = []
import paddle
from paddle import framework
from paddle.common_ops_import import LayerHelper
from paddle.fluid.clip import GradientClipByNorm, append_gradient_clip_ops
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.optimizer import Momentum, Optimizer
from paddle.framework import core
from paddle.static import create_global_var
......@@ -46,7 +46,7 @@ class DGCMomentumOptimizer(Optimizer):
grad_clip=None,
name=None,
):
if framework._non_static_mode():
if in_dygraph_mode():
raise Exception("In dygraph, don't support DGCMomentumOptimizer.")
assert (
......
......@@ -16,8 +16,7 @@ import numpy as np
import paddle
import paddle.fluid.core as core
from paddle import _legacy_C_ops
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode
from ...utils.log_util import logger
from .utils import number_2_dtype, paddle_2_number
......@@ -189,21 +188,7 @@ def _partial_send_op(
tensor, group, use_calc_stream, ring_id, dst, nranks, rank_id
):
dst_rank_in_group = dst if group is None else group.get_group_rank(dst)
if _in_legacy_dygraph():
return _legacy_C_ops.partial_send(
tensor.detach(),
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'peer',
dst_rank_in_group,
'num',
nranks,
'id',
rank_id,
)
elif in_dygraph_mode():
if in_dygraph_mode():
group = (
paddle.distributed.collective._get_default_group()
if group is None
......@@ -234,12 +219,7 @@ def send_partial(
tensor, group, use_calc_stream, ring_id, dst_rank, nranks, rank_id
)
else:
if _in_legacy_dygraph():
send_op = lambda x, dst, group: paddle.distributed.send(
x, dst, group, use_calc_stream
)
elif in_dygraph_mode():
send_op = paddle.distributed.isend
send_op = paddle.distributed.isend
return send_op(tensor.detach(), dst=dst_rank, group=group)
......@@ -247,37 +227,17 @@ def _partial_recv_op(
tensor, group, use_calc_stream, ring_id, src, nranks, rank_id
):
src_rank_in_group = src if group is None else group.get_group_rank(src)
if _in_legacy_dygraph():
assert use_calc_stream
return _legacy_C_ops.partial_recv(
tensor.detach(),
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'peer',
src_rank_in_group,
'num',
nranks,
'id',
rank_id,
'dtype',
tensor.dtype,
'out_shape',
tensor.shape,
)
elif in_dygraph_mode():
group = (
paddle.distributed.collective._get_default_group()
if group is None
else group
)
comm_op = (
group.process_group.recv_partial_on_calc_stream
if use_calc_stream
else group.process_group.recv_partial
)
return comm_op(tensor, src_rank_in_group, nranks, rank_id)
group = (
paddle.distributed.collective._get_default_group()
if group is None
else group
)
comm_op = (
group.process_group.recv_partial_on_calc_stream
if use_calc_stream
else group.process_group.recv_partial
)
return comm_op(tensor, src_rank_in_group, nranks, rank_id)
def recv_partial(
......@@ -297,7 +257,7 @@ def recv_partial(
tensor, group, use_calc_stream, ring_id, src_rank, nranks, rank_id
)
else:
if _in_legacy_dygraph() or use_calc_stream:
if use_calc_stream:
recv_op = paddle.distributed.recv
elif in_dygraph_mode():
recv_op = paddle.distributed.irecv
......@@ -307,30 +267,17 @@ def recv_partial(
def _partial_allgather_op(
tensor, group, use_calc_stream, ring_id, nranks, rank_id
):
if _in_legacy_dygraph():
return _legacy_C_ops.partial_allgather_(
tensor.detach(),
'use_calc_stream',
use_calc_stream,
'ring_id',
ring_id,
'nranks',
nranks,
'rank',
rank_id,
)
elif in_dygraph_mode():
group = (
paddle.distributed.collective._get_default_group()
if group is None
else group
)
comm_op = (
group.process_group.all_gather_partial_on_calc_stream
if use_calc_stream
else group.process_group.all_gather_partial
)
return comm_op(tensor, tensor, nranks, rank_id)
group = (
paddle.distributed.collective._get_default_group()
if group is None
else group
)
comm_op = (
group.process_group.all_gather_partial_on_calc_stream
if use_calc_stream
else group.process_group.all_gather_partial
)
return comm_op(tensor, tensor, nranks, rank_id)
def allgather_partial(
......
......@@ -14,8 +14,8 @@
import copy
import paddle
from paddle.distributed import fleet
from paddle.fluid.framework import in_dygraph_mode
from .meta_optimizers import HeterParallelOptimizer, HybridParallelOptimizer
from .utils.log_util import logger
......@@ -74,7 +74,7 @@ def _dygraph_distributed_optimizer(optimizer, strategy=None):
def distributed_optimizer(*args, **kwargs):
if paddle.framework._non_static_mode():
if in_dygraph_mode():
return _dygraph_distributed_optimizer(*args, **kwargs)
else:
return fleet.fleet.distributed_optimizer(*args, **kwargs)
......@@ -20,7 +20,8 @@ import paddle.distributed.fleet as fleet
# (TODO: GhostScreaming) It will be removed later.
import paddle.fluid.core as core
from paddle.framework import Block, Program, _non_static_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.framework import Block, Program
class HybridParallelInferenceHelper:
......@@ -205,7 +206,7 @@ class HybridParallelInferenceHelper:
elif core.is_compiled_with_cuda():
self._device = "gpu"
assert self._device, "Only gpu and npu are supported."
assert not _non_static_mode(), "Only static mode is supported."
assert not in_dygraph_mode(), "Only static mode is supported."
op_maker = core.op_proto_and_checker_maker
self._op_role = op_maker.OpRole
......
......@@ -18,7 +18,6 @@ from paddle import framework
# (TODO: GhostScreaming) It will be removed later.
from paddle.fluid import core
from paddle.framework import (
_in_legacy_dygraph,
_split_tensors,
build_groups,
in_dygraph_mode,
......@@ -215,39 +214,12 @@ def sharding_reduce_gradients(parameter_list, hcg):
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
if param.trainable and (param._grad_ivar() is not None):
if in_dygraph_mode():
param.grad.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
param.grad,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
elif _in_legacy_dygraph():
g_var = param._grad_ivar()
# need use trace_op to allreduce
# paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': g_var},
outputs={'Out': g_var},
attrs={
'ring_id': hcg.get_sharding_parallel_group().id,
'use_calc_stream': True,
},
)
# grad / sharding_rank
div_factor = paddle.to_tensor(
sharding_nrank, dtype=g_var.dtype
)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': g_var, 'Y': div_factor},
outputs={'Out': g_var},
attrs={'axis': -1},
)
param.grad.scale_(1.0 / sharding_nrank)
paddle.distributed.all_reduce(
param.grad,
group=hcg.get_sharding_parallel_group(),
sync_op=True,
)
def broadcast_sharding_parameters(model, hcg):
......
......@@ -13,9 +13,8 @@
# limitations under the License.
from paddle import _legacy_C_ops
from paddle.fluid import core
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -43,8 +42,6 @@ def _number_count(numbers, upper_range):
"""
if in_dygraph_mode():
return _legacy_C_ops.number_count(numbers, 'upper_range', upper_range)
elif _in_legacy_dygraph():
return core.ops.number_count(numbers, 'upper_range', upper_range)
else:
op_type = 'number_count'
......@@ -92,8 +89,6 @@ def _assign_pos(x, cum_count):
"""
if in_dygraph_mode():
return _legacy_C_ops.assign_pos(x, cum_count, cum_count[-1])
elif _in_legacy_dygraph():
return core.ops.assign_pos(x, cum_count, cum_count[-1])
else:
op_type = 'assign_pos'
......@@ -129,8 +124,6 @@ def _random_routing(topk_idx, topk_value, prob, topk=2):
if topk == 2:
if in_dygraph_mode():
return _legacy_C_ops.random_routing(prob, topk_value, topk_idx)
elif _in_legacy_dygraph():
return core.ops.random_routing(prob, topk_value, topk_idx)
else:
raise RuntimeError("Not supporting static mode now")
else:
......@@ -162,10 +155,6 @@ def _limit_by_capacity(expert_count, capacity, n_worker):
return _legacy_C_ops.limit_by_capacity(
expert_count, capacity, 'n_worker', n_worker
)
elif _in_legacy_dygraph():
return core.ops.limit_by_capacity(
expert_count, capacity, 'n_worker', n_worker
)
else:
op_type = 'limit_by_capacity'
......@@ -211,32 +200,29 @@ def _prune_gate_by_capacity(gate_idx, expert_count, n_expert, n_worker):
return _legacy_C_ops.prune_gate_by_capacity(
gate_idx, expert_count, "n_expert", n_expert, "n_worker", n_worker
)
elif _in_legacy_dygraph():
return core.ops.prune_gate_by_capacity(
gate_idx, expert_count, "n_expert", n_expert, "n_worker", n_worker
else:
check_variable_and_dtype(
gate_idx,
'GateIdx',
['int32', 'int64'],
'paddle.distributed.utils.prune_gate_by_capacity',
)
check_variable_and_dtype(
expert_count,
'ExpertCount',
['int32', 'int64'],
'paddle.distributed.utils.prune_gate_by_capacity',
)
check_variable_and_dtype(
gate_idx,
'GateIdx',
['int32', 'int64'],
'paddle.distributed.utils.prune_gate_by_capacity',
)
check_variable_and_dtype(
expert_count,
'ExpertCount',
['int32', 'int64'],
'paddle.distributed.utils.prune_gate_by_capacity',
)
helper = LayerHelper('prune_gate_by_capacity', **locals())
new_gate_idx = helper.create_variable_for_type_inference(
dtype=gate_idx.dtype
)
helper.append_op(
type='prune_gate_by_capacity',
inputs={'GateIdx': gate_idx, "ExpertCount": expert_count},
outputs={'NewGateIdx': new_gate_idx},
attrs={"n_expert": n_expert, "n_worker": n_worker},
)
return new_gate_idx
helper = LayerHelper('prune_gate_by_capacity', **locals())
new_gate_idx = helper.create_variable_for_type_inference(
dtype=gate_idx.dtype
)
helper.append_op(
type='prune_gate_by_capacity',
inputs={'GateIdx': gate_idx, "ExpertCount": expert_count},
outputs={'NewGateIdx': new_gate_idx},
attrs={"n_expert": n_expert, "n_worker": n_worker},
)
return new_gate_idx
......@@ -14,7 +14,7 @@
from paddle import _legacy_C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _non_static_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -103,7 +103,7 @@ def global_scatter(
return
ring_id = 0 if group is None else group.id
if _non_static_mode():
if in_dygraph_mode():
return _legacy_C_ops.global_scatter(
x,
local_count,
......@@ -220,7 +220,7 @@ def global_gather(
return
ring_id = 0 if group is None else group.id
if _non_static_mode():
if in_dygraph_mode():
return _legacy_C_ops.global_gather(
x,
local_count,
......
......@@ -15,7 +15,7 @@
import paddle
from paddle.distribution import exponential_family
from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.framework import _in_legacy_dygraph, in_dygraph_mode
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layer_helper import LayerHelper
......@@ -166,8 +166,6 @@ def _dirichlet(concentration, name=None):
if in_dygraph_mode():
return paddle._C_ops.dirichlet(concentration)
elif _in_legacy_dygraph():
return paddle._legacy_C_ops.dirichlet(concentration)
else:
helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(
......
......@@ -24,13 +24,9 @@ import warnings
import numpy as np
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle import _C_ops
from paddle.fluid.data_feeder import check_variable_and_dtype, convert_dtype
from paddle.fluid.framework import (
_in_legacy_dygraph,
_non_static_mode,
in_dygraph_mode,
)
from paddle.fluid.framework import in_dygraph_mode
from paddle.fluid.layers import tensor
......@@ -221,7 +217,7 @@ class Distribution:
Returns:
value (Tensor): Change value's dtype if value's dtype is different from param.
"""
if _non_static_mode():
if in_dygraph_mode():
if value.dtype != param.dtype and convert_dtype(value.dtype) in [
'float32',
'float64',
......@@ -229,12 +225,7 @@ class Distribution:
warnings.warn(
"dtype of input 'value' needs to be the same as parameters of distribution class. dtype of 'value' will be converted."
)
if in_dygraph_mode():
return _C_ops.cast(value, param.dtype)
if _in_legacy_dygraph():
return _legacy_C_ops.cast(
value, 'in_dtype', value.dtype, 'out_dtype', param.dtype
)
return _C_ops.cast(value, param.dtype)
return value
check_variable_and_dtype(
......
......@@ -15,14 +15,10 @@
import numpy as np
import paddle
from paddle import _C_ops, _legacy_C_ops
from paddle import _C_ops
from paddle.distribution import distribution
from paddle.fluid.data_feeder import check_type, convert_dtype
from paddle.fluid.framework import (
_in_legacy_dygraph,
_non_static_mode,
in_dygraph_mode,
)
from paddle.fluid.framework import _non_static_mode, in_dygraph_mode
from paddle.fluid.layers import tensor
from paddle.tensor import random
......@@ -210,33 +206,23 @@ class Uniform(distribution.Distribution):
"""
value = self._check_values_dtype_in_probs(self.low, value)
if _non_static_mode():
if in_dygraph_mode():
# ensure value in [low, high]
lb_bool = self.low < value
ub_bool = value < self.high
if in_dygraph_mode():
lb = _C_ops.cast(lb_bool, value.dtype)
ub = _C_ops.cast(ub_bool, value.dtype)
return paddle.log(lb * ub) - paddle.log(self.high - self.low)
if _in_legacy_dygraph():
lb = _legacy_C_ops.cast(
lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', value.dtype
)
ub = _legacy_C_ops.cast(
ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', value.dtype
)
return paddle.log(lb * ub) - paddle.log(self.high - self.low)
name = self.name + '_log_prob'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return paddle.subtract(
paddle.log(lb * ub), paddle.log(self.high - self.low), name=name
)
lb = _C_ops.cast(lb_bool, value.dtype)
ub = _C_ops.cast(ub_bool, value.dtype)
return paddle.log(lb * ub) - paddle.log(self.high - self.low)
else:
name = self.name + '_log_prob'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return paddle.subtract(
paddle.log(lb * ub), paddle.log(self.high - self.low), name=name
)
def probs(self, value):
"""Probability density/mass function.
......@@ -249,30 +235,19 @@ class Uniform(distribution.Distribution):
"""
value = self._check_values_dtype_in_probs(self.low, value)
if _non_static_mode():
if in_dygraph_mode():
lb_bool = self.low < value
ub_bool = value < self.high
if in_dygraph_mode():
lb = _C_ops.cast(lb_bool, value.dtype)
ub = _C_ops.cast(ub_bool, value.dtype)
return (lb * ub) / (self.high - self.low)
if _in_legacy_dygraph():
lb = _legacy_C_ops.cast(
lb_bool, 'in_dtype', lb_bool.dtype, 'out_dtype', value.dtype
)
ub = _legacy_C_ops.cast(
ub_bool, 'in_dtype', ub_bool.dtype, 'out_dtype', value.dtype
)
return (lb * ub) / (self.high - self.low)
name = self.name + '_probs'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return paddle.divide((lb * ub), (self.high - self.low), name=name)
lb = _C_ops.cast(lb_bool, value.dtype)
ub = _C_ops.cast(ub_bool, value.dtype)
return (lb * ub) / (self.high - self.low)
else:
name = self.name + '_probs'
lb_bool = self.low < value
ub_bool = value < self.high
lb = tensor.cast(lb_bool, dtype=value.dtype)
ub = tensor.cast(ub_bool, dtype=value.dtype)
return paddle.divide((lb * ub), (self.high - self.low), name=name)
def entropy(self):
r"""Shannon entropy in nats.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册