未验证 提交 a072fca8 编写于 作者: S Sing_chan 提交者: GitHub

【code format check upgrade】 step2:yapf (#42944)

* use yapf to format all python file

* yapf exclude two unittests file for they rely on writing and reading file, and format will break them

* disable diff_py_file because too many diff files cause command following failed
上级 92568edb

要显示的变更太多。

To preserve performance only 1000 of 1000+ files are displayed.
...@@ -4,11 +4,16 @@ repos: ...@@ -4,11 +4,16 @@ repos:
hooks: hooks:
- id: remove-crlf - id: remove-crlf
files: (?!.*third_party)^.*$ | (?!.*book)^.*$ files: (?!.*third_party)^.*$ | (?!.*book)^.*$
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git - repo: https://github.com/google/yapf
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37 sha: v0.32.0
hooks: hooks:
- id: yapf - id: yapf
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
exclude: |
(?x)^(
python/paddle/fluid/tests/unittests/dygraph_to_static/test_error.py|
python/paddle/fluid/tests/unittests/dygraph_to_static/test_origin_info.py
)$
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0 rev: v4.1.0
hooks: hooks:
......
...@@ -481,10 +481,10 @@ EOF ...@@ -481,10 +481,10 @@ EOF
} }
function cmake_gen_and_build() { function cmake_gen_and_build() {
startTime_s=`date +%s` startTime_s=100
cmake_gen $1 cmake_gen $1
build $2 build $2
endTime_s=`date +%s` endTime_s=200
[ -n "$startTime_firstBuild" ] && startTime_s=$startTime_firstBuild [ -n "$startTime_firstBuild" ] && startTime_s=$startTime_firstBuild
echo "Build Time: $[ $endTime_s - $startTime_s ]s" echo "Build Time: $[ $endTime_s - $startTime_s ]s"
echo "ipipe_log_param_Build_Time: $[ $endTime_s - $startTime_s ]s" >> ${PADDLE_ROOT}/build/build_summary.txt echo "ipipe_log_param_Build_Time: $[ $endTime_s - $startTime_s ]s" >> ${PADDLE_ROOT}/build/build_summary.txt
...@@ -1130,8 +1130,8 @@ EOF ...@@ -1130,8 +1130,8 @@ EOF
function check_diff_file_for_coverage() { function check_diff_file_for_coverage() {
diff_h_file=$(git diff --name-status test develop | awk '$1 != "D" {print $2}' | grep '\.h$' | awk -F "/" '{printf "%s,",$NF}') diff_h_file=$(git diff --name-status test develop | awk '$1 != "D" {print $2}' | grep '\.h$' | awk -F "/" '{printf "%s,",$NF}')
diff_cc_file=$(git diff --name-status test develop | awk '$1 != "D" {print $2}' | grep -E '\.(cc|c)$' | awk -F "/" '{printf "%s,",$NF}') diff_cc_file=$(git diff --name-status test develop | awk '$1 != "D" {print $2}' | grep -E '\.(cc|c)$' | awk -F "/" '{printf "%s,",$NF}')
diff_py_file=$(git diff --name-status test develop | grep '\.py$' | awk '$1 != "D" {printf "%s,",$2}') #diff_py_file=$(git diff --name-status test develop | grep '\.py$' | awk '$1 != "D" {printf "%s,",$2}')
diff_py_file='tools/test_sampcd_processor.py,tools/timeline.py'
export PADDLE_GIT_DIFF_H_FILE=${diff_h_file%*,} export PADDLE_GIT_DIFF_H_FILE=${diff_h_file%*,}
export PADDLE_GIT_DIFF_CC_FILE=${diff_cc_file%*,} export PADDLE_GIT_DIFF_CC_FILE=${diff_cc_file%*,}
export PADDLE_GIT_DIFF_PY_FILE=${diff_py_file%*,} export PADDLE_GIT_DIFF_PY_FILE=${diff_py_file%*,}
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from paddle.fluid import core from paddle.fluid import core
from .fluid import framework from .fluid import framework
__all__ = [] __all__ = []
_already_switch_to_eager_ = False _already_switch_to_eager_ = False
......
...@@ -24,6 +24,7 @@ except ImportError: ...@@ -24,6 +24,7 @@ except ImportError:
from .batch import batch # noqa: F401 from .batch import batch # noqa: F401
from .framework import monkey_patch_variable from .framework import monkey_patch_variable
from .framework import monkey_patch_math_varbase from .framework import monkey_patch_math_varbase
monkey_patch_variable() monkey_patch_variable()
monkey_patch_math_varbase() monkey_patch_math_varbase()
...@@ -52,6 +53,7 @@ if fluid.framework._in_eager_mode_: ...@@ -52,6 +53,7 @@ if fluid.framework._in_eager_mode_:
Tensor = framework.core.eager.Tensor Tensor = framework.core.eager.Tensor
else: else:
from .framework import VarBase as Tensor # noqa: F401 from .framework import VarBase as Tensor # noqa: F401
Tensor.__qualname__ = 'Tensor' # noqa: F401 Tensor.__qualname__ = 'Tensor' # noqa: F401
import paddle.compat # noqa: F401 import paddle.compat # noqa: F401
import paddle.distributed # noqa: F401 import paddle.distributed # noqa: F401
...@@ -372,272 +374,272 @@ if is_compiled_with_cinn(): ...@@ -372,272 +374,272 @@ if is_compiled_with_cinn():
disable_static() disable_static()
__all__ = [ # noqa __all__ = [ # noqa
'dtype', 'dtype',
'uint8', 'uint8',
'int8', 'int8',
'int16', 'int16',
'int32', 'int32',
'int64', 'int64',
'float16', 'float16',
'float32', 'float32',
'float64', 'float64',
'bfloat16', 'bfloat16',
'bool', 'bool',
'complex64', 'complex64',
'complex128', 'complex128',
'addmm', 'addmm',
'allclose', 'allclose',
'isclose', 'isclose',
't', 't',
'add', 'add',
'subtract', 'subtract',
'diag', 'diag',
'diagflat', 'diagflat',
'isnan', 'isnan',
'scatter_nd_add', 'scatter_nd_add',
'unstack', 'unstack',
'get_default_dtype', 'get_default_dtype',
'save', 'save',
'multinomial', 'multinomial',
'get_cuda_rng_state', 'get_cuda_rng_state',
'rank', 'rank',
'empty_like', 'empty_like',
'eye', 'eye',
'cumsum', 'cumsum',
'cumprod', 'cumprod',
'logit', 'logit',
'sign', 'sign',
'is_empty', 'is_empty',
'equal', 'equal',
'equal_all', 'equal_all',
'is_tensor', 'is_tensor',
'is_complex', 'is_complex',
'is_integer', 'is_integer',
'cross', 'cross',
'where', 'where',
'log1p', 'log1p',
'cos', 'cos',
'tan', 'tan',
'mean', 'mean',
'mode', 'mode',
'mv', 'mv',
'in_dynamic_mode', 'in_dynamic_mode',
'min', 'min',
'amin', 'amin',
'any', 'any',
'slice', 'slice',
'normal', 'normal',
'logsumexp', 'logsumexp',
'full', 'full',
'unsqueeze', 'unsqueeze',
'unsqueeze_', 'unsqueeze_',
'argmax', 'argmax',
'Model', 'Model',
'summary', 'summary',
'flops', 'flops',
'sort', 'sort',
'searchsorted', 'searchsorted',
'split', 'split',
'logical_and', 'logical_and',
'full_like', 'full_like',
'less_than', 'less_than',
'kron', 'kron',
'clip', 'clip',
'Tensor', 'Tensor',
'crop', 'crop',
'ParamAttr', 'ParamAttr',
'stanh', 'stanh',
'randint', 'randint',
'randint_like', 'randint_like',
'assign', 'assign',
'gather', 'gather',
'scale', 'scale',
'zeros', 'zeros',
'rsqrt', 'rsqrt',
'squeeze', 'squeeze',
'squeeze_', 'squeeze_',
'to_tensor', 'to_tensor',
'gather_nd', 'gather_nd',
'isinf', 'isinf',
'uniform', 'uniform',
'floor_divide', 'floor_divide',
'remainder', 'remainder',
'floor_mod', 'floor_mod',
'roll', 'roll',
'batch', 'batch',
'max', 'max',
'amax', 'amax',
'logical_or', 'logical_or',
'bitwise_and', 'bitwise_and',
'bitwise_or', 'bitwise_or',
'bitwise_xor', 'bitwise_xor',
'bitwise_not', 'bitwise_not',
'mm', 'mm',
'flip', 'flip',
'rot90', 'rot90',
'bincount', 'bincount',
'histogram', 'histogram',
'multiplex', 'multiplex',
'CUDAPlace', 'CUDAPlace',
'NPUPlace', 'NPUPlace',
'empty', 'empty',
'shape', 'shape',
'real', 'real',
'imag', 'imag',
'is_floating_point', 'is_floating_point',
'complex', 'complex',
'reciprocal', 'reciprocal',
'rand', 'rand',
'less_equal', 'less_equal',
'triu', 'triu',
'sin', 'sin',
'dist', 'dist',
'unbind', 'unbind',
'meshgrid', 'meshgrid',
'arange', 'arange',
'load', 'load',
'numel', 'numel',
'median', 'median',
'nanmedian', 'nanmedian',
'quantile', 'quantile',
'nanquantile', 'nanquantile',
'no_grad', 'no_grad',
'set_grad_enabled', 'set_grad_enabled',
'is_grad_enabled', 'is_grad_enabled',
'mod', 'mod',
'abs', 'abs',
'tril', 'tril',
'pow', 'pow',
'zeros_like', 'zeros_like',
'maximum', 'maximum',
'topk', 'topk',
'index_select', 'index_select',
'CPUPlace', 'CPUPlace',
'matmul', 'matmul',
'seed', 'seed',
'acos', 'acos',
'logical_xor', 'logical_xor',
'exp', 'exp',
'expm1', 'expm1',
'bernoulli', 'bernoulli',
'poisson', 'poisson',
'sinh', 'sinh',
'round', 'round',
'DataParallel', 'DataParallel',
'argmin', 'argmin',
'prod', 'prod',
'broadcast_shape', 'broadcast_shape',
'conj', 'conj',
'neg', 'neg',
'lgamma', 'lgamma',
'lerp', 'lerp',
'erfinv', 'erfinv',
'inner', 'inner',
'outer', 'outer',
'square', 'square',
'divide', 'divide',
'ceil', 'ceil',
'atan', 'atan',
'atan2', 'atan2',
'rad2deg', 'rad2deg',
'deg2rad', 'deg2rad',
'gcd', 'gcd',
'lcm', 'lcm',
'expand', 'expand',
'broadcast_to', 'broadcast_to',
'ones_like', 'ones_like',
'index_sample', 'index_sample',
'cast', 'cast',
'grad', 'grad',
'all', 'all',
'ones', 'ones',
'not_equal', 'not_equal',
'sum', 'sum',
'nansum', 'nansum',
'nanmean', 'nanmean',
'tile', 'tile',
'greater_equal', 'greater_equal',
'isfinite', 'isfinite',
'create_parameter', 'create_parameter',
'dot', 'dot',
'increment', 'increment',
'erf', 'erf',
'bmm', 'bmm',
'chunk', 'chunk',
'tolist', 'tolist',
'tensordot', 'tensordot',
'greater_than', 'greater_than',
'shard_index', 'shard_index',
'argsort', 'argsort',
'tanh', 'tanh',
'tanh_', 'tanh_',
'transpose', 'transpose',
'randn', 'randn',
'strided_slice', 'strided_slice',
'unique', 'unique',
'unique_consecutive', 'unique_consecutive',
'set_cuda_rng_state', 'set_cuda_rng_state',
'set_printoptions', 'set_printoptions',
'std', 'std',
'flatten', 'flatten',
'asin', 'asin',
'multiply', 'multiply',
'disable_static', 'disable_static',
'masked_select', 'masked_select',
'var', 'var',
'trace', 'trace',
'enable_static', 'enable_static',
'scatter_nd', 'scatter_nd',
'set_default_dtype', 'set_default_dtype',
'disable_signal_handler', 'disable_signal_handler',
'expand_as', 'expand_as',
'stack', 'stack',
'sqrt', 'sqrt',
'randperm', 'randperm',
'linspace', 'linspace',
'logspace', 'logspace',
'reshape', 'reshape',
'reshape_', 'reshape_',
'reverse', 'reverse',
'nonzero', 'nonzero',
'CUDAPinnedPlace', 'CUDAPinnedPlace',
'logical_not', 'logical_not',
'add_n', 'add_n',
'minimum', 'minimum',
'scatter', 'scatter',
'scatter_', 'scatter_',
'floor', 'floor',
'cosh', 'cosh',
'log', 'log',
'log2', 'log2',
'log10', 'log10',
'concat', 'concat',
'check_shape', 'check_shape',
'trunc', 'trunc',
'frac', 'frac',
'digamma', 'digamma',
'standard_normal', 'standard_normal',
'diagonal', 'diagonal',
'broadcast_tensors', 'broadcast_tensors',
'einsum', 'einsum',
'set_flags', 'set_flags',
'get_flags', 'get_flags',
'asinh', 'asinh',
'acosh', 'acosh',
'atanh', 'atanh',
'as_complex', 'as_complex',
'as_real', 'as_real',
'diff', 'diff',
'angle', 'angle',
'fmax', 'fmax',
'fmin', 'fmin',
'moveaxis', 'moveaxis',
'repeat_interleave', 'repeat_interleave',
'clone', 'clone',
'kthvalue', 'kthvalue',
'renorm', 'renorm',
'take_along_axis', 'take_along_axis',
'put_along_axis', 'put_along_axis',
'heaviside', 'heaviside',
'tril_indices', 'tril_indices',
] ]
...@@ -83,10 +83,10 @@ class GradScaler(AmpScaler): ...@@ -83,10 +83,10 @@ class GradScaler(AmpScaler):
incr_every_n_steps=1000, incr_every_n_steps=1000,
decr_every_n_nan_or_inf=2, decr_every_n_nan_or_inf=2,
use_dynamic_loss_scaling=True): use_dynamic_loss_scaling=True):
super(GradScaler, self).__init__(enable, init_loss_scaling, incr_ratio, super(GradScaler,
decr_ratio, incr_every_n_steps, self).__init__(enable, init_loss_scaling, incr_ratio, decr_ratio,
decr_every_n_nan_or_inf, incr_every_n_steps, decr_every_n_nan_or_inf,
use_dynamic_loss_scaling) use_dynamic_loss_scaling)
def scale(self, var): def scale(self, var):
""" """
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -16,6 +16,7 @@ from paddle.fluid import core ...@@ -16,6 +16,7 @@ from paddle.fluid import core
from paddle.fluid import framework from paddle.fluid import framework
from paddle.fluid.backward import gradients_with_optimizer from paddle.fluid.backward import gradients_with_optimizer
import paddle import paddle
__all__ = [] __all__ = []
...@@ -81,14 +82,16 @@ def backward(tensors, grad_tensors=None, retain_graph=False): ...@@ -81,14 +82,16 @@ def backward(tensors, grad_tensors=None, retain_graph=False):
if isinstance(in_out_list, (list, tuple)): if isinstance(in_out_list, (list, tuple)):
assert len(in_out_list) > 0, "{} connot be empyt".format(name) assert len(in_out_list) > 0, "{} connot be empyt".format(name)
for each_var in in_out_list: for each_var in in_out_list:
assert isinstance(each_var, ( assert isinstance(
paddle.Tensor, core.eager.Tensor each_var,
)), "Elements of {} must be paddle.Tensor".format(name) (paddle.Tensor, core.eager.Tensor
)), "Elements of {} must be paddle.Tensor".format(name)
return in_out_list return in_out_list
else: else:
assert isinstance(in_out_list, ( assert isinstance(
paddle.Tensor, core.eager.Tensor in_out_list,
)), "{} must be Tensor or list of Tensor".format(name) (paddle.Tensor, core.eager.Tensor
)), "{} must be Tensor or list of Tensor".format(name)
return [in_out_list] return [in_out_list]
tensors = check_tensors(tensors, "tensors") tensors = check_tensors(tensors, "tensors")
......
...@@ -139,7 +139,7 @@ def _double_backward_trick(ys, xs, v): ...@@ -139,7 +139,7 @@ def _double_backward_trick(ys, xs, v):
"""Double backward trick for computing ``jvp`` by ``vjp`` """Double backward trick for computing ``jvp`` by ``vjp``
see details: https://j-towns.github.io/2017/06/12/A-new-trick.html see details: https://j-towns.github.io/2017/06/12/A-new-trick.html
""" """
# The value of ys_grad is not important, it can be any random value in # The value of ys_grad is not important, it can be any random value in
# theory, but it's required to set stop_gradient=False. # theory, but it's required to set stop_gradient=False.
ys_grad = _zeros_like_with_grad(ys) ys_grad = _zeros_like_with_grad(ys)
xs_grad = _grad(ys, xs, ys_grad) xs_grad = _grad(ys, xs, ys_grad)
...@@ -302,10 +302,11 @@ class Hessian(object): ...@@ -302,10 +302,11 @@ class Hessian(object):
""" """
def __init__(self, func, xs, is_batched=False): def __init__(self, func, xs, is_batched=False):
def _jac_func(*xs): def _jac_func(*xs):
jac = Jacobian(func, xs, is_batched=is_batched) jac = Jacobian(func, xs, is_batched=is_batched)
if (is_batched and jac.shape[1] != 1) or (not is_batched and if (is_batched and jac.shape[1] != 1) or (not is_batched
jac.shape[0] != 1): and jac.shape[0] != 1):
raise RuntimeError( raise RuntimeError(
"The function given to Hessian shoud return as single element Tensor or batched single element Tensor." "The function given to Hessian shoud return as single element Tensor or batched single element Tensor."
) )
...@@ -362,18 +363,18 @@ class _Jacobian(object): ...@@ -362,18 +363,18 @@ class _Jacobian(object):
def _lazy_indexes(self, indexes): def _lazy_indexes(self, indexes):
idx = indexes[self._lazy_axis] idx = indexes[self._lazy_axis]
return (idx, ) if isinstance( return (idx, ) if isinstance(idx, int) else tuple(
idx, int) else tuple(range(idx.start, idx.stop, idx.step)) range(idx.start, idx.stop, idx.step))
def _flatten(self, xs): def _flatten(self, xs):
raise NotImplementedError raise NotImplementedError
def _shifted_indexes(self, indexes, lazy_axis_size=0): def _shifted_indexes(self, indexes, lazy_axis_size=0):
idx = indexes[self._lazy_axis] idx = indexes[self._lazy_axis]
shifted_lazy_axis_idx = 0 if isinstance( shifted_lazy_axis_idx = 0 if isinstance(idx, int) else slice(
idx, int) else slice(0, lazy_axis_size, 1) 0, lazy_axis_size, 1)
return indexes[:self._lazy_axis] + (shifted_lazy_axis_idx, return indexes[:self._lazy_axis] + (
) + indexes[self._lazy_axis + 1:] shifted_lazy_axis_idx, ) + indexes[self._lazy_axis + 1:]
def __getitem__(self, indexes): def __getitem__(self, indexes):
indexes = _multi_index(indexes, self.shape) indexes = _multi_index(indexes, self.shape)
...@@ -381,8 +382,8 @@ class _Jacobian(object): ...@@ -381,8 +382,8 @@ class _Jacobian(object):
if isinstance(indexes[self._lazy_axis], int): if isinstance(indexes[self._lazy_axis], int):
other_indexes = indexes[:self._lazy_axis] + \ other_indexes = indexes[:self._lazy_axis] + \
indexes[self._lazy_axis+1:] indexes[self._lazy_axis+1:]
return self._cached_evaluate(indexes[self._lazy_axis])[ return self._cached_evaluate(
other_indexes] indexes[self._lazy_axis])[other_indexes]
lazy_indexes = self._lazy_indexes(indexes) lazy_indexes = self._lazy_indexes(indexes)
part_jac = paddle.stack( part_jac = paddle.stack(
[self._cached_evaluate(i) for i in lazy_indexes], [self._cached_evaluate(i) for i in lazy_indexes],
...@@ -424,7 +425,8 @@ class _JacobianNoBatch(_Jacobian): ...@@ -424,7 +425,8 @@ class _JacobianNoBatch(_Jacobian):
def _evaluate(self, row_index): def _evaluate(self, row_index):
return self._flatten(_grad( return self._flatten(_grad(
self._flatten_ys[row_index], self._flatten_ys[row_index],
self._xs, )) self._xs,
))
class _JacobianBatchLast(_Jacobian): class _JacobianBatchLast(_Jacobian):
...@@ -508,8 +510,8 @@ def _multi_index(indexes, shape): ...@@ -508,8 +510,8 @@ def _multi_index(indexes, shape):
positive_indexes = [] positive_indexes = []
for i, index in enumerate(indexes): for i, index in enumerate(indexes):
if isinstance(index, slice): if isinstance(index, slice):
index = slice(index.start or 0, index.stop or shape[i], index = slice(index.start or 0, index.stop or shape[i], index.step
index.step or 1) or 1)
positive_indexes.append( positive_indexes.append(
slice( slice(
index.start + shape[i] if index.start < 0 else index.start, index.start + shape[i] if index.start < 0 else index.start,
...@@ -530,9 +532,8 @@ def _as_tensors(xs): ...@@ -530,9 +532,8 @@ def _as_tensors(xs):
def _stack_tensor_or_return_none(origin_list): def _stack_tensor_or_return_none(origin_list):
assert len(origin_list) > 0, "Can't not stack an empty list" assert len(origin_list) > 0, "Can't not stack an empty list"
return paddle.stack( return paddle.stack(origin_list, axis=0) if isinstance(
origin_list, axis=0) if isinstance( origin_list[0], paddle.fluid.framework.Variable) else None
origin_list[0], paddle.fluid.framework.Variable) else None
def _replace_none_with_zero_tensor(xs, refs): def _replace_none_with_zero_tensor(xs, refs):
...@@ -809,23 +810,20 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False): ...@@ -809,23 +810,20 @@ def jacobian(func, inputs, create_graph=False, allow_unused=False):
fin_size = len(inputs) fin_size = len(inputs)
fout_size = len(outputs) fout_size = len(outputs)
flat_outputs = tuple( flat_outputs = tuple(
paddle.reshape( paddle.reshape(output, shape=[-1]) for output in outputs)
output, shape=[-1]) for output in outputs)
jacobian = tuple() jacobian = tuple()
for i, flat_output in enumerate(flat_outputs): for i, flat_output in enumerate(flat_outputs):
jac_i = list([] for _ in range(fin_size)) jac_i = list([] for _ in range(fin_size))
for k in range(len(flat_output)): for k in range(len(flat_output)):
row_k = paddle.grad( row_k = paddle.grad(flat_output[k],
flat_output[k], inputs,
inputs, create_graph=create_graph,
create_graph=create_graph, retain_graph=True,
retain_graph=True, allow_unused=allow_unused)
allow_unused=allow_unused)
for j in range(fin_size): for j in range(fin_size):
jac_i[j].append( jac_i[j].append(
paddle.reshape( paddle.reshape(row_k[j], shape=[-1]) if isinstance(
row_k[j], shape=[-1]) row_k[j], paddle.Tensor) else None)
if isinstance(row_k[j], paddle.Tensor) else None)
jacobian += (tuple( jacobian += (tuple(
_stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), ) _stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), )
if fin_size == 1 and fout_size == 1: if fin_size == 1 and fout_size == 1:
...@@ -957,25 +955,22 @@ def batch_jacobian(func, inputs, create_graph=False, allow_unused=False): ...@@ -957,25 +955,22 @@ def batch_jacobian(func, inputs, create_graph=False, allow_unused=False):
fin_size = len(inputs) fin_size = len(inputs)
fout_size = len(outputs) fout_size = len(outputs)
flat_outputs = tuple( flat_outputs = tuple(
paddle.reshape( paddle.reshape(output, shape=[batch_size, -1]) for output in outputs)
output, shape=[batch_size, -1]) for output in outputs)
jacobian = tuple() jacobian = tuple()
for i, flat_output in enumerate(flat_outputs): for i, flat_output in enumerate(flat_outputs):
jac_i = list([] for _ in range(fin_size)) jac_i = list([] for _ in range(fin_size))
for k in range(flat_output.shape[1]): for k in range(flat_output.shape[1]):
row_k = paddle.grad( row_k = paddle.grad(flat_output[:, k],
flat_output[:, k], inputs,
inputs, create_graph=create_graph,
create_graph=create_graph, retain_graph=True,
retain_graph=True, allow_unused=allow_unused)
allow_unused=allow_unused)
for j in range(fin_size): for j in range(fin_size):
jac_i[j].append( jac_i[j].append(
paddle.reshape( paddle.reshape(row_k[j], shape=[-1]) if isinstance(
row_k[j], shape=[-1]) row_k[j], paddle.Tensor) else None)
if isinstance(row_k[j], paddle.Tensor) else None)
jacobian += (tuple( jacobian += (tuple(
_stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), ) _stack_tensor_or_return_none(jac_i_j) for jac_i_j in jac_i), )
if fin_size == 1 and fout_size == 1: if fin_size == 1 and fout_size == 1:
...@@ -1119,18 +1114,19 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False): ...@@ -1119,18 +1114,19 @@ def batch_hessian(func, inputs, create_graph=False, allow_unused=False):
], "The function to compute batched Hessian matrix should return a Tensor of shape [batch_size, 1]" ], "The function to compute batched Hessian matrix should return a Tensor of shape [batch_size, 1]"
def jac_func(*ins): def jac_func(*ins):
grad_inputs = paddle.grad( grad_inputs = paddle.grad(outputs,
outputs, ins,
ins, create_graph=True,
create_graph=True, retain_graph=True,
retain_graph=True, allow_unused=allow_unused)
allow_unused=allow_unused)
return tuple( return tuple(
_replace_none_with_zero_tensor(grad_inputs[i], inputs[i]) _replace_none_with_zero_tensor(grad_inputs[i], inputs[i])
for i in range(len(inputs))) for i in range(len(inputs)))
return batch_jacobian( return batch_jacobian(jac_func,
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) inputs,
create_graph=create_graph,
allow_unused=allow_unused)
@framework.dygraph_only @framework.dygraph_only
...@@ -1245,18 +1241,19 @@ def hessian(func, inputs, create_graph=False, allow_unused=False): ...@@ -1245,18 +1241,19 @@ def hessian(func, inputs, create_graph=False, allow_unused=False):
], "The function to compute Hessian matrix should return a Tensor with a single element" ], "The function to compute Hessian matrix should return a Tensor with a single element"
def jac_func(*ins): def jac_func(*ins):
grad_inputs = paddle.grad( grad_inputs = paddle.grad(outputs,
outputs, ins,
ins, create_graph=True,
create_graph=True, retain_graph=True,
retain_graph=True, allow_unused=allow_unused)
allow_unused=allow_unused)
return tuple( return tuple(
_replace_none_with_zero_tensor(grad_inputs[i], inputs[i]) _replace_none_with_zero_tensor(grad_inputs[i], inputs[i])
for i in range(len(inputs))) for i in range(len(inputs)))
return jacobian( return jacobian(jac_func,
jac_func, inputs, create_graph=create_graph, allow_unused=allow_unused) inputs,
create_graph=create_graph,
allow_unused=allow_unused)
def vhp(func, inputs, v=None, create_graph=False, allow_unused=False): def vhp(func, inputs, v=None, create_graph=False, allow_unused=False):
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -17,6 +17,7 @@ from paddle.fluid.framework import dygraph_only ...@@ -17,6 +17,7 @@ from paddle.fluid.framework import dygraph_only
from paddle.fluid.dygraph.amp.auto_cast import amp_state from paddle.fluid.dygraph.amp.auto_cast import amp_state
from paddle.amp.auto_cast import auto_cast from paddle.amp.auto_cast import auto_cast
from paddle.fluid import core from paddle.fluid import core
__all__ = [] __all__ = []
...@@ -123,7 +124,9 @@ class PyLayerContext(object): ...@@ -123,7 +124,9 @@ class PyLayerContext(object):
def with_mateclass(meta, *bases): def with_mateclass(meta, *bases):
class impl(meta): class impl(meta):
def __new__(cls, name, temp_bases, attrs): def __new__(cls, name, temp_bases, attrs):
return meta(name, bases, attrs) return meta(name, bases, attrs)
...@@ -131,6 +134,7 @@ def with_mateclass(meta, *bases): ...@@ -131,6 +134,7 @@ def with_mateclass(meta, *bases):
class CPyLayer(object): class CPyLayer(object):
@classmethod @classmethod
@dygraph_only @dygraph_only
def apply(cls, *args, **kwargs): def apply(cls, *args, **kwargs):
...@@ -178,6 +182,7 @@ class CPyLayer(object): ...@@ -178,6 +182,7 @@ class CPyLayer(object):
class PyLayerBackward(PyLayerContext): class PyLayerBackward(PyLayerContext):
def backward(self, *args, **kwargs): def backward(self, *args, **kwargs):
with paddle.fluid.dygraph.guard(): with paddle.fluid.dygraph.guard():
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
...@@ -192,6 +197,7 @@ class PyLayerBackward(PyLayerContext): ...@@ -192,6 +197,7 @@ class PyLayerBackward(PyLayerContext):
class LayerMeta(type): class LayerMeta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls._backward_function = type(name + '_backward', (PyLayerBackward, ), cls._backward_function = type(name + '_backward', (PyLayerBackward, ),
{"_forward_cls": cls}) {"_forward_cls": cls})
...@@ -330,6 +336,7 @@ class PyLayer(with_mateclass(LayerMeta, CPyLayer)): ...@@ -330,6 +336,7 @@ class PyLayer(with_mateclass(LayerMeta, CPyLayer)):
class EagerPyLayerContext(object): class EagerPyLayerContext(object):
def save_for_backward(self, *tensors): def save_for_backward(self, *tensors):
""" """
Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors. Saves given tensors that backward need. Use ``saved_tensor`` in the `backward` to get the saved tensors.
...@@ -494,11 +501,13 @@ class EagerPyLayerContext(object): ...@@ -494,11 +501,13 @@ class EagerPyLayerContext(object):
class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext): class EagerPyLayerBackward(core.eager.PyLayer, EagerPyLayerContext):
def backward(self, *args): def backward(self, *args):
return self._forward_cls.backward(self, *args) return self._forward_cls.backward(self, *args)
class EagerPyLayerMeta(type): class EagerPyLayerMeta(type):
def __init__(cls, name, bases, attrs): def __init__(cls, name, bases, attrs):
cls._backward_function = type(name + '_backward', cls._backward_function = type(name + '_backward',
(EagerPyLayerBackward, ), (EagerPyLayerBackward, ),
...@@ -510,6 +519,7 @@ class EagerPyLayerMeta(type): ...@@ -510,6 +519,7 @@ class EagerPyLayerMeta(type):
class EagerPyLayer( class EagerPyLayer(
with_mateclass(EagerPyLayerMeta, core.eager.PyLayer, with_mateclass(EagerPyLayerMeta, core.eager.PyLayer,
EagerPyLayerContext)): EagerPyLayerContext)):
@staticmethod @staticmethod
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
""" """
...@@ -590,6 +600,7 @@ class EagerPyLayer( ...@@ -590,6 +600,7 @@ class EagerPyLayer(
def once_differentiable(backward): def once_differentiable(backward):
def wrapper(ctx, *args): def wrapper(ctx, *args):
with paddle.fluid.dygraph.no_grad(): with paddle.fluid.dygraph.no_grad():
outputs = backward(ctx, *args) outputs = backward(ctx, *args)
......
...@@ -21,11 +21,6 @@ from .hapi.callbacks import EarlyStopping # noqa: F401 ...@@ -21,11 +21,6 @@ from .hapi.callbacks import EarlyStopping # noqa: F401
from .hapi.callbacks import ReduceLROnPlateau # noqa: F401 from .hapi.callbacks import ReduceLROnPlateau # noqa: F401
__all__ = [ #noqa __all__ = [ #noqa
'Callback', 'Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL', 'LRScheduler',
'ProgBarLogger', 'EarlyStopping', 'ReduceLROnPlateau'
'ModelCheckpoint',
'VisualDL',
'LRScheduler',
'EarlyStopping',
'ReduceLROnPlateau'
] ]
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .cost_model import CostModel # noqa: F401 from .cost_model import CostModel # noqa: F401
__all__ = ['CostModel'] __all__ = ['CostModel']
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -21,6 +21,7 @@ from paddle.fluid import core ...@@ -21,6 +21,7 @@ from paddle.fluid import core
class CostModel(): class CostModel():
def __init__(self): def __init__(self):
pass pass
...@@ -29,10 +30,11 @@ class CostModel(): ...@@ -29,10 +30,11 @@ class CostModel():
main_program = static.Program() main_program = static.Program()
startup_program = static.Program() startup_program = static.Program()
with static.program_guard( with static.program_guard(main_program=main_program,
main_program=main_program, startup_program=startup_program): startup_program=startup_program):
data = paddle.static.data( data = paddle.static.data(name='X',
name='X', shape=[None, 1], dtype='float32') shape=[None, 1],
dtype='float32')
hidden = paddle.static.nn.fc(data, 10) hidden = paddle.static.nn.fc(data, 10)
loss = paddle.mean(hidden) loss = paddle.mean(hidden)
paddle.optimizer.SGD(learning_rate=0.01).minimize(loss) paddle.optimizer.SGD(learning_rate=0.01).minimize(loss)
...@@ -59,8 +61,8 @@ class CostModel(): ...@@ -59,8 +61,8 @@ class CostModel():
cost_data = cost_model.ProfileMeasure(device) cost_data = cost_model.ProfileMeasure(device)
def static_cost_data(self): def static_cost_data(self):
static_cost_data_path = os.path.join( static_cost_data_path = os.path.join(os.path.dirname(__file__),
os.path.dirname(__file__), "static_op_benchmark.json") "static_op_benchmark.json")
with open(static_cost_data_path, 'r') as load_f: with open(static_cost_data_path, 'r') as load_f:
load_dict = json.load(load_f) load_dict = json.load(load_f)
self._static_cost_data = load_dict self._static_cost_data = load_dict
......
...@@ -47,10 +47,11 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85' ...@@ -47,10 +47,11 @@ CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'
def reader_creator(filename, sub_name, cycle=False): def reader_creator(filename, sub_name, cycle=False):
def read_batch(batch): def read_batch(batch):
data = batch[six.b('data')] data = batch[six.b('data')]
labels = batch.get( labels = batch.get(six.b('labels'), batch.get(six.b('fine_labels'),
six.b('labels'), batch.get(six.b('fine_labels'), None)) None))
assert labels is not None assert labels is not None
for sample, label in six.moves.zip(data, labels): for sample, label in six.moves.zip(data, labels):
yield (sample / 255.0).astype(numpy.float32), int(label) yield (sample / 255.0).astype(numpy.float32), int(label)
...@@ -129,10 +130,10 @@ def train10(cycle=False): ...@@ -129,10 +130,10 @@ def train10(cycle=False):
:return: Training reader creator :return: Training reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(paddle.dataset.common.download(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), CIFAR10_URL, 'cifar', CIFAR10_MD5),
'data_batch', 'data_batch',
cycle=cycle) cycle=cycle)
@deprecated( @deprecated(
...@@ -152,10 +153,10 @@ def test10(cycle=False): ...@@ -152,10 +153,10 @@ def test10(cycle=False):
:return: Test reader creator. :return: Test reader creator.
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(paddle.dataset.common.download(
paddle.dataset.common.download(CIFAR10_URL, 'cifar', CIFAR10_MD5), CIFAR10_URL, 'cifar', CIFAR10_MD5),
'test_batch', 'test_batch',
cycle=cycle) cycle=cycle)
@deprecated( @deprecated(
......
...@@ -64,9 +64,9 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -64,9 +64,9 @@ def download(url, module_name, md5sum, save_name=None):
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
filename = os.path.join(dirname, filename = os.path.join(
url.split('/')[-1] dirname,
if save_name is None else save_name) url.split('/')[-1] if save_name is None else save_name)
if os.path.exists(filename) and md5file(filename) == md5sum: if os.path.exists(filename) and md5file(filename) == md5sum:
return filename return filename
...@@ -79,8 +79,9 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -79,8 +79,9 @@ def download(url, module_name, md5sum, save_name=None):
if retry < retry_limit: if retry < retry_limit:
retry += 1 retry += 1
else: else:
raise RuntimeError("Cannot download {0} within retry limit {1}". raise RuntimeError(
format(url, retry_limit)) "Cannot download {0} within retry limit {1}".format(
url, retry_limit))
sys.stderr.write("Cache file %s not found, downloading %s \n" % sys.stderr.write("Cache file %s not found, downloading %s \n" %
(filename, url)) (filename, url))
sys.stderr.write("Begin to download\n") sys.stderr.write("Begin to download\n")
...@@ -98,8 +99,8 @@ def download(url, module_name, md5sum, save_name=None): ...@@ -98,8 +99,8 @@ def download(url, module_name, md5sum, save_name=None):
total_iter = total_length / chunk_size + 1 total_iter = total_length / chunk_size + 1
log_interval = total_iter // 20 if total_iter > 20 else 1 log_interval = total_iter // 20 if total_iter > 20 else 1
log_index = 0 log_index = 0
bar = paddle.hapi.progressbar.ProgressBar( bar = paddle.hapi.progressbar.ProgressBar(total_iter,
total_iter, name='item') name='item')
for data in r.iter_content(chunk_size=chunk_size): for data in r.iter_content(chunk_size=chunk_size):
f.write(data) f.write(data)
log_index += 1 log_index += 1
...@@ -121,9 +122,8 @@ def fetch_all(): ...@@ -121,9 +122,8 @@ def fetch_all():
]: ]:
if "fetch" in dir( if "fetch" in dir(
importlib.import_module("paddle.dataset.%s" % module_name)): importlib.import_module("paddle.dataset.%s" % module_name)):
getattr( getattr(importlib.import_module("paddle.dataset.%s" % module_name),
importlib.import_module("paddle.dataset.%s" % module_name), "fetch")()
"fetch")()
def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump): def split(reader, line_count, suffix="%05d.pickle", dumper=pickle.dump):
...@@ -206,5 +206,5 @@ def _check_exists_and_download(path, url, md5, module_name, download=True): ...@@ -206,5 +206,5 @@ def _check_exists_and_download(path, url, md5, module_name, download=True):
if download: if download:
return paddle.dataset.common.download(url, module_name, md5) return paddle.dataset.common.download(url, module_name, md5)
else: else:
raise ValueError('{} not exists and auto download disabled'.format( raise ValueError(
path)) '{} not exists and auto download disabled'.format(path))
...@@ -152,6 +152,7 @@ def reader_creator(corpus_reader, ...@@ -152,6 +152,7 @@ def reader_creator(corpus_reader,
word_dict=None, word_dict=None,
predicate_dict=None, predicate_dict=None,
label_dict=None): label_dict=None):
def reader(): def reader():
for sentence, predicate, labels in corpus_reader(): for sentence, predicate, labels in corpus_reader():
......
...@@ -73,8 +73,11 @@ def default_mapper(is_train, sample): ...@@ -73,8 +73,11 @@ def default_mapper(is_train, sample):
''' '''
img, label = sample img, label = sample
img = load_image_bytes(img) img = load_image_bytes(img)
img = simple_transform( img = simple_transform(img,
img, 256, 224, is_train, mean=[103.94, 116.78, 123.68]) 256,
224,
is_train,
mean=[103.94, 116.78, 123.68])
return img.flatten().astype('float32'), label return img.flatten().astype('float32'), label
...@@ -164,15 +167,14 @@ def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False): ...@@ -164,15 +167,14 @@ def train(mapper=train_mapper, buffered_size=1024, use_xmap=True, cycle=False):
:return: train data reader :return: train data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(download(DATA_URL, 'flowers', DATA_MD5),
download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(SETID_URL, 'flowers', SETID_MD5),
download(SETID_URL, 'flowers', SETID_MD5), TRAIN_FLAG,
TRAIN_FLAG, mapper,
mapper, buffered_size,
buffered_size, use_xmap,
use_xmap, cycle=cycle)
cycle=cycle)
@deprecated( @deprecated(
...@@ -198,15 +200,14 @@ def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False): ...@@ -198,15 +200,14 @@ def test(mapper=test_mapper, buffered_size=1024, use_xmap=True, cycle=False):
:return: test data reader :return: test data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(download(DATA_URL, 'flowers', DATA_MD5),
download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(SETID_URL, 'flowers', SETID_MD5),
download(SETID_URL, 'flowers', SETID_MD5), TEST_FLAG,
TEST_FLAG, mapper,
mapper, buffered_size,
buffered_size, use_xmap,
use_xmap, cycle=cycle)
cycle=cycle)
@deprecated( @deprecated(
...@@ -230,11 +231,10 @@ def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True): ...@@ -230,11 +231,10 @@ def valid(mapper=test_mapper, buffered_size=1024, use_xmap=True):
:return: test data reader :return: test data reader
:rtype: callable :rtype: callable
''' '''
return reader_creator( return reader_creator(download(DATA_URL, 'flowers', DATA_MD5),
download(DATA_URL, 'flowers', DATA_MD5), download(LABEL_URL, 'flowers', LABEL_MD5),
download(LABEL_URL, 'flowers', LABEL_MD5), download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG,
download(SETID_URL, 'flowers', SETID_MD5), VALID_FLAG, mapper, mapper, buffered_size, use_xmap)
buffered_size, use_xmap)
def fetch(): def fetch():
......
...@@ -45,10 +45,9 @@ if six.PY3: ...@@ -45,10 +45,9 @@ if six.PY3:
# will be the C++ execubable on Windows # will be the C++ execubable on Windows
if sys.platform == 'win32' and 'python.exe' not in interpreter: if sys.platform == 'win32' and 'python.exe' not in interpreter:
interpreter = sys.exec_prefix + os.sep + 'python.exe' interpreter = sys.exec_prefix + os.sep + 'python.exe'
import_cv2_proc = subprocess.Popen( import_cv2_proc = subprocess.Popen([interpreter, "-c", "import cv2"],
[interpreter, "-c", "import cv2"], stdout=subprocess.PIPE,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stderr=subprocess.PIPE)
out, err = import_cv2_proc.communicate() out, err = import_cv2_proc.communicate()
retcode = import_cv2_proc.poll() retcode = import_cv2_proc.poll()
if retcode != 0: if retcode != 0:
...@@ -123,10 +122,9 @@ def batch_images_from_tar(data_file, ...@@ -123,10 +122,9 @@ def batch_images_from_tar(data_file,
output = {} output = {}
output['label'] = labels output['label'] = labels
output['data'] = data output['data'] = data
pickle.dump( pickle.dump(output,
output, open('%s/batch_%d' % (out_path, file_id), 'wb'),
open('%s/batch_%d' % (out_path, file_id), 'wb'), protocol=2)
protocol=2)
file_id += 1 file_id += 1
data = [] data = []
labels = [] labels = []
...@@ -134,8 +132,9 @@ def batch_images_from_tar(data_file, ...@@ -134,8 +132,9 @@ def batch_images_from_tar(data_file,
output = {} output = {}
output['label'] = labels output['label'] = labels
output['data'] = data output['data'] = data
pickle.dump( pickle.dump(output,
output, open('%s/batch_%d' % (out_path, file_id), 'wb'), protocol=2) open('%s/batch_%d' % (out_path, file_id), 'wb'),
protocol=2)
with open(meta_file, 'a') as meta: with open(meta_file, 'a') as meta:
for file in os.listdir(out_path): for file in os.listdir(out_path):
......
...@@ -51,9 +51,9 @@ def tokenize(pattern): ...@@ -51,9 +51,9 @@ def tokenize(pattern):
while tf != None: while tf != None:
if bool(pattern.match(tf.name)): if bool(pattern.match(tf.name)):
# newline and punctuations removal and ad-hoc tokenization. # newline and punctuations removal and ad-hoc tokenization.
yield tarf.extractfile(tf).read().rstrip(six.b( yield tarf.extractfile(tf).read().rstrip(
"\n\r")).translate( six.b("\n\r")).translate(None, six.b(
None, six.b(string.punctuation)).lower().split() string.punctuation)).lower().split()
tf = tarf.next() tf = tarf.next()
...@@ -117,9 +117,8 @@ def train(word_idx): ...@@ -117,9 +117,8 @@ def train(word_idx):
:return: Training reader creator :return: Training reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(re.compile(r"aclImdb/train/pos/.*\.txt$"),
re.compile(r"aclImdb/train/pos/.*\.txt$"), re.compile(r"aclImdb/train/neg/.*\.txt$"), word_idx)
re.compile(r"aclImdb/train/neg/.*\.txt$"), word_idx)
@deprecated( @deprecated(
...@@ -139,9 +138,8 @@ def test(word_idx): ...@@ -139,9 +138,8 @@ def test(word_idx):
:return: Test reader creator :return: Test reader creator
:rtype: callable :rtype: callable
""" """
return reader_creator( return reader_creator(re.compile(r"aclImdb/test/pos/.*\.txt$"),
re.compile(r"aclImdb/test/pos/.*\.txt$"), re.compile(r"aclImdb/test/neg/.*\.txt$"), word_idx)
re.compile(r"aclImdb/test/neg/.*\.txt$"), word_idx)
@deprecated( @deprecated(
......
...@@ -83,6 +83,7 @@ def build_dict(min_word_freq=50): ...@@ -83,6 +83,7 @@ def build_dict(min_word_freq=50):
def reader_creator(filename, word_idx, n, data_type): def reader_creator(filename, word_idx, n, data_type):
def reader(): def reader():
with tarfile.open( with tarfile.open(
paddle.dataset.common.download( paddle.dataset.common.download(
......
...@@ -41,6 +41,7 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432' ...@@ -41,6 +41,7 @@ TRAIN_LABEL_MD5 = 'd53e105ee54ea40749a09fcbcd1e9432'
def reader_creator(image_filename, label_filename, buffer_size): def reader_creator(image_filename, label_filename, buffer_size):
def reader(): def reader():
with gzip.GzipFile(image_filename, 'rb') as image_file: with gzip.GzipFile(image_filename, 'rb') as image_file:
img_buf = image_file.read() img_buf = image_file.read()
...@@ -61,8 +62,8 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -61,8 +62,8 @@ def reader_creator(image_filename, label_filename, buffer_size):
offset_lab = 0 offset_lab = 0
# label file : 8B # label file : 8B
magic_byte_lab = '>II' magic_byte_lab = '>II'
magic_lab, label_num = struct.unpack_from(magic_byte_lab, magic_lab, label_num = struct.unpack_from(
lab_buf, offset_lab) magic_byte_lab, lab_buf, offset_lab)
offset_lab += struct.calcsize(magic_byte_lab) offset_lab += struct.calcsize(magic_byte_lab)
while True: while True:
...@@ -76,8 +77,9 @@ def reader_creator(image_filename, label_filename, buffer_size): ...@@ -76,8 +77,9 @@ def reader_creator(image_filename, label_filename, buffer_size):
fmt_images = '>' + str(buffer_size * rows * cols) + 'B' fmt_images = '>' + str(buffer_size * rows * cols) + 'B'
images_temp = struct.unpack_from(fmt_images, img_buf, images_temp = struct.unpack_from(fmt_images, img_buf,
offset_img) offset_img)
images = numpy.reshape(images_temp, ( images = numpy.reshape(
buffer_size, rows * cols)).astype('float32') images_temp,
(buffer_size, rows * cols)).astype('float32')
offset_img += struct.calcsize(fmt_images) offset_img += struct.calcsize(fmt_images)
images = images / 255.0 images = images / 255.0
......
...@@ -89,8 +89,8 @@ class UserInfo(object): ...@@ -89,8 +89,8 @@ class UserInfo(object):
def __str__(self): def __str__(self):
return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % ( return "<UserInfo id(%d), gender(%s), age(%d), job(%d)>" % (
self.index, "M" self.index, "M" if self.is_male else "F", age_table[self.age],
if self.is_male else "F", age_table[self.age], self.job_id) self.job_id)
def __repr__(self): def __repr__(self):
return str(self) return str(self)
...@@ -142,8 +142,10 @@ def __initialize_meta_info__(): ...@@ -142,8 +142,10 @@ def __initialize_meta_info__():
for line in user_file: for line in user_file:
line = cpt.to_text(line, encoding='latin') line = cpt.to_text(line, encoding='latin')
uid, gender, age, job, _ = line.strip().split("::") uid, gender, age, job, _ = line.strip().split("::")
USER_INFO[int(uid)] = UserInfo( USER_INFO[int(uid)] = UserInfo(index=uid,
index=uid, gender=gender, age=age, job_id=job) gender=gender,
age=age,
job_id=job)
return fn return fn
......
...@@ -21,6 +21,7 @@ __all__ = [] ...@@ -21,6 +21,7 @@ __all__ = []
class TestCIFAR(unittest.TestCase): class TestCIFAR(unittest.TestCase):
def check_reader(self, reader): def check_reader(self, reader):
sum = 0 sum = 0
label = 0 label = 0
......
...@@ -21,6 +21,7 @@ __all__ = [] ...@@ -21,6 +21,7 @@ __all__ = []
class TestFlowers(unittest.TestCase): class TestFlowers(unittest.TestCase):
def check_reader(self, reader): def check_reader(self, reader):
sum = 0 sum = 0
label = 0 label = 0
......
...@@ -23,6 +23,7 @@ __all__ = [] ...@@ -23,6 +23,7 @@ __all__ = []
class TestMikolov(unittest.TestCase): class TestMikolov(unittest.TestCase):
def check_reader(self, reader, n): def check_reader(self, reader, n):
for l in reader(): for l in reader():
self.assertEqual(len(l), n) self.assertEqual(len(l), n)
......
...@@ -21,6 +21,7 @@ __all__ = [] ...@@ -21,6 +21,7 @@ __all__ = []
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
def check_reader(self, reader): def check_reader(self, reader):
sum = 0 sum = 0
label = 0 label = 0
......
...@@ -23,6 +23,7 @@ __all__ = [] ...@@ -23,6 +23,7 @@ __all__ = []
class Image(unittest.TestCase): class Image(unittest.TestCase):
def test_resize_flip_chw(self): def test_resize_flip_chw(self):
# resize # resize
im = image.load_image('cat.jpg') im = image.load_image('cat.jpg')
......
...@@ -21,6 +21,7 @@ __all__ = [] ...@@ -21,6 +21,7 @@ __all__ = []
class TestVOC(unittest.TestCase): class TestVOC(unittest.TestCase):
def check_reader(self, reader): def check_reader(self, reader):
sum = 0 sum = 0
label = 0 label = 0
......
...@@ -21,6 +21,7 @@ __all__ = [] ...@@ -21,6 +21,7 @@ __all__ = []
class TestWMT16(unittest.TestCase): class TestWMT16(unittest.TestCase):
def checkout_one_sample(self, sample): def checkout_one_sample(self, sample):
# train data has 3 field: source language word indices, # train data has 3 field: source language word indices,
# target language word indices, and target next word indices. # target language word indices, and target next word indices.
...@@ -38,22 +39,22 @@ class TestWMT16(unittest.TestCase): ...@@ -38,22 +39,22 @@ class TestWMT16(unittest.TestCase):
def test_train(self): def test_train(self):
for idx, sample in enumerate( for idx, sample in enumerate(
paddle.dataset.wmt16.train( paddle.dataset.wmt16.train(src_dict_size=100000,
src_dict_size=100000, trg_dict_size=100000)()): trg_dict_size=100000)()):
if idx >= 10: break if idx >= 10: break
self.checkout_one_sample(sample) self.checkout_one_sample(sample)
def test_test(self): def test_test(self):
for idx, sample in enumerate( for idx, sample in enumerate(
paddle.dataset.wmt16.test( paddle.dataset.wmt16.test(src_dict_size=1000,
src_dict_size=1000, trg_dict_size=1000)()): trg_dict_size=1000)()):
if idx >= 10: break if idx >= 10: break
self.checkout_one_sample(sample) self.checkout_one_sample(sample)
def test_val(self): def test_val(self):
for idx, sample in enumerate( for idx, sample in enumerate(
paddle.dataset.wmt16.validation( paddle.dataset.wmt16.validation(src_dict_size=1000,
src_dict_size=1000, trg_dict_size=1000)()): trg_dict_size=1000)()):
if idx >= 10: break if idx >= 10: break
self.checkout_one_sample(sample) self.checkout_one_sample(sample)
......
...@@ -73,8 +73,8 @@ def load_data(filename, feature_num=14, ratio=0.8): ...@@ -73,8 +73,8 @@ def load_data(filename, feature_num=14, ratio=0.8):
data = np.fromfile(filename, sep=' ') data = np.fromfile(filename, sep=' ')
data = data.reshape(data.shape[0] // feature_num, feature_num) data = data.reshape(data.shape[0] // feature_num, feature_num)
maximums, minimums, avgs = data.max(axis=0), data.min(axis=0), data.sum( maximums, minimums, avgs = data.max(axis=0), data.min(
axis=0) / data.shape[0] axis=0), data.sum(axis=0) / data.shape[0]
# if you want to print the distribution of input data, you could use function of feature_range # if you want to print the distribution of input data, you could use function of feature_range
#feature_range(maximums[:-1], minimums[:-1]) #feature_range(maximums[:-1], minimums[:-1])
for i in six.moves.range(feature_num - 1): for i in six.moves.range(feature_num - 1):
...@@ -135,8 +135,10 @@ def test(): ...@@ -135,8 +135,10 @@ def test():
def fluid_model(): def fluid_model():
parameter_tar = paddle.dataset.common.download( parameter_tar = paddle.dataset.common.download(FLUID_URL_MODEL,
FLUID_URL_MODEL, 'uci_housing', FLUID_MD5_MODEL, 'fit_a_line.fluid.tar') 'uci_housing',
FLUID_MD5_MODEL,
'fit_a_line.fluid.tar')
tar = tarfile.TarFile(parameter_tar, mode='r') tar = tarfile.TarFile(parameter_tar, mode='r')
dirpath = tempfile.mkdtemp() dirpath = tempfile.mkdtemp()
......
...@@ -50,6 +50,7 @@ UNK_IDX = 2 ...@@ -50,6 +50,7 @@ UNK_IDX = 2
def __read_to_dict(tar_file, dict_size): def __read_to_dict(tar_file, dict_size):
def __to_dict(fd, size): def __to_dict(fd, size):
out_dict = dict() out_dict = dict()
for line_count, line in enumerate(fd): for line_count, line in enumerate(fd):
...@@ -76,6 +77,7 @@ def __read_to_dict(tar_file, dict_size): ...@@ -76,6 +77,7 @@ def __read_to_dict(tar_file, dict_size):
def reader_creator(tar_file, file_name, dict_size): def reader_creator(tar_file, file_name, dict_size):
def reader(): def reader():
src_dict, trg_dict = __read_to_dict(tar_file, dict_size) src_dict, trg_dict = __read_to_dict(tar_file, dict_size)
with tarfile.open(tar_file, mode='r') as f: with tarfile.open(tar_file, mode='r') as f:
......
...@@ -68,9 +68,9 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -68,9 +68,9 @@ def __build_dict(tar_file, dict_size, save_path, lang):
fout.write( fout.write(
cpt.to_bytes("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK))) cpt.to_bytes("%s\n%s\n%s\n" % (START_MARK, END_MARK, UNK_MARK)))
for idx, word in enumerate( for idx, word in enumerate(
sorted( sorted(six.iteritems(word_dict),
six.iteritems(word_dict), key=lambda x: x[1], key=lambda x: x[1],
reverse=True)): reverse=True)):
if idx + 3 == dict_size: break if idx + 3 == dict_size: break
fout.write(cpt.to_bytes(word[0])) fout.write(cpt.to_bytes(word[0]))
fout.write(cpt.to_bytes('\n')) fout.write(cpt.to_bytes('\n'))
...@@ -79,8 +79,8 @@ def __build_dict(tar_file, dict_size, save_path, lang): ...@@ -79,8 +79,8 @@ def __build_dict(tar_file, dict_size, save_path, lang):
def __load_dict(tar_file, dict_size, lang, reverse=False): def __load_dict(tar_file, dict_size, lang, reverse=False):
dict_path = os.path.join(paddle.dataset.common.DATA_HOME, dict_path = os.path.join(paddle.dataset.common.DATA_HOME,
"wmt16/%s_%d.dict" % (lang, dict_size)) "wmt16/%s_%d.dict" % (lang, dict_size))
if not os.path.exists(dict_path) or ( if not os.path.exists(dict_path) or (len(open(dict_path, "rb").readlines())
len(open(dict_path, "rb").readlines()) != dict_size): != dict_size):
__build_dict(tar_file, dict_size, dict_path, lang) __build_dict(tar_file, dict_size, dict_path, lang)
word_dict = {} word_dict = {}
...@@ -94,14 +94,15 @@ def __load_dict(tar_file, dict_size, lang, reverse=False): ...@@ -94,14 +94,15 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
def __get_dict_size(src_dict_size, trg_dict_size, src_lang): def __get_dict_size(src_dict_size, trg_dict_size, src_lang):
src_dict_size = min(src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else src_dict_size = min(
TOTAL_DE_WORDS)) src_dict_size, (TOTAL_EN_WORDS if src_lang == "en" else TOTAL_DE_WORDS))
trg_dict_size = min(trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else trg_dict_size = min(
TOTAL_EN_WORDS)) trg_dict_size, (TOTAL_DE_WORDS if src_lang == "en" else TOTAL_EN_WORDS))
return src_dict_size, trg_dict_size return src_dict_size, trg_dict_size
def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
def reader(): def reader():
src_dict = __load_dict(tar_file, src_dict_size, src_lang) src_dict = __load_dict(tar_file, src_dict_size, src_lang)
trg_dict = __load_dict(tar_file, trg_dict_size, trg_dict = __load_dict(tar_file, trg_dict_size,
...@@ -124,9 +125,9 @@ def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang): ...@@ -124,9 +125,9 @@ def reader_creator(tar_file, file_name, src_dict_size, trg_dict_size, src_lang):
if len(line_split) != 2: if len(line_split) != 2:
continue continue
src_words = line_split[src_col].split() src_words = line_split[src_col].split()
src_ids = [start_id] + [ src_ids = [start_id
src_dict.get(w, unk_id) for w in src_words ] + [src_dict.get(w, unk_id)
] + [end_id] for w in src_words] + [end_id]
trg_words = line_split[trg_col].split() trg_words = line_split[trg_col].split()
trg_ids = [trg_dict.get(w, unk_id) for w in trg_words] trg_ids = [trg_dict.get(w, unk_id) for w in trg_words]
...@@ -184,13 +185,12 @@ def train(src_dict_size, trg_dict_size, src_lang="en"): ...@@ -184,13 +185,12 @@ def train(src_dict_size, trg_dict_size, src_lang="en"):
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang) src_lang)
return reader_creator( return reader_creator(tar_file=paddle.dataset.common.download(
tar_file=paddle.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"),
"wmt16.tar.gz"), file_name="wmt16/train",
file_name="wmt16/train", src_dict_size=src_dict_size,
src_dict_size=src_dict_size, trg_dict_size=trg_dict_size,
trg_dict_size=trg_dict_size, src_lang=src_lang)
src_lang=src_lang)
@deprecated( @deprecated(
...@@ -238,13 +238,12 @@ def test(src_dict_size, trg_dict_size, src_lang="en"): ...@@ -238,13 +238,12 @@ def test(src_dict_size, trg_dict_size, src_lang="en"):
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang) src_lang)
return reader_creator( return reader_creator(tar_file=paddle.dataset.common.download(
tar_file=paddle.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"),
"wmt16.tar.gz"), file_name="wmt16/test",
file_name="wmt16/test", src_dict_size=src_dict_size,
src_dict_size=src_dict_size, trg_dict_size=trg_dict_size,
trg_dict_size=trg_dict_size, src_lang=src_lang)
src_lang=src_lang)
@deprecated( @deprecated(
...@@ -290,13 +289,12 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"): ...@@ -290,13 +289,12 @@ def validation(src_dict_size, trg_dict_size, src_lang="en"):
src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size, src_dict_size, trg_dict_size = __get_dict_size(src_dict_size, trg_dict_size,
src_lang) src_lang)
return reader_creator( return reader_creator(tar_file=paddle.dataset.common.download(
tar_file=paddle.dataset.common.download(DATA_URL, "wmt16", DATA_MD5, DATA_URL, "wmt16", DATA_MD5, "wmt16.tar.gz"),
"wmt16.tar.gz"), file_name="wmt16/val",
file_name="wmt16/val", src_dict_size=src_dict_size,
src_dict_size=src_dict_size, trg_dict_size=trg_dict_size,
trg_dict_size=trg_dict_size, src_lang=src_lang)
src_lang=src_lang)
@deprecated( @deprecated(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# TODO: define the functions to manipulate devices # TODO: define the functions to manipulate devices
import re import re
import os import os
from paddle.fluid import core from paddle.fluid import core
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -178,8 +178,8 @@ def extract_cuda_device_id(device, op_name): ...@@ -178,8 +178,8 @@ def extract_cuda_device_id(device, op_name):
else: else:
raise ValueError( raise ValueError(
"The current string {} is not expected. Because {} only support string which is like 'gpu:x'. " "The current string {} is not expected. Because {} only support string which is like 'gpu:x'. "
"Please input appropriate string again!".format(device, "Please input appropriate string again!".format(
op_name)) device, op_name))
else: else:
raise ValueError( raise ValueError(
"The device type {} is not expected. Because {} only support int, str or paddle.CUDAPlace. " "The device type {} is not expected. Because {} only support int, str or paddle.CUDAPlace. "
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -31,6 +31,7 @@ ALL_MODES = ["global", "thread_local", "relaxed"] ...@@ -31,6 +31,7 @@ ALL_MODES = ["global", "thread_local", "relaxed"]
class CUDAGraph: class CUDAGraph:
def __init__(self, place=None, mode="thread_local"): def __init__(self, place=None, mode="thread_local"):
assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU." assert CoreCUDAGraph is not None, "CUDA Graph is only supported on PaddlePaddle compiled with NVIDIA GPU."
...@@ -61,7 +62,7 @@ class CUDAGraph: ...@@ -61,7 +62,7 @@ class CUDAGraph:
assert os.path.isdir( assert os.path.isdir(
dirname), "The dirname {} should be a directory".format(dirname) dirname), "The dirname {} should be a directory".format(dirname)
if flags is None: if flags is None:
flags = 2047 # only all information. It can be any integer inside [1, 2048) flags = 2047 # only all information. It can be any integer inside [1, 2048)
self._graph.print_to_dot_files(dirname, flags) self._graph.print_to_dot_files(dirname, flags)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
...@@ -59,33 +59,33 @@ from . import utils # noqa: F401 ...@@ -59,33 +59,33 @@ from . import utils # noqa: F401
from .sharding import * # noqa: F401 from .sharding import * # noqa: F401
__all__ = [ # noqa __all__ = [ # noqa
"spawn", "spawn",
"launch", "launch",
"scatter", "scatter",
"broadcast", "broadcast",
"ParallelEnv", "ParallelEnv",
"new_group", "new_group",
"init_parallel_env", "init_parallel_env",
"gloo_init_parallel_env", "gloo_init_parallel_env",
"gloo_barrier", "gloo_barrier",
"gloo_release", "gloo_release",
"QueueDataset", "QueueDataset",
"split", "split",
"CountFilterEntry", "CountFilterEntry",
"ShowClickEntry", "ShowClickEntry",
"get_world_size", "get_world_size",
"get_group", "get_group",
"all_gather", "all_gather",
"InMemoryDataset", "InMemoryDataset",
"barrier", "barrier",
"all_reduce", "all_reduce",
"alltoall", "alltoall",
"send", "send",
"reduce", "reduce",
"recv", "recv",
"ReduceOp", "ReduceOp",
"wait", "wait",
"get_rank", "get_rank",
"ProbabilityEntry", "ProbabilityEntry",
"ParallelMode", "ParallelMode",
] ]
...@@ -50,14 +50,14 @@ class Device: ...@@ -50,14 +50,14 @@ class Device:
self._local_id = local_id self._local_id = local_id
self._machine = machine self._machine = machine
self._type = None self._type = None
# Different device have different models, such as # Different device have different models, such as
# "Tesla V100-SXM2-32GB" and "A100-SXM4-40GB" etc. # "Tesla V100-SXM2-32GB" and "A100-SXM4-40GB" etc.
self._model = None self._model = None
# Double precision GFLOPS # Double precision GFLOPS
self._dp_gflops = None self._dp_gflops = None
# Single precision GFLOPS # Single precision GFLOPS
self._sp_gflops = None self._sp_gflops = None
# Memory is stored by GB # Memory is stored by GB
self._memory = None self._memory = None
@property @property
...@@ -144,9 +144,9 @@ class Link: ...@@ -144,9 +144,9 @@ class Link:
self._src = source self._src = source
self._tgt = target self._tgt = target
self._type = None self._type = None
# bandwidth is stored by GB/s # bandwidth is stored by GB/s
self._bandwidth = None self._bandwidth = None
# latency is stored by millisecond # latency is stored by millisecond
self._latency = None self._latency = None
self._hop = None self._hop = None
...@@ -210,6 +210,7 @@ class Link: ...@@ -210,6 +210,7 @@ class Link:
class Machine: class Machine:
def __init__(self, id): def __init__(self, id):
self._id = id self._id = id
self._hostname = None self._hostname = None
...@@ -290,6 +291,7 @@ class Machine: ...@@ -290,6 +291,7 @@ class Machine:
class AlphaLatency: class AlphaLatency:
def __init__(self, alpha_latency): def __init__(self, alpha_latency):
assert isinstance(alpha_latency, dict) assert isinstance(alpha_latency, dict)
self._base = alpha_latency.get("base", None) self._base = alpha_latency.get("base", None)
......
...@@ -137,6 +137,7 @@ def _validate_dims_mapping(dims_mapping, process_mesh): ...@@ -137,6 +137,7 @@ def _validate_dims_mapping(dims_mapping, process_mesh):
class Completer: class Completer:
def __init__(self, dist_context): def __init__(self, dist_context):
assert dist_context is not None assert dist_context is not None
self._dist_context = dist_context self._dist_context = dist_context
...@@ -248,8 +249,8 @@ class Completer: ...@@ -248,8 +249,8 @@ class Completer:
tensor_desc.name(), compatible_dims_mapping) tensor_desc.name(), compatible_dims_mapping)
changed = True changed = True
# Find the most compatible implemenetations from the distributed operator # Find the most compatible implemenetations from the distributed operator
op_dist_impls = find_compatible_distributed_operator_impls( op_dist_impls = find_compatible_distributed_operator_impls(dist_op,
dist_op, fwd=True) fwd=True)
if op_dist_impls is not None: if op_dist_impls is not None:
not_compatible = True not_compatible = True
backup_op_dist_attr = copy.deepcopy(op_dist_attr) backup_op_dist_attr = copy.deepcopy(op_dist_attr)
...@@ -451,6 +452,7 @@ class Completer: ...@@ -451,6 +452,7 @@ class Completer:
tensor_dist_attr.process_mesh = compatible_process_mesh tensor_dist_attr.process_mesh = compatible_process_mesh
def _update_process_mesh_for_specials(self): def _update_process_mesh_for_specials(self):
def _find_nearest_tensor_node_before(nodes, idx, var_name): def _find_nearest_tensor_node_before(nodes, idx, var_name):
for node in reversed(nodes[:idx]): for node in reversed(nodes[:idx]):
if node.is_var() and node.var() is not None \ if node.is_var() and node.var() is not None \
...@@ -694,8 +696,8 @@ class Completer: ...@@ -694,8 +696,8 @@ class Completer:
# Step 2.2: set the process meshes of ops by the nearest op node after the first op node # Step 2.2: set the process meshes of ops by the nearest op node after the first op node
if idx_of_first_op_node_has_process_mesh + 1 > len(ordered_op_nodes): if idx_of_first_op_node_has_process_mesh + 1 > len(ordered_op_nodes):
return None return None
for idx, op_node in enumerate(ordered_op_nodes[ for idx, op_node in enumerate(
idx_of_first_op_node_has_process_mesh + 1:]): ordered_op_nodes[idx_of_first_op_node_has_process_mesh + 1:]):
original_idx = idx_of_first_op_node_has_process_mesh + idx + 1 original_idx = idx_of_first_op_node_has_process_mesh + idx + 1
nearest_op_node = ordered_op_nodes[original_idx - 1] nearest_op_node = ordered_op_nodes[original_idx - 1]
nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph( nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph(
...@@ -831,9 +833,9 @@ class Completer: ...@@ -831,9 +833,9 @@ class Completer:
if grad_op.desc.original_id( if grad_op.desc.original_id(
) in dist_op_context.grad_op_id_to_op_id: ) in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op # TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(ops, forward_op = _get_op_by_id(
dist_op_context.grad_op_id_to_op_id[ ops, dist_op_context.grad_op_id_to_op_id[
grad_op.desc.original_id()]) grad_op.desc.original_id()])
assert forward_op is not None assert forward_op is not None
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
...@@ -862,8 +864,8 @@ class Completer: ...@@ -862,8 +864,8 @@ class Completer:
input_name) input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_name) input_name)
grad_op_dist_attr.set_input_dims_mapping(input_name, grad_op_dist_attr.set_input_dims_mapping(
ref_dims_mapping) input_name, ref_dims_mapping)
for output_name in grad_op.output_arg_names: for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var[appended_grad_times] assert output_name in grad_var_to_var[appended_grad_times]
...@@ -878,8 +880,8 @@ class Completer: ...@@ -878,8 +880,8 @@ class Completer:
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr) output_var, tensor_dist_attr)
# op # op
grad_op_dist_attr.set_output_dims_mapping(output_name, grad_op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) output_name, ref_dims_mapping)
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
...@@ -934,10 +936,10 @@ class Completer: ...@@ -934,10 +936,10 @@ class Completer:
# op # op
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_process_mesh grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping(ref_var_name, grad_op_dist_attr.set_input_dims_mapping(
ref_dims_mapping) ref_var_name, ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(output_var_name, grad_op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) output_var_name, ref_dims_mapping)
elif grad_op.type in ['shape', 'fill_constant']: elif grad_op.type in ['shape', 'fill_constant']:
continue continue
...@@ -977,8 +979,8 @@ class Completer: ...@@ -977,8 +979,8 @@ class Completer:
first_backward_op_idx = -1 first_backward_op_idx = -1
for idx, op in enumerate(serial_main_program.global_block().ops): for idx, op in enumerate(serial_main_program.global_block().ops):
if int(op.attr('op_role')) == int( if int(op.attr('op_role')) == int(
int(core.op_proto_and_checker_maker.OpRole.Backward) | int( int(core.op_proto_and_checker_maker.OpRole.Backward)
core.op_proto_and_checker_maker.OpRole.Loss)): | int(core.op_proto_and_checker_maker.OpRole.Loss)):
assert op.type == "fill_constant" assert op.type == "fill_constant"
first_backward_op_idx = idx first_backward_op_idx = idx
break break
...@@ -1025,8 +1027,8 @@ class Completer: ...@@ -1025,8 +1027,8 @@ class Completer:
op_dist_attr.process_mesh = process_mesh op_dist_attr.process_mesh = process_mesh
op_dist_attr.set_output_dims_mapping(grad_var.name, op_dist_attr.set_output_dims_mapping(grad_var.name,
dims_mapping) dims_mapping)
self._dist_context.set_op_dist_attr_for_program(ops[idx], self._dist_context.set_op_dist_attr_for_program(
op_dist_attr) ops[idx], op_dist_attr)
continue continue
# complete the annotation of grad op (xxx_grad op or sum op) # complete the annotation of grad op (xxx_grad op or sum op)
...@@ -1035,9 +1037,10 @@ class Completer: ...@@ -1035,9 +1037,10 @@ class Completer:
if grad_op.desc.original_id( if grad_op.desc.original_id(
) in dist_op_context.grad_op_id_to_op_id: ) in dist_op_context.grad_op_id_to_op_id:
# TODO support the case where one forward op corresponding to multiple xxx_grad op # TODO support the case where one forward op corresponding to multiple xxx_grad op
forward_op = _get_op_by_id(ops[:first_backward_op_idx], forward_op = _get_op_by_id(
dist_op_context.grad_op_id_to_op_id[ ops[:first_backward_op_idx],
grad_op.desc.original_id()]) dist_op_context.grad_op_id_to_op_id[
grad_op.desc.original_id()])
assert forward_op is not None assert forward_op is not None
if grad_op.type == "concat" and forward_op.type == "split": if grad_op.type == "concat" and forward_op.type == "split":
...@@ -1060,8 +1063,8 @@ class Completer: ...@@ -1060,8 +1063,8 @@ class Completer:
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
output_var, output_var_dist_attr) output_var, output_var_dist_attr)
grad_op_dist_attr.set_output_dims_mapping(output_var.name, grad_op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) output_var.name, ref_dims_mapping)
grad_op_dist_attr.process_mesh = ref_mesh grad_op_dist_attr.process_mesh = ref_mesh
self._dist_context.set_op_dist_attr_for_program( self._dist_context.set_op_dist_attr_for_program(
grad_op, grad_op_dist_attr) grad_op, grad_op_dist_attr)
...@@ -1095,8 +1098,8 @@ class Completer: ...@@ -1095,8 +1098,8 @@ class Completer:
input_name) input_name)
assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
input_name) input_name)
grad_op_dist_attr.set_input_dims_mapping(input_name, grad_op_dist_attr.set_input_dims_mapping(
ref_dims_mapping) input_name, ref_dims_mapping)
for output_name in grad_op.output_arg_names: for output_name in grad_op.output_arg_names:
assert output_name in grad_var_to_var assert output_name in grad_var_to_var
...@@ -1111,8 +1114,8 @@ class Completer: ...@@ -1111,8 +1114,8 @@ class Completer:
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr) output_var, tensor_dist_attr)
# op # op
grad_op_dist_attr.set_output_dims_mapping(output_name, grad_op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) output_name, ref_dims_mapping)
grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
...@@ -1170,10 +1173,10 @@ class Completer: ...@@ -1170,10 +1173,10 @@ class Completer:
# op # op
grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr = OperatorDistributedAttribute()
grad_op_dist_attr.process_mesh = ref_process_mesh grad_op_dist_attr.process_mesh = ref_process_mesh
grad_op_dist_attr.set_input_dims_mapping(ref_var_name, grad_op_dist_attr.set_input_dims_mapping(
ref_dims_mapping) ref_var_name, ref_dims_mapping)
grad_op_dist_attr.set_output_dims_mapping(output_var_name, grad_op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) output_var_name, ref_dims_mapping)
else: else:
raise ValueError("got unexpect op [{}]".format( raise ValueError("got unexpect op [{}]".format(
...@@ -1186,7 +1189,7 @@ class Completer: ...@@ -1186,7 +1189,7 @@ class Completer:
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
# Notice: serial_main_program is actually a dist_main_program of current rank, # Notice: serial_main_program is actually a dist_main_program of current rank,
# and must be passed into this function. # and must be passed into this function.
# TODO: We should fix this behavior. # TODO: We should fix this behavior.
ops = list(serial_main_program.global_block().ops) ops = list(serial_main_program.global_block().ops)
...@@ -1223,10 +1226,10 @@ class Completer: ...@@ -1223,10 +1226,10 @@ class Completer:
op, op_dist_attr) op, op_dist_attr)
if "Grad" in op.input_names and "Param" in ops[idx].input_names: if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input( assert len(
"Param")) == 1, "Only support one-to-one now." op.input("Param")) == 1, "Only support one-to-one now."
assert len(op.input( assert len(
"Grad")) == 1, "Only support one-to-one now." op.input("Grad")) == 1, "Only support one-to-one now."
param = vars[op.input("Param")[0]] param = vars[op.input("Param")[0]]
grad_var = vars[op.input("Grad")[0]] grad_var = vars[op.input("Grad")[0]]
...@@ -1245,12 +1248,12 @@ class Completer: ...@@ -1245,12 +1248,12 @@ class Completer:
ref_dims_mapping) ref_dims_mapping)
op_dist_attr.set_input_dims_mapping(param.name, op_dist_attr.set_input_dims_mapping(param.name,
ref_dims_mapping) ref_dims_mapping)
op_dist_attr.set_output_dims_mapping(param.name, op_dist_attr.set_output_dims_mapping(
ref_dims_mapping) param.name, ref_dims_mapping)
learning_var = vars[op.input("LearningRate")[0]] learning_var = vars[op.input("LearningRate")[0]]
op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
op_dist_attr.set_output_dims_mapping(learning_var.name, op_dist_attr.set_output_dims_mapping(
[-1]) learning_var.name, [-1])
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
...@@ -1275,10 +1278,10 @@ class Completer: ...@@ -1275,10 +1278,10 @@ class Completer:
if "Beta1Pow" in input_name or "Beta2Pow" in input_name: if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
input_var_attr.dims_mapping = [-1] input_var_attr.dims_mapping = [-1]
op_dist_attr.set_input_dims_mapping(input_var.name, op_dist_attr.set_input_dims_mapping(
[-1]) input_var.name, [-1])
op_dist_attr.set_output_dims_mapping(input_var.name, op_dist_attr.set_output_dims_mapping(
[-1]) input_var.name, [-1])
else: else:
assert "Moment" in input_name assert "Moment" in input_name
input_var_attr.dims_mapping = ref_dims_mapping input_var_attr.dims_mapping = ref_dims_mapping
......
...@@ -133,8 +133,9 @@ class Converter(object): ...@@ -133,8 +133,9 @@ class Converter(object):
tensors_dict[tensor_name] = Converter.merge_and_slice( tensors_dict[tensor_name] = Converter.merge_and_slice(
tensor_list, pre_dist_attr, cur_dist_attr) tensor_list, pre_dist_attr, cur_dist_attr)
except ValueError as err: except ValueError as err:
raise ValueError("Fail to convert tensor '{}'. " raise ValueError(
.format(str(tensor_name)) + str(err)) "Fail to convert tensor '{}'. ".format(str(tensor_name)) +
str(err))
for tensor_name in self._pre_strategy: for tensor_name in self._pre_strategy:
if tensor_name not in self._cur_strategy: if tensor_name not in self._cur_strategy:
...@@ -150,17 +151,17 @@ class Converter(object): ...@@ -150,17 +151,17 @@ class Converter(object):
tensor_not_in_cur = set(tensor_not_in_cur) - set(tensor_match_with_cur) tensor_not_in_cur = set(tensor_not_in_cur) - set(tensor_match_with_cur)
if tensor_not_in_pre: if tensor_not_in_pre:
warnings.warn( warnings.warn(
"tensors [{}] are not found in last training strategy." "tensors [{}] are not found in last training strategy.".format(
.format(str(tensor_not_in_pre))) str(tensor_not_in_pre)))
if tensor_not_in_cur: if tensor_not_in_cur:
warnings.warn( warnings.warn(
"tensors [{}] are not found in current training strategy." "tensors [{}] are not found in current training strategy.".
.format(str(tensor_not_in_cur))) format(str(tensor_not_in_cur)))
if tensor_not_in_ckpt: if tensor_not_in_ckpt:
warnings.warn( warnings.warn(
"tensors [{}] are found in pre_strategy, but are not found" "tensors [{}] are found in pre_strategy, but are not found"
"in checkpoint files, please check your checkpoint files." "in checkpoint files, please check your checkpoint files.".
.format(str(tensor_not_in_ckpt))) format(str(tensor_not_in_ckpt)))
return tensors_dict return tensors_dict
...@@ -360,8 +361,9 @@ class Converter(object): ...@@ -360,8 +361,9 @@ class Converter(object):
""" """
sliced_tensor_list = [] sliced_tensor_list = []
axis = len(complete_tensor.shape) - length axis = len(complete_tensor.shape) - length
sliced_tensor = np.split( sliced_tensor = np.split(complete_tensor,
complete_tensor, partition_index_list[axis], axis=axis) partition_index_list[axis],
axis=axis)
if length == 1: if length == 1:
return sliced_tensor return sliced_tensor
for tensor in sliced_tensor: for tensor in sliced_tensor:
......
...@@ -85,8 +85,8 @@ def _parse_op_to_desc(op, dist_context=None): ...@@ -85,8 +85,8 @@ def _parse_op_to_desc(op, dist_context=None):
def parse_to_desc(op=None, dist_op=None, dist_context=None): def parse_to_desc(op=None, dist_op=None, dist_context=None):
desc = None desc = None
if op is None and dist_op is not None and dist_context is not None: if op is None and dist_op is not None and dist_context is not None:
desc = _parse_op_to_desc( desc = _parse_op_to_desc(op=dist_op.serial_op,
op=dist_op.serial_op, dist_context=dist_context) dist_context=dist_context)
elif op is not None and dist_op is None and dist_context is None: elif op is not None and dist_op is None and dist_context is None:
desc = _parse_op_to_desc(op) desc = _parse_op_to_desc(op)
...@@ -94,6 +94,7 @@ def parse_to_desc(op=None, dist_op=None, dist_context=None): ...@@ -94,6 +94,7 @@ def parse_to_desc(op=None, dist_op=None, dist_context=None):
def parse_desc_to_str(desc): def parse_desc_to_str(desc):
def _parse_dtype(dtype): def _parse_dtype(dtype):
dtype_str = "" dtype_str = ""
if dtype == paddle.float32: if dtype == paddle.float32:
...@@ -248,10 +249,10 @@ class CommContext: ...@@ -248,10 +249,10 @@ class CommContext:
else: else:
for i in range(len(ranks)): for i in range(len(ranks)):
for j in range(i + 1, len(ranks)): for j in range(i + 1, len(ranks)):
forward_order_beta = self.cluster.get_beta(ranks[i], forward_order_beta = self.cluster.get_beta(
ranks[j]) ranks[i], ranks[j])
backward_order_beta = self.cluster.get_beta(ranks[j], backward_order_beta = self.cluster.get_beta(
ranks[i]) ranks[j], ranks[i])
beta = forward_order_beta if forward_order_beta > backward_order_beta else backward_order_beta beta = forward_order_beta if forward_order_beta > backward_order_beta else backward_order_beta
if max_beta == None: if max_beta == None:
max_beta = beta max_beta = beta
...@@ -275,6 +276,7 @@ class CommContext: ...@@ -275,6 +276,7 @@ class CommContext:
class Cost: class Cost:
def __init__(self, time=0, memory=0, flops=0): def __init__(self, time=0, memory=0, flops=0):
self.time = time self.time = time
self.memory = memory self.memory = memory
...@@ -338,6 +340,7 @@ class Cost: ...@@ -338,6 +340,7 @@ class Cost:
class OpCost: class OpCost:
def __init__(self, op=None, op_desc=None): def __init__(self, op=None, op_desc=None):
self._op = op self._op = op
self._op_desc = op_desc self._op_desc = op_desc
...@@ -462,8 +465,8 @@ class CommOpCost(OpCost): ...@@ -462,8 +465,8 @@ class CommOpCost(OpCost):
elif dtype == paddle.float16: elif dtype == paddle.float16:
factor = 2 factor = 2
else: else:
raise TypeError("This dtype {} is not supported now".format( raise TypeError(
dtype)) "This dtype {} is not supported now".format(dtype))
comm_count = reduce(lambda x, y: x * y, shape) * factor comm_count = reduce(lambda x, y: x * y, shape) * factor
self._comm_count = comm_count self._comm_count = comm_count
...@@ -506,8 +509,9 @@ class CommOpCost(OpCost): ...@@ -506,8 +509,9 @@ class CommOpCost(OpCost):
def _check_comm_op_type(cls): def _check_comm_op_type(cls):
if cls.OP_TYPE != "COMM": if cls.OP_TYPE != "COMM":
if cls.OP_TYPE not in COMM_OP_TYPE: if cls.OP_TYPE not in COMM_OP_TYPE:
raise TypeError("Please Check op type in {}, but got {}.". raise TypeError(
format(COMM_OP_TYPE, cls.OP_TYPE)) "Please Check op type in {}, but got {}.".format(
COMM_OP_TYPE, cls.OP_TYPE))
class CompOpCost(OpCost): class CompOpCost(OpCost):
...@@ -523,8 +527,9 @@ class CompOpCost(OpCost): ...@@ -523,8 +527,9 @@ class CompOpCost(OpCost):
def _check_comp_op_type(cls): def _check_comp_op_type(cls):
if cls.OP_TYPE != "COMP": if cls.OP_TYPE != "COMP":
if cls.OP_TYPE in NON_COMP_TYPE: if cls.OP_TYPE in NON_COMP_TYPE:
raise TypeError("Please Check op type not in {}, but got {}.". raise TypeError(
format(NON_COMP_TYPE, cls.OP_TYPE)) "Please Check op type not in {}, but got {}.".format(
NON_COMP_TYPE, cls.OP_TYPE))
def register_op_cost(cls): def register_op_cost(cls):
......
...@@ -22,8 +22,9 @@ class AllreduceSumOpCost(CommOpCost): ...@@ -22,8 +22,9 @@ class AllreduceSumOpCost(CommOpCost):
OP_TYPE = "c_allreduce_sum" OP_TYPE = "c_allreduce_sum"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(AllreduceSumOpCost, self).__init__( super(AllreduceSumOpCost, self).__init__(op=op,
op=op, op_desc=op_desc, comm_context=comm_context) op_desc=op_desc,
comm_context=comm_context)
def calc_time(self): def calc_time(self):
# use tree if cross machine and use ring if in a single machine # use tree if cross machine and use ring if in a single machine
...@@ -38,20 +39,20 @@ class AllreduceSumOpCost(CommOpCost): ...@@ -38,20 +39,20 @@ class AllreduceSumOpCost(CommOpCost):
def calc_time_ring(self): def calc_time_ring(self):
alpha = self.comm_context.base_ring alpha = self.comm_context.base_ring
alpha += 2 * ( alpha += 2 * (self.rank_count -
self.rank_count - self.machine_count) * self.comm_context.intra_ring self.machine_count) * self.comm_context.intra_ring
alpha += 2 * (self.machine_count - 1) * ( alpha += 2 * (self.machine_count - 1) * (
self.comm_context.inter_ring + self.hops * self.comm_context.switch) self.comm_context.inter_ring + self.hops * self.comm_context.switch)
beta = self.comm_context.get_max_beta(self.group_ranks) beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + 2 * (self.rank_count - 1 time = alpha + 2 * (self.rank_count -
) / self.rank_count * self.comm_count * beta 1) / self.rank_count * self.comm_count * beta
return time return time
def calc_time_tree(self): def calc_time_tree(self):
alpha = self.comm_context.base_tree alpha = self.comm_context.base_tree
alpha += 2 * (self.rank_count / self.machine_count - 1 alpha += 2 * (self.rank_count / self.machine_count -
) * self.comm_context.intra_tree 1) * self.comm_context.intra_tree
alpha += math.log2(self.machine_count) * ( alpha += math.log2(self.machine_count) * (
self.comm_context.inter_tree + self.hops * self.comm_context.switch) self.comm_context.inter_tree + self.hops * self.comm_context.switch)
beta = self.comm_context.get_max_beta(self.group_ranks) beta = self.comm_context.get_max_beta(self.group_ranks)
...@@ -66,8 +67,9 @@ class AllgatherOpCost(CommOpCost): ...@@ -66,8 +67,9 @@ class AllgatherOpCost(CommOpCost):
OP_TYPE = "c_allgather" OP_TYPE = "c_allgather"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(AllgatherOpCost, self).__init__( super(AllgatherOpCost, self).__init__(op=op,
op=op, op_desc=op_desc, comm_context=comm_context) op_desc=op_desc,
comm_context=comm_context)
def calc_time(self): def calc_time(self):
time = self.calc_time_ring() time = self.calc_time_ring()
...@@ -75,13 +77,13 @@ class AllgatherOpCost(CommOpCost): ...@@ -75,13 +77,13 @@ class AllgatherOpCost(CommOpCost):
def calc_time_ring(self): def calc_time_ring(self):
alpha = self.comm_context.base_ring alpha = self.comm_context.base_ring
alpha += ( alpha += (self.rank_count -
self.rank_count - self.machine_count) * self.comm_context.intra_ring self.machine_count) * self.comm_context.intra_ring
alpha += (self.machine_count - 1) * ( alpha += (self.machine_count - 1) * (
self.comm_context.inter_ring + self.hops * self.comm_context.switch) self.comm_context.inter_ring + self.hops * self.comm_context.switch)
beta = self.comm_context.get_max_beta(self.group_ranks) beta = self.comm_context.get_max_beta(self.group_ranks)
time = alpha + (self.rank_count - 1 time = alpha + (self.rank_count -
) / self.rank_count * self.comm_count * beta 1) / self.rank_count * self.comm_count * beta
return time return time
...@@ -90,8 +92,9 @@ class BroadcastOpCost(CommOpCost): ...@@ -90,8 +92,9 @@ class BroadcastOpCost(CommOpCost):
OP_TYPE = "c_broadcast" OP_TYPE = "c_broadcast"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(BroadcastOpCost, self).__init__( super(BroadcastOpCost, self).__init__(op=op,
op=op, op_desc=op_desc, comm_context=comm_context) op_desc=op_desc,
comm_context=comm_context)
def calc_time(self): def calc_time(self):
time = self.calc_time_ring() time = self.calc_time_ring()
...@@ -114,8 +117,9 @@ class IdentityOpCost(CommOpCost): ...@@ -114,8 +117,9 @@ class IdentityOpCost(CommOpCost):
OP_TYPE = "c_identity" OP_TYPE = "c_identity"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(IdentityOpCost, self).__init__( super(IdentityOpCost, self).__init__(op=op,
op=op, op_desc=op_desc, comm_context=comm_context) op_desc=op_desc,
comm_context=comm_context)
def calc_time(self): def calc_time(self):
return 0 return 0
...@@ -126,8 +130,9 @@ class RecvOpCost(CommOpCost): ...@@ -126,8 +130,9 @@ class RecvOpCost(CommOpCost):
OP_TYPE = "recv_v2" OP_TYPE = "recv_v2"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(RecvOpCost, self).__init__( super(RecvOpCost, self).__init__(op=op,
op=op, op_desc=op_desc, comm_context=comm_context) op_desc=op_desc,
comm_context=comm_context)
def calc_time(self): def calc_time(self):
alpha = self.comm_context.base_ring alpha = self.comm_context.base_ring
...@@ -145,8 +150,9 @@ class SendOpCost(CommOpCost): ...@@ -145,8 +150,9 @@ class SendOpCost(CommOpCost):
OP_TYPE = "send_v2" OP_TYPE = "send_v2"
def __init__(self, op=None, op_desc=None, comm_context=None): def __init__(self, op=None, op_desc=None, comm_context=None):
super(SendOpCost, self).__init__( super(SendOpCost, self).__init__(op=op,
op=op, op_desc=op_desc, comm_context=comm_context) op_desc=op_desc,
comm_context=comm_context)
def calc_time(self): def calc_time(self):
alpha = self.comm_context.base_ring alpha = self.comm_context.base_ring
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
class CostEstimator: class CostEstimator:
def __init__(self, def __init__(self,
program, program,
cluster=None, cluster=None,
......
...@@ -22,6 +22,7 @@ from .base_cost import Cost ...@@ -22,6 +22,7 @@ from .base_cost import Cost
class TensorCost: class TensorCost:
def __init__(self, tensor=None, dist_tensor=None, shape=None, dtype=None): def __init__(self, tensor=None, dist_tensor=None, shape=None, dtype=None):
self._check_args(tensor, dist_tensor, shape, dtype) self._check_args(tensor, dist_tensor, shape, dtype)
self._tensor = tensor self._tensor = tensor
...@@ -59,20 +60,20 @@ class TensorCost: ...@@ -59,20 +60,20 @@ class TensorCost:
assert (tensor is None and shape is None) assert (tensor is None and shape is None)
if not isinstance(dist_tensor, DistributedTensor): if not isinstance(dist_tensor, DistributedTensor):
raise TypeError( raise TypeError(
"Please check dist_tensor type is DistributedTensor, but got {}". "Please check dist_tensor type is DistributedTensor, but got {}"
format(type(dist_tensor))) .format(type(dist_tensor)))
elif shape is not None: elif shape is not None:
assert (tensor is None and dist_tensor is None and assert (tensor is None and dist_tensor is None
dtype is not None) and dtype is not None)
if not isinstance(shape, (list, set)): if not isinstance(shape, (list, set)):
raise TypeError( raise TypeError(
"Please check shape type is list or set, but got {}".format( "Please check shape type is list or set, but got {}".format(
type(shape))) type(shape)))
elif dtype is not None: elif dtype is not None:
assert (tensor is None and dist_tensor is None and assert (tensor is None and dist_tensor is None
shape is not None) and shape is not None)
@property @property
def cost(self): def cost(self):
......
...@@ -37,6 +37,7 @@ class CostNodeType(Enum): ...@@ -37,6 +37,7 @@ class CostNodeType(Enum):
class Cost(object): class Cost(object):
def __init__(self): def __init__(self):
self.runtime = None self.runtime = None
self.static_mem = None self.static_mem = None
...@@ -51,6 +52,7 @@ class CostModelMode(Enum): ...@@ -51,6 +52,7 @@ class CostModelMode(Enum):
class CostNode(object): class CostNode(object):
def __init__(self, node, node_type, id=None): def __init__(self, node, node_type, id=None):
self.id = id self.id = id
self.node = node self.node = node
...@@ -71,6 +73,7 @@ class CostNode(object): ...@@ -71,6 +73,7 @@ class CostNode(object):
class MergedOpsCostNode(CostNode): class MergedOpsCostNode(CostNode):
def __init__(self, node_type, id=None, base_node_list=None, is_bwd=False): def __init__(self, node_type, id=None, base_node_list=None, is_bwd=False):
super(MergedOpsCostNode, self).__init__(None, node_type, id) super(MergedOpsCostNode, self).__init__(None, node_type, id)
self.node_list = base_node_list self.node_list = base_node_list
...@@ -78,6 +81,7 @@ class MergedOpsCostNode(CostNode): ...@@ -78,6 +81,7 @@ class MergedOpsCostNode(CostNode):
class CommOpCostNode(CostNode): class CommOpCostNode(CostNode):
def __init__(self, def __init__(self,
node, node,
node_type, node_type,
...@@ -118,6 +122,7 @@ class CommOpCostNode(CostNode): ...@@ -118,6 +122,7 @@ class CommOpCostNode(CostNode):
class TensorCostNode(CostNode): class TensorCostNode(CostNode):
def __init__(self, def __init__(self,
node, node,
node_type, node_type,
...@@ -159,6 +164,7 @@ class TensorCostNode(CostNode): ...@@ -159,6 +164,7 @@ class TensorCostNode(CostNode):
class CompOpCostNode(CostNode): class CompOpCostNode(CostNode):
def __init__(self, node, node_type, id=None, is_bwd=False, is_optim=False): def __init__(self, node, node_type, id=None, is_bwd=False, is_optim=False):
super(CompOpCostNode, self).__init__(node, node_type, id) super(CompOpCostNode, self).__init__(node, node_type, id)
self.is_bwd = is_bwd self.is_bwd = is_bwd
...@@ -174,6 +180,7 @@ class CompOpCostNode(CostNode): ...@@ -174,6 +180,7 @@ class CompOpCostNode(CostNode):
class PipeEvent(object): class PipeEvent(object):
def __init__(self, stage_id, event_name, duration, start_time=-1): def __init__(self, stage_id, event_name, duration, start_time=-1):
self.stage_id = stage_id self.stage_id = stage_id
self.name = event_name self.name = event_name
...@@ -183,6 +190,7 @@ class PipeEvent(object): ...@@ -183,6 +190,7 @@ class PipeEvent(object):
class CostModel(object): class CostModel(object):
def __init__(self, def __init__(self,
mode=CostModelMode.BENCHMARKING, mode=CostModelMode.BENCHMARKING,
cluster=None, cluster=None,
...@@ -261,8 +269,8 @@ class CostModel(object): ...@@ -261,8 +269,8 @@ class CostModel(object):
op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id, op_node = CommOpCostNode(op, CostNodeType.COMMUNICATION, op_id,
is_bwd) is_bwd)
else: else:
is_bwd = (int(op.attr('op_role')) == int(OpRole.Backward) is_bwd = (int(op.attr('op_role')) == int(
) or "@GRAD" in op.input_arg_names OpRole.Backward)) or "@GRAD" in op.input_arg_names
is_optim = 'LearningRate' in op.input_names is_optim = 'LearningRate' in op.input_names
op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id, op_node = CompOpCostNode(op, CostNodeType.COMPUTATION, op_id,
is_bwd, is_optim) is_bwd, is_optim)
...@@ -310,11 +318,10 @@ class CostModel(object): ...@@ -310,11 +318,10 @@ class CostModel(object):
write_op_cnt += 1 write_op_cnt += 1
new_var_id = node_id + '_write_{}'.format(write_op_cnt) new_var_id = node_id + '_write_{}'.format(write_op_cnt)
new_var = TensorCostNode( new_var = TensorCostNode(node.node,
node.node, CostNodeType.VARIABLE,
CostNodeType.VARIABLE, new_var_id,
new_var_id, shared_node_id=node_id)
shared_node_id=node_id)
graph[new_var_id] = [[], []] graph[new_var_id] = [[], []]
graph[pred_id][SUCC].append(new_var_id) graph[pred_id][SUCC].append(new_var_id)
...@@ -341,8 +348,8 @@ class CostModel(object): ...@@ -341,8 +348,8 @@ class CostModel(object):
self.runtime_graph.append({}) self.runtime_graph.append({})
self._parse_sub_program( self._parse_sub_program(
sub_prog, self.nodes[sub_idx], self.origin_graph[sub_idx], sub_prog, self.nodes[sub_idx], self.origin_graph[sub_idx],
self.cost_data[0 if self.rank2pp is None else self.rank2pp[ self.cost_data[0 if self.rank2pp is None else self.
sub_idx]], sub_idx) rank2pp[sub_idx]], sub_idx)
return self.nodes return self.nodes
def _find_succ_op(self, node_id, sub_idx=0): def _find_succ_op(self, node_id, sub_idx=0):
...@@ -417,11 +424,10 @@ class CostModel(object): ...@@ -417,11 +424,10 @@ class CostModel(object):
merge_type)) merge_type))
merged_node_id = 'merged_' + str(len(nodes)) merged_node_id = 'merged_' + str(len(nodes))
is_bwd = to_merge_node_list[0].is_bwd is_bwd = to_merge_node_list[0].is_bwd
merged_node = MergedOpsCostNode( merged_node = MergedOpsCostNode(CostNodeType.MERGED,
CostNodeType.MERGED, id=merged_node_id,
id=merged_node_id, base_node_list=nodes_list,
base_node_list=nodes_list, is_bwd=is_bwd)
is_bwd=is_bwd)
merged_node.cost = node_cost merged_node.cost = node_cost
return merged_node_id, merged_node return merged_node_id, merged_node
...@@ -435,10 +441,12 @@ class CostModel(object): ...@@ -435,10 +441,12 @@ class CostModel(object):
''' '''
cnt = 0 cnt = 0
for sub_idx in range(self.total_rank): for sub_idx in range(self.total_rank):
cnt += self._merge_linear( cnt += self._merge_linear(self.nodes[sub_idx],
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False) self.runtime_graph[sub_idx],
cnt += self._merge_linear( is_bwd=False)
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True) cnt += self._merge_linear(self.nodes[sub_idx],
self.runtime_graph[sub_idx],
is_bwd=True)
return cnt return cnt
def merge_branch(self): def merge_branch(self):
...@@ -454,10 +462,12 @@ class CostModel(object): ...@@ -454,10 +462,12 @@ class CostModel(object):
''' '''
cnt = 0 cnt = 0
for sub_idx in range(self.total_rank): for sub_idx in range(self.total_rank):
cnt += self._merge_branch( cnt += self._merge_branch(self.nodes[sub_idx],
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=False) self.runtime_graph[sub_idx],
cnt += self._merge_branch( is_bwd=False)
self.nodes[sub_idx], self.runtime_graph[sub_idx], is_bwd=True) cnt += self._merge_branch(self.nodes[sub_idx],
self.runtime_graph[sub_idx],
is_bwd=True)
return cnt return cnt
def _merge_linear(self, nodes, runtime_graph, is_bwd=False): def _merge_linear(self, nodes, runtime_graph, is_bwd=False):
...@@ -482,8 +492,8 @@ class CostModel(object): ...@@ -482,8 +492,8 @@ class CostModel(object):
# delete edges and add new edges # delete edges and add new edges
succ = None succ = None
try: try:
runtime_graph[merged_node_id][SUCC] = copy.deepcopy(edges[ runtime_graph[merged_node_id][SUCC] = copy.deepcopy(
SUCC]) edges[SUCC])
if len(runtime_graph[pred_id][SUCC]) > 1: if len(runtime_graph[pred_id][SUCC]) > 1:
# predecessor has more than 1 successor # predecessor has more than 1 successor
...@@ -558,8 +568,8 @@ class CostModel(object): ...@@ -558,8 +568,8 @@ class CostModel(object):
to_merge = True to_merge = True
try: try:
if len(edges[SUCC]) < 1 or len(runtime_graph[edges[SUCC][0]] if len(edges[SUCC]) < 1 or len(
[SUCC]) < 1: runtime_graph[edges[SUCC][0]][SUCC]) < 1:
continue continue
except: except:
continue continue
...@@ -596,6 +606,7 @@ class CostModel(object): ...@@ -596,6 +606,7 @@ class CostModel(object):
return reduct_cnt return reduct_cnt
def get_runtime_cost(self): def get_runtime_cost(self):
def get_node_cost(node): def get_node_cost(node):
node_cost = node.cost + self.opcall_overhead node_cost = node.cost + self.opcall_overhead
if isinstance(node, MergedOpsCostNode): if isinstance(node, MergedOpsCostNode):
...@@ -660,8 +671,8 @@ class CostModel(object): ...@@ -660,8 +671,8 @@ class CostModel(object):
static_mem += size static_mem += size
cur_mem += size cur_mem += size
edges = sim_graph[node_id] edges = sim_graph[node_id]
if not (node.type == CostNodeType.VARIABLE and if not (node.type == CostNodeType.VARIABLE
node.node.persistable): and node.node.persistable):
for succ_id in edges[SUCC]: for succ_id in edges[SUCC]:
sim_graph[succ_id][PRED].remove(node_id) sim_graph[succ_id][PRED].remove(node_id)
if len(sim_graph[succ_id][PRED]) == 0: if len(sim_graph[succ_id][PRED]) == 0:
...@@ -670,8 +681,8 @@ class CostModel(object): ...@@ -670,8 +681,8 @@ class CostModel(object):
pred = nodes pred = nodes
if pred.type == CostNodeType.VARIABLE: if pred.type == CostNodeType.VARIABLE:
sim_graph[pred_id][SUCC].remove(node_id) sim_graph[pred_id][SUCC].remove(node_id)
if len(sim_graph[pred_id][ if len(sim_graph[pred_id]
SUCC]) == 0 and not pred.node.persistable: [SUCC]) == 0 and not pred.node.persistable:
cur_mem -= pred.get_size() cur_mem -= pred.get_size()
return static_mem, cur_mem, top_mem return static_mem, cur_mem, top_mem
...@@ -703,18 +714,16 @@ class CostModel(object): ...@@ -703,18 +714,16 @@ class CostModel(object):
event_list.append(e) event_list.append(e)
if stid != stage_num - 1: if stid != stage_num - 1:
q.put( q.put(
PipeEvent( PipeEvent(stid + 1,
stid + 1, 'fwd',
'fwd', self.fwd_time[stid + 1],
self.fwd_time[stid + 1], start_time=e.e_time))
start_time=e.e_time))
else: else:
q.put( q.put(
PipeEvent( PipeEvent(stid,
stid, 'bwd',
'bwd', self.bwd_time[stid],
self.bwd_time[stid], start_time=e.e_time))
start_time=e.e_time))
fwd_cnt[stid] -= 1 fwd_cnt[stid] -= 1
global_time[stid] = e.e_time global_time[stid] = e.e_time
else: else:
...@@ -725,20 +734,18 @@ class CostModel(object): ...@@ -725,20 +734,18 @@ class CostModel(object):
event_list.append(e) event_list.append(e)
if stid != 0: if stid != 0:
q.put( q.put(
PipeEvent( PipeEvent(stid - 1,
stid - 1, 'bwd',
'bwd', self.bwd_time[stid - 1],
self.bwd_time[stid - 1], start_time=e.e_time))
start_time=e.e_time))
fwd_cnt[stid] += 1 fwd_cnt[stid] += 1
bwd_cnt[stid] -= 1 bwd_cnt[stid] -= 1
if bwd_cnt[stid] == 0: if bwd_cnt[stid] == 0:
q.put( q.put(
PipeEvent( PipeEvent(stid,
stid, 'optim',
'optim', self.optim_time[stid],
self.optim_time[stid], start_time=e.e_time))
start_time=e.e_time))
global_time[stid] = e.e_time global_time[stid] = e.e_time
elif e.name == 'optim': elif e.name == 'optim':
e.s_time = max(global_time[stid], e.s_time) e.s_time = max(global_time[stid], e.s_time)
...@@ -792,11 +799,10 @@ def estimate_cost(distributed_program, cluster, pipeline_config, ...@@ -792,11 +799,10 @@ def estimate_cost(distributed_program, cluster, pipeline_config,
""" """
# the following line is left for now, cluster model will be involved in the future # the following line is left for now, cluster model will be involved in the future
assert cluster is None, "For now, cluster remains None" assert cluster is None, "For now, cluster remains None"
cm_ctx = CostModel( cm_ctx = CostModel(cluster=cluster,
cluster=cluster, batch_size=batch_size,
batch_size=batch_size, standalone_cost_data=standalone_cost_data,
standalone_cost_data=standalone_cost_data, pipeline_config=pipeline_config)
pipeline_config=pipeline_config)
cm_ctx.init(distributed_program) cm_ctx.init(distributed_program)
cost = cm_ctx.get_cost() cost = cm_ctx.get_cost()
return cost return cost
...@@ -51,6 +51,7 @@ def append_op_output_suffix(name): ...@@ -51,6 +51,7 @@ def append_op_output_suffix(name):
class TensorDistributedAttribute: class TensorDistributedAttribute:
def __init__(self): def __init__(self):
# The process mesh of distributed operator attribute must is the same as # The process mesh of distributed operator attribute must is the same as
# the process meshes of all input and output distributed attributed # the process meshes of all input and output distributed attributed
...@@ -123,8 +124,8 @@ class TensorDistributedAttribute: ...@@ -123,8 +124,8 @@ class TensorDistributedAttribute:
key, dist_attr) key, dist_attr)
elif isinstance(dist_attr, TensorDistributedAttribute): elif isinstance(dist_attr, TensorDistributedAttribute):
for key in get_tensor_dist_attr_field_keys(): for key in get_tensor_dist_attr_field_keys():
field_property = TensorDistributedAttribute.__dict__.get(key, field_property = TensorDistributedAttribute.__dict__.get(
None) key, None)
if field_property: if field_property:
field_property.fset(self, field_property.fget(dist_attr)) field_property.fset(self, field_property.fget(dist_attr))
else: else:
...@@ -192,6 +193,7 @@ class TensorDistributedAttribute: ...@@ -192,6 +193,7 @@ class TensorDistributedAttribute:
class OperatorDistributedAttribute: class OperatorDistributedAttribute:
def __init__(self): def __init__(self):
self._process_mesh = None self._process_mesh = None
self._op_type = None self._op_type = None
...@@ -356,8 +358,8 @@ class OperatorDistributedAttribute: ...@@ -356,8 +358,8 @@ class OperatorDistributedAttribute:
tensor_name, dist_attr.get_output_dist_attr(tensor_name)) tensor_name, dist_attr.get_output_dist_attr(tensor_name))
self._is_annotated = copy.deepcopy(dist_attr._is_annotated) self._is_annotated = copy.deepcopy(dist_attr._is_annotated)
for key in get_op_dist_attr_field_keys(): for key in get_op_dist_attr_field_keys():
field_property = OperatorDistributedAttribute.__dict__.get(key, field_property = OperatorDistributedAttribute.__dict__.get(
None) key, None)
if field_property: if field_property:
field_property.fset(self, field_property.fget(dist_attr)) field_property.fset(self, field_property.fget(dist_attr))
else: else:
......
...@@ -203,8 +203,8 @@ class DistributedContext: ...@@ -203,8 +203,8 @@ class DistributedContext:
self._serial_main_program.clone()) self._serial_main_program.clone())
self._backup_serial_startup_program_stack.append( self._backup_serial_startup_program_stack.append(
self._serial_startup_program.clone()) self._serial_startup_program.clone())
self._backup_pass_context_stack.append( self._backup_pass_context_stack.append(copy.deepcopy(
copy.deepcopy(self._pass_context)) self._pass_context))
self._backup_block_state_stack.append(copy.deepcopy(self._block_state)) self._backup_block_state_stack.append(copy.deepcopy(self._block_state))
def _backup_dist_info(self, mode): def _backup_dist_info(self, mode):
...@@ -398,8 +398,8 @@ class DistributedContext: ...@@ -398,8 +398,8 @@ class DistributedContext:
return dist_tensor return dist_tensor
else: else:
serial_tensor_id = serial_tensor.desc.original_id() serial_tensor_id = serial_tensor.desc.original_id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, dist_tensor = self._dist_tensors_for_program.get(
None) serial_tensor_id, None)
if dist_tensor: if dist_tensor:
return dist_tensor return dist_tensor
else: else:
...@@ -438,8 +438,8 @@ class DistributedContext: ...@@ -438,8 +438,8 @@ class DistributedContext:
return dist_tensor.dist_attr return dist_tensor.dist_attr
else: else:
serial_tensor_id = serial_tensor.desc.original_id() serial_tensor_id = serial_tensor.desc.original_id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, dist_tensor = self._dist_tensors_for_program.get(
None) serial_tensor_id, None)
if dist_tensor: if dist_tensor:
return dist_tensor.dist_attr return dist_tensor.dist_attr
else: else:
...@@ -548,6 +548,7 @@ class DistributedContext: ...@@ -548,6 +548,7 @@ class DistributedContext:
self._dist_ops_for_program) self._dist_ops_for_program)
def _order_nodes_by_program_order(self): def _order_nodes_by_program_order(self):
def _contains(nodes, target_node): def _contains(nodes, target_node):
for node in nodes: for node in nodes:
if _node_id(node) == _node_id(target_node): if _node_id(node) == _node_id(target_node):
...@@ -719,8 +720,8 @@ class DistributedContext: ...@@ -719,8 +720,8 @@ class DistributedContext:
# here we just set there process_mesh to the first one. # here we just set there process_mesh to the first one.
for orphan_node in self._serial_orphan_tensor_nodes: for orphan_node in self._serial_orphan_tensor_nodes:
serial_tensor_id = orphan_node.var().id() serial_tensor_id = orphan_node.var().id()
dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, dist_tensor = self._dist_tensors_for_program.get(
None) serial_tensor_id, None)
if dist_tensor: if dist_tensor:
dist_tensor.dist_attr.process_mesh = self._process_meshes[0] dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
else: else:
...@@ -807,11 +808,10 @@ class DistributedContext: ...@@ -807,11 +808,10 @@ class DistributedContext:
assert dist_tensor is not None, \ assert dist_tensor is not None, \
"Tensor {} does not have a distributed attribute.".format( "Tensor {} does not have a distributed attribute.".format(
dist_tensor.serial_tensor.name) dist_tensor.serial_tensor.name)
if (dist_tensor is not None) and ( if (dist_tensor
not dist_tensor.validate_dist_attr()): is not None) and (not dist_tensor.validate_dist_attr()):
assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format( assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
dist_tensor.serial_tensor.name, dist_tensor.serial_tensor.name, dist_tensor.desc.id(),
dist_tensor.desc.id(),
dist_tensor.desc.original_id(), dist_tensor.dist_attr) dist_tensor.desc.original_id(), dist_tensor.dist_attr)
for op in block.ops: for op in block.ops:
dist_op = self.get_dist_op_for_program(op) dist_op = self.get_dist_op_for_program(op)
...@@ -820,8 +820,7 @@ class DistributedContext: ...@@ -820,8 +820,7 @@ class DistributedContext:
dist_op.serial_op.type) dist_op.serial_op.type)
if (dist_op is not None) and (not dist_op.validate_dist_attr()): if (dist_op is not None) and (not dist_op.validate_dist_attr()):
assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format( assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
dist_op.serial_op.type, dist_op.serial_op.type, dist_op.serial_op.desc.id(),
dist_op.serial_op.desc.id(),
dist_op.serial_op.desc.original_id(), dist_op.dist_attr) dist_op.serial_op.desc.original_id(), dist_op.dist_attr)
return True return True
...@@ -947,6 +946,7 @@ class DistributedOperatorContext: ...@@ -947,6 +946,7 @@ class DistributedOperatorContext:
class BlockState(object): class BlockState(object):
def __init__(self): def __init__(self):
self.nblock = 0 self.nblock = 0
self.forward_indices = [] self.forward_indices = []
......
...@@ -21,6 +21,7 @@ from paddle.io import DataLoader, DistributedBatchSampler ...@@ -21,6 +21,7 @@ from paddle.io import DataLoader, DistributedBatchSampler
class DistributedDataLoader(metaclass=abc.ABCMeta): class DistributedDataLoader(metaclass=abc.ABCMeta):
def __init__(self, def __init__(self,
dataset, dataset,
batch_size=1, batch_size=1,
...@@ -47,6 +48,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -47,6 +48,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
class NonIterableGeneratorLoader(DistributedDataLoader): class NonIterableGeneratorLoader(DistributedDataLoader):
def __init__(self, def __init__(self,
dataset, dataset,
feed_list, feed_list,
...@@ -63,9 +65,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -63,9 +65,10 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
super(NonIterableGeneratorLoader, self).__init__( super(NonIterableGeneratorLoader,
dataset, batch_size, epochs, data_parallel_world_size, self).__init__(dataset, batch_size, epochs,
data_parallel_rank, drop_last) data_parallel_world_size, data_parallel_rank,
drop_last)
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
self._steps = self._infer_steps() self._steps = self._infer_steps()
...@@ -96,6 +99,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -96,6 +99,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
return steps_per_epoch return steps_per_epoch
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def sample_data_generator(): def sample_data_generator():
batch_data = None batch_data = None
for step, data in enumerate(self.dataset): for step, data in enumerate(self.dataset):
......
...@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys ...@@ -26,6 +26,7 @@ from .dist_attribute import get_op_dist_attr_field_keys
class DistributedOperator: class DistributedOperator:
def __init__(self, serial_op, dist_attr=None): def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op self._serial_op = serial_op
self._serial_inputs = {} self._serial_inputs = {}
...@@ -248,6 +249,7 @@ class DistributedOperator: ...@@ -248,6 +249,7 @@ class DistributedOperator:
class DistributedModule: class DistributedModule:
def __init__(self, serial_module, dist_attr=None): def __init__(self, serial_module, dist_attr=None):
self._serial_module = serial_module self._serial_module = serial_module
self._dist_attr = dist_attr self._dist_attr = dist_attr
......
...@@ -53,6 +53,7 @@ def _process_path(path): ...@@ -53,6 +53,7 @@ def _process_path(path):
class DistributedSaver: class DistributedSaver:
def __init__(self): def __init__(self):
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
...@@ -114,8 +115,8 @@ class DistributedSaver: ...@@ -114,8 +115,8 @@ class DistributedSaver:
param_file): param_file):
param_file_list.append(os.path.join(dirname, param_file)) param_file_list.append(os.path.join(dirname, param_file))
param_file_list.sort() param_file_list.sort()
self._logger.info("Load distributed attribute file: {}".format( self._logger.info(
param_file_list)) "Load distributed attribute file: {}".format(param_file_list))
param_dict = {} param_dict = {}
for param_file in param_file_list: for param_file in param_file_list:
with open(param_file, 'rb') as f: with open(param_file, 'rb') as f:
...@@ -131,11 +132,11 @@ class DistributedSaver: ...@@ -131,11 +132,11 @@ class DistributedSaver:
for dist_attr_file in os.listdir(dirname): for dist_attr_file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).pdattr'.format(filename), if check_filename('{}(.*)_dist(.*).pdattr'.format(filename),
dist_attr_file): dist_attr_file):
dist_attr_file_list.append( dist_attr_file_list.append(os.path.join(dirname,
os.path.join(dirname, dist_attr_file)) dist_attr_file))
dist_attr_file_list.sort() dist_attr_file_list.sort()
self._logger.info("Load distributed attribute file: {}".format( self._logger.info(
dist_attr_file_list)) "Load distributed attribute file: {}".format(dist_attr_file_list))
pre_dist_attr = {} pre_dist_attr = {}
for dist_attr_file in dist_attr_file_list: for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f: with open(dist_attr_file, 'rb') as f:
...@@ -206,12 +207,11 @@ class DistributedSaver: ...@@ -206,12 +207,11 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock. # NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id) dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename) dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model( paddle.static.save_inference_model(dist_path,
dist_path, dist_feed_vars,
dist_feed_vars, dist_fetch_vars,
dist_fetch_vars, exe,
exe, program=dist_main_prog)
program=dist_main_prog)
def _save_rank_mapping(self, dirname): def _save_rank_mapping(self, dirname):
path = os.path.join(dirname, 'rank_mapping.csv') path = os.path.join(dirname, 'rank_mapping.csv')
......
...@@ -40,26 +40,26 @@ class DistributedTensor: ...@@ -40,26 +40,26 @@ class DistributedTensor:
processes, processes,
rank=None, rank=None,
shard_sizes=None): shard_sizes=None):
if not (isinstance(sizes, (list, tuple)) and if not (isinstance(sizes, (list, tuple))
all(map(lambda x: isinstance(x, int) and x >= 0, sizes))): and all(map(lambda x: isinstance(x, int) and x >= 0, sizes))):
raise ValueError( raise ValueError(
"The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}". "The sizes must be list or tuple and item in sizes must be non-negative integer, but got {}"
format(sizes)) .format(sizes))
if not (isinstance(dims_mapping, (list, tuple)) and all( if not (isinstance(dims_mapping, (list, tuple)) and all(
map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))): map(lambda x: isinstance(x, int) and x >= -1, dims_mapping))):
raise ValueError( raise ValueError(
"The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}". "The dims_mapping must be list or tuple and item in dims_mapping must >= -1, but got {}"
format(dims_mapping)) .format(dims_mapping))
if not (isinstance(processes, (list, tuple)) and if not (isinstance(processes, (list, tuple)) and all(
all(map(lambda x: isinstance(x, int) and x >= 0, processes))): map(lambda x: isinstance(x, int) and x >= 0, processes))):
raise ValueError( raise ValueError(
"The processes must be list or tuple and item in processes must be integer, but got {}". "The processes must be list or tuple and item in processes must be integer, but got {}"
format(processes)) .format(processes))
if not (isinstance(topology, (list, tuple)) and if not (isinstance(topology, (list, tuple))
all(map(lambda x: isinstance(x, int) and x > 0, topology))): and all(map(lambda x: isinstance(x, int) and x > 0, topology))):
raise ValueError( raise ValueError(
"The topology must be list or tuple and item in topology must be non-negative integer, but got {}". "The topology must be list or tuple and item in topology must be non-negative integer, but got {}"
format(topology)) .format(topology))
if rank is not None and not (isinstance(rank, int) and rank >= 0): if rank is not None and not (isinstance(rank, int) and rank >= 0):
raise ValueError("The rank must >= 0, but got {}".format(rank)) raise ValueError("The rank must >= 0, but got {}".format(rank))
...@@ -74,8 +74,10 @@ class DistributedTensor: ...@@ -74,8 +74,10 @@ class DistributedTensor:
processes, processes,
rank=None, rank=None,
shard_sizes=None): shard_sizes=None):
DistributedTensor._validate_sizes_and_dist_attr( DistributedTensor._validate_sizes_and_dist_attr(global_sizes,
global_sizes, dims_mapping, topology, processes, rank, shard_sizes) dims_mapping, topology,
processes, rank,
shard_sizes)
local_sizes = [] local_sizes = []
# for even sharding, the local sizes of every rank are equal # for even sharding, the local sizes of every rank are equal
...@@ -97,8 +99,10 @@ class DistributedTensor: ...@@ -97,8 +99,10 @@ class DistributedTensor:
processes, processes,
rank, rank,
shard_sizes=None): shard_sizes=None):
local_sizes = DistributedTensor.get_local_sizes( local_sizes = DistributedTensor.get_local_sizes(global_sizes,
global_sizes, dims_mapping, topology, processes, rank, shard_sizes) dims_mapping, topology,
processes, rank,
shard_sizes)
local_offsets = [] local_offsets = []
rank_relatvie = processes.index(rank) rank_relatvie = processes.index(rank)
coordinate = _linear_idx2coordinate(topology, rank_relatvie) coordinate = _linear_idx2coordinate(topology, rank_relatvie)
...@@ -118,8 +122,10 @@ class DistributedTensor: ...@@ -118,8 +122,10 @@ class DistributedTensor:
processes, processes,
rank=None, rank=None,
shard_sizes=None): shard_sizes=None):
DistributedTensor._validate_sizes_and_dist_attr( DistributedTensor._validate_sizes_and_dist_attr(local_sizes,
local_sizes, dims_mapping, topology, processes, rank, shard_sizes) dims_mapping, topology,
processes, rank,
shard_sizes)
global_sizes = [] global_sizes = []
for idx, item in enumerate(local_sizes): for idx, item in enumerate(local_sizes):
if dims_mapping[idx] == -1: if dims_mapping[idx] == -1:
...@@ -137,8 +143,10 @@ class DistributedTensor: ...@@ -137,8 +143,10 @@ class DistributedTensor:
shard_sizes=None): shard_sizes=None):
local_offsets = DistributedTensor.get_local_offsets( local_offsets = DistributedTensor.get_local_offsets(
global_sizes, dims_mapping, topology, processes, rank, shard_sizes) global_sizes, dims_mapping, topology, processes, rank, shard_sizes)
local_sizes = DistributedTensor.get_local_sizes( local_sizes = DistributedTensor.get_local_sizes(global_sizes,
global_sizes, dims_mapping, topology, processes, rank, shard_sizes) dims_mapping, topology,
processes, rank,
shard_sizes)
assert len(local_sizes) == len( assert len(local_sizes) == len(
local_offsets local_offsets
), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format( ), "The length of local_sizes must be equal to local_offsets, but got {} and {}.".format(
......
...@@ -48,6 +48,7 @@ from .dist_context import DistributedContext, get_default_distributed_context ...@@ -48,6 +48,7 @@ from .dist_context import DistributedContext, get_default_distributed_context
class Engine: class Engine:
def __init__(self, def __init__(self,
model=None, model=None,
inputs_spec=None, inputs_spec=None,
...@@ -88,8 +89,9 @@ class Engine: ...@@ -88,8 +89,9 @@ class Engine:
gradient_scale=True, gradient_scale=True,
metrics=None, metrics=None,
all_ranks=False): all_ranks=False):
if optimizer and not isinstance(optimizer, ( if optimizer and not isinstance(
paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)): optimizer,
(paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError( raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \ "'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`." " or `paddle.fluid.optimizer.Optimizer`."
...@@ -194,7 +196,7 @@ class Engine: ...@@ -194,7 +196,7 @@ class Engine:
parallelizer.parallel_all() parallelizer.parallel_all()
def _init_dist_context(self, mode): def _init_dist_context(self, mode):
# Init dist_context['mode'] with the first planned dist_context # Init dist_context['mode'] with the first planned dist_context
# to guarantee that train/eval/predict mode have same parallel strategy # to guarantee that train/eval/predict mode have same parallel strategy
dist_context = self._dist_contexts[mode] dist_context = self._dist_contexts[mode]
origin_main_prog = dist_context._original_serial_main_program origin_main_prog = dist_context._original_serial_main_program
...@@ -212,7 +214,7 @@ class Engine: ...@@ -212,7 +214,7 @@ class Engine:
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode): def _initialize(self, mode):
# Get the current content from the distributed context # Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[ self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program mode].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[ self._serial_startup_progs[mode] = self._dist_contexts[
...@@ -380,7 +382,7 @@ class Engine: ...@@ -380,7 +382,7 @@ class Engine:
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list from dist_program, then insert dataloader op # NOTE: Get feed_list from dist_program, then insert dataloader op
# with sharded var shape. Because predict_program does not contain # with sharded var shape. Because predict_program does not contain
# labels var, so we will filter dataset's value with length of feed_list. # labels var, so we will filter dataset's value with length of feed_list.
inputs_var = self._feed_vars[self.mode]["inputs"] inputs_var = self._feed_vars[self.mode]["inputs"]
...@@ -389,8 +391,8 @@ class Engine: ...@@ -389,8 +391,8 @@ class Engine:
for var in inputs_var + labels_var: for var in inputs_var + labels_var:
if var.name in dist_main_block.vars: if var.name in dist_main_block.vars:
feed_list.append(dist_main_block.vars[var.name]) feed_list.append(dist_main_block.vars[var.name])
dp_world_size, dp_rank = self._get_data_parallel_info(feed_list[0], dp_world_size, dp_rank = self._get_data_parallel_info(
dist_context) feed_list[0], dist_context)
# remove the first three ops if multi run fit/evaluate/predict # remove the first three ops if multi run fit/evaluate/predict
op_size = len(dist_main_block.ops) op_size = len(dist_main_block.ops)
...@@ -418,8 +420,9 @@ class Engine: ...@@ -418,8 +420,9 @@ class Engine:
op = dist_main_block.ops[new_op_size - 1] op = dist_main_block.ops[new_op_size - 1]
new_op_desc = dist_main_block.desc._prepend_op() new_op_desc = dist_main_block.desc._prepend_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op = Operator( new_op = Operator(dist_main_block,
dist_main_block, new_op_desc, type=new_op_desc.type()) new_op_desc,
type=new_op_desc.type())
dist_main_block.ops.insert(0, new_op) dist_main_block.ops.insert(0, new_op)
dist_op = DistributedOperator(new_op) dist_op = DistributedOperator(new_op)
dist_context.add_dist_op_for_program(dist_op) dist_context.add_dist_op_for_program(dist_op)
...@@ -442,21 +445,21 @@ class Engine: ...@@ -442,21 +445,21 @@ class Engine:
def _set_data_parallel(self, var): def _set_data_parallel(self, var):
if self._nranks == 1: if self._nranks == 1:
self._default_strategy = 'serial' self._default_strategy = 'serial'
auto.shard_tensor( auto.shard_tensor(var,
var, dist_attr={
dist_attr={ "process_mesh": [0],
"process_mesh": [0], "dims_mapping":
"dims_mapping": [-1 for _ in range(len(var.shape))] [-1 for _ in range(len(var.shape))]
}) })
else: else:
self._default_strategy = 'dp' self._default_strategy = 'dp'
auto.shard_tensor( auto.shard_tensor(var,
var, dist_attr={
dist_attr={ "process_mesh":
"process_mesh": list(range(self._nranks)), list(range(self._nranks)),
"dims_mapping": "dims_mapping":
[0] + [-1 for _ in range(len(var.shape) - 1)] [0] + [-1 for _ in range(len(var.shape) - 1)]
}) })
return var return var
...@@ -492,22 +495,20 @@ class Engine: ...@@ -492,22 +495,20 @@ class Engine:
serial_program = self._serial_main_progs["train"] serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"] dist_context = self._dist_contexts["train"]
self._saver.save( self._saver.save(path,
path, serial_program=serial_program,
serial_program=serial_program, dist_main_program=dist_main_prog,
dist_main_program=dist_main_prog, dist_context=dist_context)
dist_context=dist_context)
else: else:
assert mode, "Please set the 'mode' you want to save." assert mode, "Please set the 'mode' you want to save."
feed_vars = self._feed_vars[mode]['inputs'] feed_vars = self._feed_vars[mode]['inputs']
fetch_vars = self._fetch_vars[mode]['outputs'] fetch_vars = self._fetch_vars[mode]['outputs']
dist_main_prog = self._dist_main_progs[mode][self._cur_rank] dist_main_prog = self._dist_main_progs[mode][self._cur_rank]
self._saver.save_inference_model( self._saver.save_inference_model(path,
path, feed_vars,
feed_vars, fetch_vars,
fetch_vars, self._executor,
self._executor, program=dist_main_prog)
program=dist_main_prog)
def load(self, path, strict=True, load_optimizer=True, mode=None): def load(self, path, strict=True, load_optimizer=True, mode=None):
if not mode: if not mode:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
class Node: class Node:
def __init__(self, id, **attrs): def __init__(self, id, **attrs):
# Each node must has a unique id # Each node must has a unique id
self._id = id self._id = id
...@@ -47,6 +48,7 @@ class Node: ...@@ -47,6 +48,7 @@ class Node:
class Edge: class Edge:
def __init__(self, src_id, tgt_id, **attrs): def __init__(self, src_id, tgt_id, **attrs):
# The id of source node in an Edge # The id of source node in an Edge
self._src_id = src_id self._src_id = src_id
...@@ -88,6 +90,7 @@ class Edge: ...@@ -88,6 +90,7 @@ class Edge:
class Graph: class Graph:
def __init__(self, **attrs): def __init__(self, **attrs):
# _nodes is dict for storing the nodes of the graph. # _nodes is dict for storing the nodes of the graph.
# The key of this dict is the node id. # The key of this dict is the node id.
......
...@@ -171,8 +171,9 @@ def build_process_graph(distributed_program): ...@@ -171,8 +171,9 @@ def build_process_graph(distributed_program):
src_info, src_rank) src_info, src_rank)
graph.add_node(src_rank, resource_requirements=resource_requirements) graph.add_node(src_rank, resource_requirements=resource_requirements)
for tgt_rank, comm_requirements in comm_requirements_to_ranks.items(): for tgt_rank, comm_requirements in comm_requirements_to_ranks.items():
graph.add_edge( graph.add_edge(src_rank,
src_rank, tgt_rank, comm_requirements=comm_requirements) tgt_rank,
comm_requirements=comm_requirements)
return graph return graph
...@@ -192,8 +193,9 @@ def build_cluster_graph(cluster): ...@@ -192,8 +193,9 @@ def build_cluster_graph(cluster):
else: else:
graph.nodes[device.global_id]["occupied"] = False graph.nodes[device.global_id]["occupied"] = False
for link in machine.links.values(): for link in machine.links.values():
graph.add_edge( graph.add_edge(link.source.global_id,
link.source.global_id, link.target.global_id, link=link) link.target.global_id,
link=link)
return graph return graph
...@@ -233,8 +235,8 @@ def mapping(distributed_program, cluster): ...@@ -233,8 +235,8 @@ def mapping(distributed_program, cluster):
device_type = cur_rank_node["resource_requirements"]["device_type"] device_type = cur_rank_node["resource_requirements"]["device_type"]
cur_device_node = None cur_device_node = None
for device_node in cluster_graph.nodes.values(): for device_node in cluster_graph.nodes.values():
if (device_node["device"].type == device_type) and ( if (device_node["device"].type
not device_node["occupied"]): == device_type) and (not device_node["occupied"]):
device_node["occupied"] = True device_node["occupied"] = True
cur_rank_node["visited"] = True cur_rank_node["visited"] = True
cur_rank_node["device"] = device_node["device"] cur_rank_node["device"] = device_node["device"]
...@@ -257,8 +259,8 @@ def mapping(distributed_program, cluster): ...@@ -257,8 +259,8 @@ def mapping(distributed_program, cluster):
nbr_device_edges.sort(key=sort_by_comm_bandwidth) nbr_device_edges.sort(key=sort_by_comm_bandwidth)
for nbr_rank_edge in nbr_rank_edges: for nbr_rank_edge in nbr_rank_edges:
src_rank_node = process_graph.nodes[nbr_rank_edge.src_id][ src_rank_node = process_graph.nodes[
"visited"] nbr_rank_edge.src_id]["visited"]
if src_rank_node: if src_rank_node:
continue continue
device_type = src_rank_node["resource_requirements"][ device_type = src_rank_node["resource_requirements"][
......
...@@ -32,6 +32,7 @@ def is_elementwise_op(op_type): ...@@ -32,6 +32,7 @@ def is_elementwise_op(op_type):
class DistributedOperatorImplContainer: class DistributedOperatorImplContainer:
def __init__(self, op_type): def __init__(self, op_type):
self._type = op_type self._type = op_type
self._impls = [] self._impls = []
...@@ -81,6 +82,7 @@ class DistributedOperatorImplContainer: ...@@ -81,6 +82,7 @@ class DistributedOperatorImplContainer:
class DistributedOperatorImpl(abc.ABC): class DistributedOperatorImpl(abc.ABC):
def __init__(self, name): def __init__(self, name):
self._name = name self._name = name
self._type = None self._type = None
......
...@@ -30,6 +30,7 @@ world_process_group = get_world_process_group() ...@@ -30,6 +30,7 @@ world_process_group = get_world_process_group()
class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer): class DistributedCheckFiniteAndUnscale(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedCheckFiniteAndUnscale, self).__init__(op_type) super(DistributedCheckFiniteAndUnscale, self).__init__(op_type)
...@@ -39,6 +40,7 @@ register_distributed_operator_impl_container( ...@@ -39,6 +40,7 @@ register_distributed_operator_impl_container(
class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name) super(DistributedCheckFiniteAndUnscaleImpl, self).__init__(name)
self._name = name self._name = name
...@@ -122,41 +124,37 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -122,41 +124,37 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
group = new_process_group(world_process_group.ranks) group = new_process_group(world_process_group.ranks)
inf_var = main_block.var(kwargs['FoundInfinite'][0]) inf_var = main_block.var(kwargs['FoundInfinite'][0])
inf_var_int32 = main_block.create_var( inf_var_int32 = main_block.create_var(name=inf_var.name + "@cast_int32",
name=inf_var.name + "@cast_int32", shape=inf_var.shape,
shape=inf_var.shape, dtype=core.VarDesc.VarType.INT32)
dtype=core.VarDesc.VarType.INT32)
set_var_dist_attr( set_var_dist_attr(
ctx, inf_var_int32, ctx, inf_var_int32,
ctx.get_tensor_dist_attr_for_program(inf_var).dims_mapping, ctx.get_tensor_dist_attr_for_program(inf_var).dims_mapping,
ctx.get_tensor_dist_attr_for_program(inf_var).process_mesh) ctx.get_tensor_dist_attr_for_program(inf_var).process_mesh)
cast_op1 = main_block.append_op( cast_op1 = main_block.append_op(type='cast',
type='cast', inputs={'X': inf_var},
inputs={'X': inf_var}, outputs={'Out': inf_var_int32},
outputs={'Out': inf_var_int32}, attrs={
attrs={ "in_dtype": inf_var.dtype,
"in_dtype": inf_var.dtype, "out_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_int32.dtype, OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Backward })
}) allreduce_op = main_block.append_op(type='c_allreduce_max',
allreduce_op = main_block.append_op( inputs={'X': inf_var_int32},
type='c_allreduce_max', outputs={'Out': inf_var_int32},
inputs={'X': inf_var_int32}, attrs={
outputs={'Out': inf_var_int32}, 'ring_id': group.id,
attrs={ 'use_calc_stream': True,
'ring_id': group.id, OP_ROLE_KEY: OpRole.Backward
'use_calc_stream': True, })
OP_ROLE_KEY: OpRole.Backward cast_op2 = main_block.append_op(type='cast',
}) inputs={'X': inf_var_int32},
cast_op2 = main_block.append_op( outputs={'Out': inf_var},
type='cast', attrs={
inputs={'X': inf_var_int32}, "in_dtype": inf_var_int32.dtype,
outputs={'Out': inf_var}, "out_dtype": inf_var.dtype,
attrs={ OP_ROLE_KEY: OpRole.Backward
"in_dtype": inf_var_int32.dtype, })
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Backward
})
main_block._sync_with_cpp() main_block._sync_with_cpp()
for op in [cast_op1, allreduce_op, cast_op2]: for op in [cast_op1, allreduce_op, cast_op2]:
......
...@@ -47,28 +47,26 @@ def prim_operator_data_parallel_functor(ctx, src_op): ...@@ -47,28 +47,26 @@ def prim_operator_data_parallel_functor(ctx, src_op):
ctx.synced_gradient.add(var_name) ctx.synced_gradient.add(var_name)
sync_group = new_process_group(ctx.data_parallel_group) sync_group = new_process_group(ctx.data_parallel_group)
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(type='c_allreduce_sum',
type='c_allreduce_sum', inputs={'X': [var_name]},
inputs={'X': [var_name]}, outputs={'Out': [var_name]},
outputs={'Out': [var_name]}, attrs={
attrs={ 'ring_id': sync_group.id,
'ring_id': sync_group.id, 'use_calc_stream': True,
'use_calc_stream': True, OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Backward })
})
param = ctx.grads_params[var_name] param = ctx.grads_params[var_name]
startup_block = dist_op_context.startup_block startup_block = dist_op_context.startup_block
new_op = startup_block.append_op( new_op = startup_block.append_op(type='c_broadcast',
type='c_broadcast', inputs={'X': [param]},
inputs={'X': [param]}, outputs={'Out': [param]},
outputs={'Out': [param]}, attrs={
attrs={ 'ring_id': sync_group.id,
'ring_id': sync_group.id, 'root': 0,
'root': 0, 'use_calc_stream': True,
'use_calc_stream': True, OP_ROLE_KEY: OpRole.Forward
OP_ROLE_KEY: OpRole.Forward })
})
grad_var = main_block.var(var_name) grad_var = main_block.var(var_name)
dims_mapping = ctx.get_tensor_dist_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
...@@ -85,6 +83,7 @@ def prim_operator_data_parallel_functor(ctx, src_op): ...@@ -85,6 +83,7 @@ def prim_operator_data_parallel_functor(ctx, src_op):
class DistributedDefault(DistributedOperatorImplContainer): class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedDefault, self).__init__(op_type) super(DistributedDefault, self).__init__(op_type)
...@@ -94,6 +93,7 @@ register_distributed_operator_impl_container(DistributedDefault("default")) ...@@ -94,6 +93,7 @@ register_distributed_operator_impl_container(DistributedDefault("default"))
# Replicated Default # Replicated Default
class DistributedDefaultImpl0(DistributedOperatorImpl): class DistributedDefaultImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedDefaultImpl0, self).__init__(name) super(DistributedDefaultImpl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -277,8 +277,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -277,8 +277,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
batch_dim_mappings.append(dims_mapping[1]) batch_dim_mappings.append(dims_mapping[1])
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like": if op_desc.type() == "fill_zeros_like":
input_tensor = dist_op.get_serial_input(op_desc.input_arg_names( input_tensor = dist_op.get_serial_input(
)[0]) op_desc.input_arg_names()[0])
if input_tensor.is_parameter: if input_tensor.is_parameter:
continue continue
serial_tensor = dist_op.get_serial_output(arg_name) serial_tensor = dist_op.get_serial_output(arg_name)
...@@ -316,8 +316,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -316,8 +316,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
changed = True changed = True
for arg_name in op_desc.output_arg_names(): for arg_name in op_desc.output_arg_names():
if op_desc.type() == "fill_zeros_like": if op_desc.type() == "fill_zeros_like":
input_tensor = dist_op.get_serial_input(op_desc.input_arg_names( input_tensor = dist_op.get_serial_input(
)[0]) op_desc.input_arg_names()[0])
if input_tensor.is_parameter: if input_tensor.is_parameter:
continue continue
if op_desc.type() in ["shape", "slice"]: if op_desc.type() in ["shape", "slice"]:
...@@ -409,16 +409,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -409,16 +409,19 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
axis, rank_id) axis, rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
new_op = startup_block.append_op( new_op = startup_block.append_op(type='c_broadcast',
type='c_broadcast', inputs={'X': param},
inputs={'X': param}, outputs={'Out': param},
outputs={'Out': param}, attrs={
attrs={ 'ring_id':
'ring_id': sync_group.id, sync_group.id,
'root': 0, 'root':
'use_calc_stream': True, 0,
OP_ROLE_KEY: OpRole.Forward 'use_calc_stream':
}) True,
OP_ROLE_KEY:
OpRole.Forward
})
# set distributed attribute # set distributed attribute
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistributedAttribute()
...@@ -484,8 +487,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -484,8 +487,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism
if rank_id not in process_mesh.processes: if rank_id not in process_mesh.processes:
rank_id = _get_corresponding_rank(ctx, process_mesh, rank_id = _get_corresponding_rank(
rank_id) ctx, process_mesh, rank_id)
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0]
......
...@@ -35,6 +35,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -35,6 +35,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedElementwise(DistributedOperatorImplContainer): class DistributedElementwise(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedElementwise, self).__init__(op_type) super(DistributedElementwise, self).__init__(op_type)
...@@ -45,6 +46,7 @@ register_distributed_operator_impl_container( ...@@ -45,6 +46,7 @@ register_distributed_operator_impl_container(
# Replicated Elementwise # Replicated Elementwise
class DistributedElementwiseImpl0(DistributedOperatorImpl): class DistributedElementwiseImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedElementwiseImpl0, self).__init__(name) super(DistributedElementwiseImpl0, self).__init__(name)
self._forward_implemented = False self._forward_implemented = False
...@@ -208,8 +210,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -208,8 +210,8 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
changed = True changed = True
else: else:
if compatible_dims_mapping != input_dims_mapping_dict[arg_name]: if compatible_dims_mapping != input_dims_mapping_dict[arg_name]:
op_dist_attr.set_input_dims_mapping(arg_name, op_dist_attr.set_input_dims_mapping(
compatible_dims_mapping) arg_name, compatible_dims_mapping)
changed = True changed = True
for arg_name in output_arg_names: for arg_name in output_arg_names:
...@@ -222,12 +224,11 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -222,12 +224,11 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
output_dims_mapping_lens[arg_name]) + i output_dims_mapping_lens[arg_name]) + i
new_dims_mapping[i] = compatible_dims_mapping[new_idx] new_dims_mapping[i] = compatible_dims_mapping[new_idx]
if new_dims_mapping != output_dims_mapping_dict[arg_name]: if new_dims_mapping != output_dims_mapping_dict[arg_name]:
op_dist_attr.set_output_dims_mapping(arg_name, op_dist_attr.set_output_dims_mapping(
new_dims_mapping) arg_name, new_dims_mapping)
changed = True changed = True
else: else:
if compatible_dims_mapping != output_dims_mapping_dict[ if compatible_dims_mapping != output_dims_mapping_dict[arg_name]:
arg_name]:
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
arg_name, compatible_dims_mapping) arg_name, compatible_dims_mapping)
changed = True changed = True
......
...@@ -34,6 +34,7 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank ...@@ -34,6 +34,7 @@ from ..utils import _get_comm_group, _get_idx_in_axis, _get_corresponding_rank
class DistributedEmbedding(DistributedOperatorImplContainer): class DistributedEmbedding(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedEmbedding, self).__init__(op_type) super(DistributedEmbedding, self).__init__(op_type)
...@@ -46,6 +47,7 @@ register_distributed_operator_impl_container( ...@@ -46,6 +47,7 @@ register_distributed_operator_impl_container(
# RowParallel # RowParallel
class DistributedEmbeddingImpl(DistributedOperatorImpl): class DistributedEmbeddingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedEmbeddingImpl, self).__init__(name) super(DistributedEmbeddingImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -58,8 +60,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -58,8 +60,8 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
w_name = op_desc.input('W')[0] w_name = op_desc.input('W')[0]
ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name) ids_dims_mapping = op_dist_attr.get_input_dims_mapping(ids_name)
w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name) w_dims_mapping = op_dist_attr.get_input_dims_mapping(w_name)
if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(w_dims_mapping[ if is_dim_replicate(w_dims_mapping[-2]) or is_dim_shard(
-1]): w_dims_mapping[-1]):
return False return False
# Other dimensions must be replicate except the batch dimension # Other dimensions must be replicate except the batch dimension
for mapping in ids_dims_mapping[1:]: for mapping in ids_dims_mapping[1:]:
...@@ -215,8 +217,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -215,8 +217,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
c_embedding_op = main_block.append_op( c_embedding_op = main_block.append_op(
type='c_embedding', type='c_embedding',
inputs={'Ids': [Ids_var], inputs={
'W': [Weight_var]}, 'Ids': [Ids_var],
'W': [Weight_var]
},
outputs={'Out': [intermediate_var_0]}, outputs={'Out': [intermediate_var_0]},
attrs={"start_index": relative_idx}) attrs={"start_index": relative_idx})
if intermediate_var_0.shape != ref_shape: if intermediate_var_0.shape != ref_shape:
...@@ -295,16 +299,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -295,16 +299,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
rank_id) rank_id)
sync_group = new_process_group(group_ranks) sync_group = new_process_group(group_ranks)
startup_block.append_op( startup_block.append_op(type='c_broadcast',
type='c_broadcast', inputs={'X': param},
inputs={'X': param}, outputs={'Out': param},
outputs={'Out': param}, attrs={
attrs={ 'ring_id': sync_group.id,
'ring_id': sync_group.id, 'root': 0,
'root': 0, 'use_calc_stream': True,
'use_calc_stream': True, OP_ROLE_KEY: OpRole.Forward
OP_ROLE_KEY: OpRole.Forward })
})
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
@staticmethod @staticmethod
...@@ -440,21 +443,21 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -440,21 +443,21 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
if need_gradient_allreduce: if need_gradient_allreduce:
W_Grad_var = main_block.var(kwargs['W@GRAD'][0]) W_Grad_var = main_block.var(kwargs['W@GRAD'][0])
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(type='c_allreduce_sum',
type='c_allreduce_sum', inputs={'X': [W_Grad_var]},
inputs={'X': [W_Grad_var]}, outputs={'Out': [W_Grad_var]},
outputs={'Out': [W_Grad_var]}, attrs={
attrs={ 'ring_id': dp_group.id,
'ring_id': dp_group.id, 'use_calc_stream': True,
'use_calc_stream': True, OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Backward })
}) scale_op = main_block.append_op(type='scale',
scale_op = main_block.append_op( inputs={'X': W_Grad_var},
type='scale', outputs={'Out': W_Grad_var},
inputs={'X': W_Grad_var}, attrs={
outputs={'Out': W_Grad_var}, 'scale': 1.0 / dp_degree,
attrs={'scale': 1.0 / dp_degree, OP_ROLE_KEY: OpRole.Backward
OP_ROLE_KEY: OpRole.Backward}) })
main_block._sync_with_cpp() main_block._sync_with_cpp()
dims_mapping = ctx.get_tensor_dist_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
......
...@@ -31,6 +31,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -31,6 +31,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer): class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedFillConstantBatchSizeLike, self).__init__(op_type) super(DistributedFillConstantBatchSizeLike, self).__init__(op_type)
...@@ -40,6 +41,7 @@ register_distributed_operator_impl_container( ...@@ -40,6 +41,7 @@ register_distributed_operator_impl_container(
class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedFillConstantBatchSizeLikeImpl0, self).__init__(name) super(DistributedFillConstantBatchSizeLikeImpl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -27,6 +27,7 @@ from ..process_group import new_process_group ...@@ -27,6 +27,7 @@ from ..process_group import new_process_group
class DistributedFusedAttention(DistributedOperatorImplContainer): class DistributedFusedAttention(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedFusedAttention, self).__init__(op_type) super(DistributedFusedAttention, self).__init__(op_type)
...@@ -36,6 +37,7 @@ register_distributed_operator_impl_container( ...@@ -36,6 +37,7 @@ register_distributed_operator_impl_container(
class DistributedFusedAttentionImpl(DistributedOperatorImpl): class DistributedFusedAttentionImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedFusedAttentionImpl, self).__init__(name) super(DistributedFusedAttentionImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -60,8 +62,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -60,8 +62,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
for mapping in x_dims_mapping[1:-1]: for mapping in x_dims_mapping[1:-1]:
if is_dim_shard(mapping): if is_dim_shard(mapping):
return False return False
if len(qkv_w_dims_mapping) != 4 or is_dim_replicate(qkv_w_dims_mapping[ if len(qkv_w_dims_mapping) != 4 or is_dim_replicate(
head_axis]): qkv_w_dims_mapping[head_axis]):
return False return False
if len(qkv_bias_dims_mapping) != 3 or is_dim_replicate( if len(qkv_bias_dims_mapping) != 3 or is_dim_replicate(
qkv_bias_dims_mapping[head_axis]): qkv_bias_dims_mapping[head_axis]):
...@@ -91,7 +93,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -91,7 +93,7 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
# none of output should be sharded # none of output should be sharded
for out_name in op_desc.output_names(): for out_name in op_desc.output_names():
out = op_desc.output(out_name)[0] out = op_desc.output(out_name)[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
...@@ -152,8 +154,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl): ...@@ -152,8 +154,8 @@ class DistributedFusedAttentionImpl(DistributedOperatorImpl):
# infer logic comm presentation # infer logic comm presentation
head_axis = 1 head_axis = 1
qkv_w = src_op.input('QKVW')[0] qkv_w = src_op.input('QKVW')[0]
qkv_w_col_dim_mapping = op_dist_attr.get_input_dims_mapping(qkv_w)[ qkv_w_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
head_axis] qkv_w)[head_axis]
assert qkv_w_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format( assert qkv_w_col_dim_mapping >= 0, "col_parallel_matmul's row should be divided by a specific mesh axis, but got [{}]".format(
qkv_w_col_dim_mapping) qkv_w_col_dim_mapping)
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -27,6 +27,7 @@ from ..process_group import new_process_group ...@@ -27,6 +27,7 @@ from ..process_group import new_process_group
class DistributedFusedFeedForward(DistributedOperatorImplContainer): class DistributedFusedFeedForward(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedFusedFeedForward, self).__init__(op_type) super(DistributedFusedFeedForward, self).__init__(op_type)
...@@ -36,6 +37,7 @@ register_distributed_operator_impl_container( ...@@ -36,6 +37,7 @@ register_distributed_operator_impl_container(
class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedFusedFeedForwardImpl, self).__init__(name) super(DistributedFusedFeedForwardImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -82,7 +84,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl): ...@@ -82,7 +84,7 @@ class DistributedFusedFeedForwardImpl(DistributedOperatorImpl):
op_desc = dist_op.serial_op.desc op_desc = dist_op.serial_op.desc
op_dist_attr = dist_op.dist_attr op_dist_attr = dist_op.dist_attr
# none of output should be sharded # none of output should be sharded
for out_name in op_desc.output_names(): for out_name in op_desc.output_names():
out = op_desc.output(out_name)[0] out = op_desc.output(out_name)[0]
out_dims_mapping = op_dist_attr.get_output_dims_mapping(out) out_dims_mapping = op_dist_attr.get_output_dims_mapping(out)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype ...@@ -34,6 +34,7 @@ from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype
class DistributedPNorm(DistributedOperatorImplContainer): class DistributedPNorm(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedPNorm, self).__init__(op_type) super(DistributedPNorm, self).__init__(op_type)
...@@ -52,19 +53,21 @@ def _insert_fill_constant_op(block, op_role): ...@@ -52,19 +53,21 @@ def _insert_fill_constant_op(block, op_role):
attrs['value'] = int("1") attrs['value'] = int("1")
attrs['dtype'] = out.dtype attrs['dtype'] = out.dtype
attrs['op_role'] = op_role attrs['op_role'] = op_role
utils.get_shape_tensor_inputs( utils.get_shape_tensor_inputs(inputs=inputs,
inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant') attrs=attrs,
fill_constant_op = block.append_op( shape=[0],
type='fill_constant', op_type='fill_constant')
inputs=inputs, fill_constant_op = block.append_op(type='fill_constant',
outputs={'Out': [out]}, inputs=inputs,
attrs=attrs) outputs={'Out': [out]},
attrs=attrs)
out.stop_gradient = True out.stop_gradient = True
return out, fill_constant_op return out, fill_constant_op
# Row Parallel # Row Parallel
class DistributedPNormImpl(DistributedOperatorImpl): class DistributedPNormImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedPNormImpl, self).__init__(name) super(DistributedPNormImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -193,15 +196,14 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -193,15 +196,14 @@ class DistributedPNormImpl(DistributedOperatorImpl):
# set fill_constant op dist_attr # set fill_constant op dist_attr
constant_op_dist_attr = OperatorDistributedAttribute() constant_op_dist_attr = OperatorDistributedAttribute()
constant_op_dist_attr.process_mesh = ref_process_mesh constant_op_dist_attr.process_mesh = ref_process_mesh
constant_op_dist_attr.set_output_dims_mapping(fill_constant_out.name, constant_op_dist_attr.set_output_dims_mapping(
constant_out_dims_mapping) fill_constant_out.name, constant_out_dims_mapping)
ctx.set_op_dist_attr_for_program(fill_constant_op, ctx.set_op_dist_attr_for_program(fill_constant_op,
constant_op_dist_attr) constant_op_dist_attr)
barrier_op = main_block.append_op( barrier_op = main_block.append_op(type='barrier',
type='barrier', inputs={'X': [fill_constant_out]},
inputs={'X': [fill_constant_out]}, outputs={'Out': [fill_constant_out]},
outputs={'Out': [fill_constant_out]}, attrs={'ring_id': group.id})
attrs={'ring_id': group.id})
# set barrier op dist attr # set barrier op dist attr
set_comm_op_dist_attr_for_program(barrier_op, ref_process_mesh, set_comm_op_dist_attr_for_program(barrier_op, ref_process_mesh,
constant_out_dist_attr, ctx) constant_out_dist_attr, ctx)
...@@ -223,16 +225,16 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -223,16 +225,16 @@ class DistributedPNormImpl(DistributedOperatorImpl):
] ]
ctx.set_tensor_dist_attr_for_program(allgather_out, ctx.set_tensor_dist_attr_for_program(allgather_out,
allgather_out_dist_attr) allgather_out_dist_attr)
c_allgather_op = main_block.append_op( c_allgather_op = main_block.append_op(type='c_allgather',
type='c_allgather', inputs={'X': [X_var]},
inputs={'X': [X_var]}, outputs={'Out': [allgather_out]},
outputs={'Out': [allgather_out]}, attrs={
attrs={ 'ring_id': group.id,
'ring_id': group.id, 'use_calc_stream': True,
'use_calc_stream': True, 'nranks': group.nranks,
'nranks': group.nranks, 'op_role':
'op_role': src_op.attr('op_role') src_op.attr('op_role')
}) })
# set c_allgather op dist_attr # set c_allgather op dist_attr
allgather_op_dist_attr = OperatorDistributedAttribute() allgather_op_dist_attr = OperatorDistributedAttribute()
allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh allgather_op_dist_attr.process_mesh = op_dist_attr.process_mesh
...@@ -344,11 +346,10 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -344,11 +346,10 @@ class DistributedPNormImpl(DistributedOperatorImpl):
"infer_flags": infer_flags, "infer_flags": infer_flags,
"op_role": backward_op.attr('op_role') "op_role": backward_op.attr('op_role')
} }
slice_op = main_block.append_op( slice_op = main_block.append_op(type='slice',
type='slice', inputs={'Input': [new_X_grad]},
inputs={'Input': [new_X_grad]}, outputs={'Out': [X_grad_var]},
outputs={'Out': [X_grad_var]}, attrs=attrs)
attrs=attrs)
X_grad_var_dims_mapping = op_dist_attr.get_output_dims_mapping( X_grad_var_dims_mapping = op_dist_attr.get_output_dims_mapping(
X_grad_var.name) X_grad_var.name)
slice_op_dist_attr = OperatorDistributedAttribute() slice_op_dist_attr = OperatorDistributedAttribute()
......
...@@ -34,6 +34,7 @@ from ..utils import _get_comm_group, _get_corresponding_rank ...@@ -34,6 +34,7 @@ from ..utils import _get_comm_group, _get_corresponding_rank
class DistributedReducePrimtive(DistributedOperatorImplContainer): class DistributedReducePrimtive(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedReducePrimtive, self).__init__(op_type) super(DistributedReducePrimtive, self).__init__(op_type)
...@@ -44,6 +45,7 @@ register_distributed_operator_impl_container( ...@@ -44,6 +45,7 @@ register_distributed_operator_impl_container(
# Batch Dimension Reduce Primitive # Batch Dimension Reduce Primitive
class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReducePrimtiveImpl0, self).__init__(name) super(DistributedReducePrimtiveImpl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -119,15 +121,14 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): ...@@ -119,15 +121,14 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
# batch dimension synchronization # batch dimension synchronization
var_name = src_op.output_arg_names[0] var_name = src_op.output_arg_names[0]
sync_group = new_process_group(ctx.data_parallel_group) sync_group = new_process_group(ctx.data_parallel_group)
allreduce_op = main_block.append_op( allreduce_op = main_block.append_op(type='c_allreduce_sum',
type='c_allreduce_sum', inputs={'X': [var_name]},
inputs={'X': [var_name]}, outputs={'Out': [var_name]},
outputs={'Out': [var_name]}, attrs={
attrs={ 'ring_id': sync_group.id,
'ring_id': sync_group.id, 'use_calc_stream': True,
'use_calc_stream': True, OP_ROLE_KEY: OpRole.Forward
OP_ROLE_KEY: OpRole.Forward })
})
# dist attr # dist attr
var = main_block.var(var_name) var = main_block.var(var_name)
......
...@@ -31,6 +31,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -31,6 +31,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedReshape2(DistributedOperatorImplContainer): class DistributedReshape2(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedReshape2, self).__init__(op_type) super(DistributedReshape2, self).__init__(op_type)
...@@ -39,6 +40,7 @@ register_distributed_operator_impl_container(DistributedReshape2("reshape2")) ...@@ -39,6 +40,7 @@ register_distributed_operator_impl_container(DistributedReshape2("reshape2"))
class DistributedReshapeImpl0(DistributedOperatorImpl): class DistributedReshapeImpl0(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl0, self).__init__(name) super(DistributedReshapeImpl0, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -171,8 +173,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -171,8 +173,8 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
if axis >= 0: if axis >= 0:
if len(shape_list) > idx: if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[ shape_list[
axis] idx] = shape_list[idx] // process_mesh_shape[axis]
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.desc.append_op()
...@@ -193,6 +195,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -193,6 +195,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
class DistributedReshapeImpl1(DistributedOperatorImpl): class DistributedReshapeImpl1(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl1, self).__init__(name) super(DistributedReshapeImpl1, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -328,8 +331,8 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -328,8 +331,8 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
for idx, axis in enumerate(dim_mapping): for idx, axis in enumerate(dim_mapping):
if axis >= 0: if axis >= 0:
if len(shape_list) > idx: if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[ shape_list[
axis] idx] = shape_list[idx] // process_mesh_shape[axis]
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.desc.append_op()
...@@ -350,6 +353,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -350,6 +353,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
class DistributedReshapeImpl2(DistributedOperatorImpl): class DistributedReshapeImpl2(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedReshapeImpl2, self).__init__(name) super(DistributedReshapeImpl2, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
...@@ -478,8 +482,8 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -478,8 +482,8 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
for idx, axis in enumerate(out_dim_mapping): for idx, axis in enumerate(out_dim_mapping):
if axis >= 0: if axis >= 0:
if len(shape_list) > idx: if len(shape_list) > idx:
shape_list[idx] = shape_list[idx] // process_mesh_shape[ shape_list[
axis] idx] = shape_list[idx] // process_mesh_shape[axis]
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.desc.append_op()
......
...@@ -23,6 +23,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -23,6 +23,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedSlice(DistributedOperatorImplContainer): class DistributedSlice(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedSlice, self).__init__(op_type) super(DistributedSlice, self).__init__(op_type)
...@@ -31,6 +32,7 @@ register_distributed_operator_impl_container(DistributedSlice("slice")) ...@@ -31,6 +32,7 @@ register_distributed_operator_impl_container(DistributedSlice("slice"))
class DistributedSliceImpl(DistributedOperatorImpl): class DistributedSliceImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedSliceImpl, self).__init__(name) super(DistributedSliceImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
......
...@@ -26,6 +26,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -26,6 +26,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedSoftmax(DistributedOperatorImplContainer): class DistributedSoftmax(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedSoftmax, self).__init__(op_type) super(DistributedSoftmax, self).__init__(op_type)
...@@ -34,6 +35,7 @@ register_distributed_operator_impl_container(DistributedSoftmax("softmax")) ...@@ -34,6 +35,7 @@ register_distributed_operator_impl_container(DistributedSoftmax("softmax"))
class DistributedSoftmaxImpl(DistributedOperatorImpl): class DistributedSoftmaxImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedSoftmaxImpl, self).__init__(name) super(DistributedSoftmaxImpl, self).__init__(name)
self._forward_implemented = False self._forward_implemented = False
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...@@ -25,6 +25,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -25,6 +25,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedSplit(DistributedOperatorImplContainer): class DistributedSplit(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedSplit, self).__init__(op_type) super(DistributedSplit, self).__init__(op_type)
...@@ -33,6 +34,7 @@ register_distributed_operator_impl_container(DistributedSplit("split")) ...@@ -33,6 +34,7 @@ register_distributed_operator_impl_container(DistributedSplit("split"))
class DistributedSplitImpl(DistributedOperatorImpl): class DistributedSplitImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedSplitImpl, self).__init__(name) super(DistributedSplitImpl, self).__init__(name)
self._forward_implemented = True self._forward_implemented = True
......
...@@ -26,6 +26,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -26,6 +26,7 @@ from .dist_default import DistributedDefaultImpl0
class DistributedTranspose2(DistributedOperatorImplContainer): class DistributedTranspose2(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedTranspose2, self).__init__(op_type) super(DistributedTranspose2, self).__init__(op_type)
...@@ -35,6 +36,7 @@ register_distributed_operator_impl_container( ...@@ -35,6 +36,7 @@ register_distributed_operator_impl_container(
class DistributedTranspose2Impl(DistributedOperatorImpl): class DistributedTranspose2Impl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedTranspose2Impl, self).__init__(name) super(DistributedTranspose2Impl, self).__init__(name)
self._forward_implemented = False self._forward_implemented = False
......
...@@ -20,6 +20,7 @@ from ..utils import set_dist_op_desc_original_id ...@@ -20,6 +20,7 @@ from ..utils import set_dist_op_desc_original_id
class DistributedUpdateLossScaling(DistributedOperatorImplContainer): class DistributedUpdateLossScaling(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
super(DistributedUpdateLossScaling, self).__init__(op_type) super(DistributedUpdateLossScaling, self).__init__(op_type)
...@@ -29,6 +30,7 @@ register_distributed_operator_impl_container( ...@@ -29,6 +30,7 @@ register_distributed_operator_impl_container(
class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
def __init__(self, name): def __init__(self, name):
super(DistributedUpdateLossScalingImpl, self).__init__(name) super(DistributedUpdateLossScalingImpl, self).__init__(name)
self._name = name self._name = name
......
...@@ -108,8 +108,8 @@ class AutoParallelizer: ...@@ -108,8 +108,8 @@ class AutoParallelizer:
if config["use_pure_fp16"]: if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply([main_program], [startup_program],
[main_program], [startup_program], self._pass_context) self._pass_context)
else: else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program], auto_parallel_amp_pass.apply([main_program], [startup_program],
...@@ -123,8 +123,9 @@ class AutoParallelizer: ...@@ -123,8 +123,9 @@ class AutoParallelizer:
config["loss"] = loss config["loss"] = loss
auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
config) config)
auto_parallel_recompute_pass.apply( auto_parallel_recompute_pass.apply([main_program],
[main_program], [startup_program], self._pass_context) [startup_program],
self._pass_context)
def _generate_backward(self, main_program, startup_program, loss, def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks): parameter_list, no_grad_set, callbacks):
...@@ -144,10 +145,10 @@ class AutoParallelizer: ...@@ -144,10 +145,10 @@ class AutoParallelizer:
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(self._optimizer).apply_gradients( optimize_ops = copy.deepcopy(
params_grads) self._optimizer).apply_gradients(params_grads)
# update completion # update completion
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
...@@ -163,8 +164,8 @@ class AutoParallelizer: ...@@ -163,8 +164,8 @@ class AutoParallelizer:
config["global_rank"] = rank config["global_rank"] = rank
auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
config) config)
auto_parallel_sharding_pass.apply( auto_parallel_sharding_pass.apply([main_program], [startup_program],
[main_program], [startup_program], self._pass_context) self._pass_context)
if self._dist_strategy.gradient_merge: if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs) config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
...@@ -172,8 +173,9 @@ class AutoParallelizer: ...@@ -172,8 +173,9 @@ class AutoParallelizer:
config["params_grads"] = params_grads config["params_grads"] = params_grads
auto_parallel_gradient_merge_pass = new_pass( auto_parallel_gradient_merge_pass = new_pass(
"auto_parallel_gradient_merge_pass", config) "auto_parallel_gradient_merge_pass", config)
auto_parallel_gradient_merge_pass.apply( auto_parallel_gradient_merge_pass.apply([main_program],
[main_program], [startup_program], self._pass_context) [startup_program],
self._pass_context)
def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
completed_main_program = None completed_main_program = None
...@@ -181,7 +183,7 @@ class AutoParallelizer: ...@@ -181,7 +183,7 @@ class AutoParallelizer:
serial_startup_program = self._startup_program.clone() serial_startup_program = self._startup_program.clone()
serial_loss = serial_main_program.global_block().var(self._loss.name) serial_loss = serial_main_program.global_block().var(self._loss.name)
# generating serial # generating serial
if dist_context is None: if dist_context is None:
# Annotation completion # Annotation completion
self._dist_context = DistributedContext() self._dist_context = DistributedContext()
...@@ -205,15 +207,16 @@ class AutoParallelizer: ...@@ -205,15 +207,16 @@ class AutoParallelizer:
self._apply_pre_optimization_passes(completed_main_program, self._apply_pre_optimization_passes(completed_main_program,
serial_startup_program, serial_loss, serial_startup_program, serial_loss,
params_grads, self._no_grad_set) params_grads, self._no_grad_set)
# Logical partition # Logical partition
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
completed_main_program, serial_startup_program, params_grads) completed_main_program, serial_startup_program, params_grads)
# TODO refactor the placement of optimizer # TODO refactor the placement of optimizer
# generate optimize program # generate optimize program
dist_optimize_ops = self._apply_optimize( dist_optimize_ops = self._apply_optimize(dist_main_prog,
dist_main_prog, dist_startup_prog, dist_params_grads) dist_startup_prog,
dist_params_grads)
set_grad_var_shape(dist_main_prog, self._dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
...@@ -258,14 +261,17 @@ class AutoParallelizer: ...@@ -258,14 +261,17 @@ class AutoParallelizer:
# auto search # auto search
if self._dist_strategy.auto_search: if self._dist_strategy.auto_search:
logging.info("Start searching dist attr.") logging.info("Start searching dist attr.")
serial_program_info = SerialProgramInfo( serial_program_info = SerialProgramInfo(self._main_program,
self._main_program, self._startup_program, self._loss, self._startup_program,
self._optimizer, self._cluster) self._loss,
planner = Planner( self._optimizer,
serial_program_info, self._cluster)
self, planner = Planner(serial_program_info,
algorithm_config={"name": "mcmc", self,
"max_search_times": 5}) algorithm_config={
"name": "mcmc",
"max_search_times": 5
})
dist_context, _ = planner.search() dist_context, _ = planner.search()
logging.info("End searching dist attr.") logging.info("End searching dist attr.")
...@@ -325,8 +331,8 @@ class AutoParallelizer: ...@@ -325,8 +331,8 @@ class AutoParallelizer:
else: else:
coverage_args = [] coverage_args = []
new_cmd_args = "-m paddle.distributed.fleet.launch" + " " + rank_mapping_args + " " + original_cmd_args new_cmd_args = "-m paddle.distributed.fleet.launch" + " " + rank_mapping_args + " " + original_cmd_args
new_cmd = [sys.executable, "-u"] + coverage_args + shlex.split( new_cmd = [sys.executable, "-u"
new_cmd_args) ] + coverage_args + shlex.split(new_cmd_args)
new_process = subprocess.Popen(new_cmd) new_process = subprocess.Popen(new_cmd)
new_process.wait() new_process.wait()
assert new_process.returncode == 0, \ assert new_process.returncode == 0, \
...@@ -368,13 +374,12 @@ class AutoParallelizer: ...@@ -368,13 +374,12 @@ class AutoParallelizer:
self._loss, self._loss,
self._optimizer, self._optimizer,
cluster=self._cluster) cluster=self._cluster)
planner = Planner( planner = Planner(serial_program_info,
serial_program_info, self,
self, algorithm_config={
algorithm_config={ "name": "mcmc",
"name": "mcmc", "max_search_times": 5
"max_search_times": 5 })
})
dist_context, _ = planner.search() dist_context, _ = planner.search()
# rebuild g_process_group # rebuild g_process_group
......
...@@ -20,6 +20,7 @@ from .utils import print_program_with_dist_attr ...@@ -20,6 +20,7 @@ from .utils import print_program_with_dist_attr
class Planner: class Planner:
def __init__(self, mode, dist_context): def __init__(self, mode, dist_context):
self._mode = mode self._mode = mode
self._dist_context = dist_context self._dist_context = dist_context
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册