提交 08f927de 编写于 作者: T Tink_Y 提交者: Cheerego

cherry-pick API reference for release1.2 (#14750)

* Add examples to some functions. (#14645)

* Fix comments of ctc_greedy_decoder. (#14679)

test=develop

* fix api format and examples

test=develop

* Update executor.py

test=develop

* Update nn.py

* Update nn.py

test=develop

* Update nn.py

test=develop

* Update clip.py

test=release1.2
上级 8feb99b4
...@@ -134,12 +134,12 @@ class GradientClipByValue(BaseGradientClipAttr): ...@@ -134,12 +134,12 @@ class GradientClipByValue(BaseGradientClipAttr):
Examples: Examples:
.. code-block:: python .. code-block:: python
w_param_attrs = ParamAttr(name=None, w_param_attrs = fluid.ParamAttr(name=None,
initializer=UniformInitializer(low=-1.0, high=1.0, seed=0), initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0),
learning_rate=1.0, learning_rate=1.0,
regularizer=L1Decay(1.0), regularizer=fluid.regularizer.L1Decay(1.0),
trainable=True, trainable=True,
clip=GradientClipByValue(-1.0, 1.0)) clip=fluid.clip.GradientClipByValue(-1.0, 1.0))
y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs) y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
""" """
...@@ -185,12 +185,12 @@ class GradientClipByNorm(BaseGradientClipAttr): ...@@ -185,12 +185,12 @@ class GradientClipByNorm(BaseGradientClipAttr):
Examples: Examples:
.. code-block:: python .. code-block:: python
w_param_attrs = ParamAttr(name=None, w_param_attrs = fluid.ParamAttr(name=None,
initializer=UniformInitializer(low=-1.0, high=1.0, seed=0), initializer=fluid.initializer.UniformInitializer(low=-1.0, high=1.0, seed=0),
learning_rate=1.0, learning_rate=1.0,
regularizer=L1Decay(1.0), regularizer=fluid.regularizer.L1Decay(1.0),
trainable=True, trainable=True,
clip=GradientClipByNorm(clip_norm=2.0)) clip=fluid.clip.GradientClipByNorm(clip_norm=2.0))
y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs) y_predict = fluid.layers.fc(input=x, size=1, param_attr=w_param_attrs)
""" """
......
...@@ -20,7 +20,7 @@ import six ...@@ -20,7 +20,7 @@ import six
from .framework import Program, default_main_program, Variable from .framework import Program, default_main_program, Variable
from . import core from . import core
__all__ = ['Executor', 'global_scope', 'scope_guard', '_switch_scope'] __all__ = ['Executor', 'global_scope', 'scope_guard']
g_scope = core.Scope() g_scope = core.Scope()
...@@ -407,16 +407,17 @@ class Executor(object): ...@@ -407,16 +407,17 @@ class Executor(object):
Examples: Examples:
>>> data = layers.data(name='X', shape=[1], dtype='float32') >>> data = fluid.layers.data(name='X', shape=[1], dtype='float32')
>>> hidden = layers.fc(input=data, size=10) >>> out = fluid.layers.create_tensor(dtype='float32')
>>> layers.assign(hidden, out) >>> hidden = fluid.layers.fc(input=data, size=10)
>>> loss = layers.mean(out) >>> fluid.layers.assign(hidden,out)
>>> loss = fluid.layers.mean(out)
>>> adam = fluid.optimizer.Adam() >>> adam = fluid.optimizer.Adam()
>>> adam.minimize(loss) >>> adam.minimize(loss)
>>> cpu = core.CPUPlace() >>> cpu = core.CPUPlace()
>>> exe = Executor(cpu) >>> exe = fluid.Executor(cpu)
>>> exe.run(default_startup_program()) >>> exe.run(fluid.default_startup_program())
>>> x = numpy.random.random(size=(10, 1)).astype('float32') >>> x = numpy.random.random(size=(10, 1)).astype('float32')
>>> outs = exe.run( >>> outs = exe.run(
......
...@@ -89,12 +89,13 @@ def name_scope(prefix=None): ...@@ -89,12 +89,13 @@ def name_scope(prefix=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
with name_scope("encoder"): with name_scope("encoder"):
... ...
with name_scope("decoder"): with name_scope("decoder"):
... ...
with name_scope("attention"): with name_scope("attention"):
... ...
""" """
# TODO(panyx0718): Only [0-9a-z]. # TODO(panyx0718): Only [0-9a-z].
assert prefix, "namescope prefix cannot be empty." assert prefix, "namescope prefix cannot be empty."
......
...@@ -943,7 +943,18 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None): ...@@ -943,7 +943,18 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
def shuffle(reader, buffer_size): def shuffle(reader, buffer_size):
""" """
Shuffle the reader. Creates a data reader whose data output is shuffled.
Output from the iterator that created by original reader will be
buffered into shuffle buffer, and then shuffled. The size of shuffle buffer
is determined by argument buf_size.
Args:
param reader: the original reader whose output will be shuffled.
type reader: callable
param buf_size: shuffle buffer size.
type buf_size: int
return: the new reader whose output is shuffled.
rtype: callable
""" """
return __create_unshared_decorated_reader__( return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)}) 'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
......
...@@ -308,13 +308,9 @@ def piecewise_decay(boundaries, values): ...@@ -308,13 +308,9 @@ def piecewise_decay(boundaries, values):
def append_LARS(params_grads, learning_rate, weight_decay): def append_LARS(params_grads, learning_rate, weight_decay):
"""Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for """
each layer. Applies LARS (LAYER-WISE ADAPTIVE RATE SCALING) to learning rate for
each layer.
```python
learning_rate *= local_gw_ratio * sqrt(sumsq(param))
/ (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param)))
```
Args: Args:
learning_rate: A learning rate Variable. This learning_rate: A learning rate Variable. This
...@@ -323,6 +319,11 @@ def append_LARS(params_grads, learning_rate, weight_decay): ...@@ -323,6 +319,11 @@ def append_LARS(params_grads, learning_rate, weight_decay):
Returns: Returns:
The decayed learning rate The decayed learning rate
Examples:
.. code-block:: python
learning_rate *= local_gw_ratio * sqrt(sumsq(param))
/ (sqrt(sumsq(gradient))+ weight_decay * sqrt(sumsq(param)))
""" """
def _balanced_weight(param_norm, grad_norm): def _balanced_weight(param_norm, grad_norm):
......
...@@ -928,7 +928,7 @@ def dynamic_gru(input, ...@@ -928,7 +928,7 @@ def dynamic_gru(input,
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim]) emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512 hidden_dim = 512
x = fluid.layers.fc(input=emb, size=hidden_dim * 3) x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim) hidden = fluid.layers.dynamic_gru(input=x, size=hidden_dim)
""" """
helper = LayerHelper('gru', **locals()) helper = LayerHelper('gru', **locals())
...@@ -3560,6 +3560,7 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None): ...@@ -3560,6 +3560,7 @@ def beam_search_decode(ids, scores, beam_size, end_id, name=None):
Examples: Examples:
.. code-block:: python .. code-block:: python
# Suppose `ids` and `scores` are LodTensorArray variables reserving # Suppose `ids` and `scores` are LodTensorArray variables reserving
# the selected ids and scores of all steps # the selected ids and scores of all steps
finished_ids, finished_scores = layers.beam_search_decode( finished_ids, finished_scores = layers.beam_search_decode(
...@@ -4389,8 +4390,15 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4389,8 +4390,15 @@ def ctc_greedy_decoder(input, blank, name=None):
[0.5, 0.1, 0.3, 0.1]] [0.5, 0.1, 0.3, 0.1]]
input.lod = [[4, 4]] input.lod = [[4, 4]]
Computation:
Then: step1: Apply argmax to first input sequence which is input.data[0:4]. Then we get:
[[0], [2], [1], [0]]
step2: merge repeated tokens and remove blank which is 0. Then we get first output sequence:
[[2], [1]]
Finally:
output.data = [[2], output.data = [[2],
[1], [1],
...@@ -4398,6 +4406,7 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4398,6 +4406,7 @@ def ctc_greedy_decoder(input, blank, name=None):
output.lod = [[2, 1]] output.lod = [[2, 1]]
Args: Args:
input(Variable): (LoDTensor<float>), the probabilities of input(Variable): (LoDTensor<float>), the probabilities of
...@@ -4412,8 +4421,10 @@ def ctc_greedy_decoder(input, blank, name=None): ...@@ -4412,8 +4421,10 @@ def ctc_greedy_decoder(input, blank, name=None):
name (str): The name of this layer. It is optional. name (str): The name of this layer. It is optional.
Returns: Returns:
Variable: CTC greedy decode result. If all the sequences in result were Variable: CTC greedy decode result which is a 2-D tensor with shape [Lp, 1].
empty, the result LoDTensor will be [-1] with LoD [[]] and dims [1, 1]. 'Lp' is the sum if all output sequences' length. If all the sequences
in result were empty, the result LoDTensor will be [-1] with
LoD [[]] and dims [1, 1].
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -5047,7 +5058,7 @@ def im2sequence(input, ...@@ -5047,7 +5058,7 @@ def im2sequence(input,
output.lod = [[4, 4]] output.lod = [[4, 4]]
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -5834,24 +5845,23 @@ def pad_constant_like(x, y, pad_value=0., name=None): ...@@ -5834,24 +5845,23 @@ def pad_constant_like(x, y, pad_value=0., name=None):
[[38, 39, 40]], [[38, 39, 40]],
[[41, 42, 43]]]] [[41, 42, 43]]]]
Y.shape = (1, 3, 1, 3) Y.shape = (1, 3, 1, 3)
And
pad_value = -1,
And Return:
pad_value = -1, Out = [[[[35, 36, 37],
[-1, -1, -1]],
Return: [[38, 39, 40],
Out = [[[[35, 36, 37], [-1, -1, -1]],
[-1, -1, -1]], [[41, 42, 43],
[[38, 39, 40], [-1, -1, -1]]],
[-1, -1, -1]], [[[-1, -1, -1],
[[41, 42, 43], [-1, -1, -1]],
[-1, -1, -1]]], [[-1, -1, -1],
[[[-1, -1, -1], [-1, -1, -1]],
[-1, -1, -1]], [[-1, -1, -1],
[[-1, -1, -1], [-1, -1, -1]]]]
[-1, -1, -1]], Out.shape = (2, 3, 2, 3)
[[-1, -1, -1],
[-1, -1, -1]]]]
Out.shape = (2, 3, 2, 3)
Args: Args:
x (Variable): The input tensor variable. x (Variable): The input tensor variable.
...@@ -6090,6 +6100,7 @@ def image_resize(input, ...@@ -6090,6 +6100,7 @@ def image_resize(input,
Supporting resample methods: Supporting resample methods:
'BILINEAR' : Bilinear interpolation 'BILINEAR' : Bilinear interpolation
'NEAREST' : Nearest neighbor interpolation 'NEAREST' : Nearest neighbor interpolation
Args: Args:
...@@ -6745,7 +6756,7 @@ def crop(x, shape=None, offsets=None, name=None): ...@@ -6745,7 +6756,7 @@ def crop(x, shape=None, offsets=None, name=None):
# or # or
z = fluid.layers.data(name="z", shape=[3, 5], dtype="float32") z = fluid.layers.data(name="z", shape=[3, 5], dtype="float32")
crop = fluid.layers.crop(z, shape=[2, 3]) crop = fluid.layers.crop(z, shape=[-1, 2, 3])
""" """
helper = LayerHelper('crop', **locals()) helper = LayerHelper('crop', **locals())
...@@ -7026,39 +7037,40 @@ def pad2d(input, ...@@ -7026,39 +7037,40 @@ def pad2d(input,
than height-1. And the width dimension has the same condition. than height-1. And the width dimension has the same condition.
Example: Example:
.. code-block:: text
Given that X is a channel of image from input: Given that X is a channel of image from input:
X = [[1, 2, 3], X = [[1, 2, 3],
[4, 5, 6]] [4, 5, 6]]
Case 0: Case 0:
paddings = [0, 1, 2, 3], paddings = [0, 1, 2, 3],
mode = 'constant' mode = 'constant'
pad_value = 0 pad_value = 0
Out = [[0, 0, 1, 2, 3, 0, 0, 0] Out = [[0, 0, 1, 2, 3, 0, 0, 0]
[0, 0, 4, 5, 6, 0, 0, 0] [0, 0, 4, 5, 6, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0]] [0, 0, 0, 0, 0, 0, 0, 0]]
Case 1: Case 1:
paddings = [0, 1, 2, 1], paddings = [0, 1, 2, 1],
mode = 'reflect' mode = 'reflect'
Out = [[3, 2, 1, 2, 3, 2] Out = [[3, 2, 1, 2, 3, 2]
[6, 5, 4, 5, 6, 5] [6, 5, 4, 5, 6, 5]
[3, 2, 1, 2, 3, 2]] [3, 2, 1, 2, 3, 2]]
Case 2: Case 2:
paddings = [0, 1, 2, 1], paddings = [0, 1, 2, 1],
mode = 'edge' mode = 'edge'
Out = [[1, 1, 1, 2, 3, 3] Out = [[1, 1, 1, 2, 3, 3]
[4, 4, 4, 5, 6, 6] [4, 4, 4, 5, 6, 6]
[4, 4, 4, 5, 6, 6]] [4, 4, 4, 5, 6, 6]]
Args: Args:
...@@ -7295,13 +7307,13 @@ def prelu(x, mode, param_attr=None, name=None): ...@@ -7295,13 +7307,13 @@ def prelu(x, mode, param_attr=None, name=None):
Args: Args:
x (Variable): The input tensor. x (Variable): The input tensor.
param_attr(ParamAttr|None): The parameter attribute for the learnable param_attr(ParamAttr|None): The parameter attribute for the learnable
weight (alpha). weight (alpha).
mode (string): The mode for weight sharing. It supports all, channel mode (string): The mode for weight sharing. It supports all, channel
and element. all: all elements share same weight and element. all: all elements share same weight
channel:elements in a channel share same weight channel:elements in a channel share same weight
element:each element has a weight element:each element has a weight
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
Returns: Returns:
Variable: The output tensor with the same shape as input. Variable: The output tensor with the same shape as input.
...@@ -7745,6 +7757,11 @@ def uniform_random_batch_size_like(input, ...@@ -7745,6 +7757,11 @@ def uniform_random_batch_size_like(input,
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
input = layers.data(name="input", shape=[13, 11], dtype='float32')
out = layers.uniform_random_batch_size_like(input, [-1, 11])
""" """
helper = LayerHelper('uniform_random_batch_size_like', **locals()) helper = LayerHelper('uniform_random_batch_size_like', **locals())
...@@ -7782,6 +7799,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'): ...@@ -7782,6 +7799,10 @@ def gaussian_random(shape, mean=0.0, std=1.0, seed=0, dtype='float32'):
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
out = layers.gaussian_random(shape=[20, 30])
""" """
helper = LayerHelper('gaussian_random', **locals()) helper = LayerHelper('gaussian_random', **locals())
...@@ -7817,6 +7838,16 @@ def sampling_id(x, min=0.0, max=1.0, seed=0, dtype='float32'): ...@@ -7817,6 +7838,16 @@ def sampling_id(x, min=0.0, max=1.0, seed=0, dtype='float32'):
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
x = layers.data(
name="X",
shape=[13, 11],
dtype='float32',
append_batch_size=False)
out = layers.sampling_id(x)
""" """
helper = LayerHelper('sampling_id', **locals()) helper = LayerHelper('sampling_id', **locals())
...@@ -7856,6 +7887,14 @@ def gaussian_random_batch_size_like(input, ...@@ -7856,6 +7887,14 @@ def gaussian_random_batch_size_like(input,
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
input = layers.data(name="input", shape=[13, 11], dtype='float32')
out = layers.gaussian_random_batch_size_like(
input, shape=[-1, 11], mean=1.0, std=2.0)
""" """
helper = LayerHelper('gaussian_random_batch_size_like', **locals()) helper = LayerHelper('gaussian_random_batch_size_like', **locals())
...@@ -7888,6 +7927,12 @@ def sum(x): ...@@ -7888,6 +7927,12 @@ def sum(x):
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
input = layers.data(name="input", shape=[13, 11], dtype='float32')
out = layers.sum(input)
""" """
helper = LayerHelper('sum', **locals()) helper = LayerHelper('sum', **locals())
...@@ -7916,6 +7961,17 @@ def slice(input, axes, starts, ends): ...@@ -7916,6 +7961,17 @@ def slice(input, axes, starts, ends):
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
starts = [1, 0, 2]
ends = [3, 3, 4]
axes = [0, 1, 2]
input = layers.data(
name="input", shape=[3, 4, 5, 6], dtype='float32')
out = layers.slice(input, axes=axes, starts=starts, ends=ends)
""" """
helper = LayerHelper('slice', **locals()) helper = LayerHelper('slice', **locals())
...@@ -7943,6 +7999,12 @@ def shape(input): ...@@ -7943,6 +7999,12 @@ def shape(input):
Returns: Returns:
out (Variable): ${out_comment} out (Variable): ${out_comment}
Examples:
.. code-block:: python
input = layers.data(
name="input", shape=[3, 100, 100], dtype="float32")
out = layers.shape(input)
""" """
helper = LayerHelper('shape', **locals()) helper = LayerHelper('shape', **locals())
......
...@@ -222,13 +222,13 @@ class Precision(MetricBase): ...@@ -222,13 +222,13 @@ class Precision(MetricBase):
Examples: Examples:
.. code-block:: python .. code-block:: python
metric = fluid.metrics.Precision() metric = fluid.metrics.Precision()
for pass in range(PASSES): for pass in range(PASSES):
metric.reset() metric.reset()
for data in train_reader(): for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels) metric.update(preds=preds, labels=labels)
numpy_precision = metric.eval() numpy_precision = metric.eval()
""" """
def __init__(self, name=None): def __init__(self, name=None):
...@@ -267,13 +267,13 @@ class Recall(MetricBase): ...@@ -267,13 +267,13 @@ class Recall(MetricBase):
Examples: Examples:
.. code-block:: python .. code-block:: python
metric = fluid.metrics.Recall() metric = fluid.metrics.Recall()
for pass in range(PASSES): for pass in range(PASSES):
metric.reset() metric.reset()
for data in train_reader(): for data in train_reader():
loss, preds, labels = exe.run(fetch_list=[cost, preds, labels]) loss, preds, labels = exe.run(fetch_list=[cost, preds, labels])
metric.update(preds=preds, labels=labels) metric.update(preds=preds, labels=labels)
numpy_recall = metric.eval() numpy_recall = metric.eval()
""" """
def __init__(self, name=None): def __init__(self, name=None):
...@@ -449,8 +449,9 @@ class EditDistance(MetricBase): ...@@ -449,8 +449,9 @@ class EditDistance(MetricBase):
distance_evaluator.update(distances, seq_num) distance_evaluator.update(distances, seq_num)
distance, instance_error = distance_evaluator.eval() distance, instance_error = distance_evaluator.eval()
In the above example: In the above example:
'distance' is the average of the edit distance in a pass. 'distance' is the average of the edit distance in a pass.
'instance_error' is the instance error rate in a pass. 'instance_error' is the instance error rate in a pass.
""" """
......
...@@ -50,8 +50,9 @@ class ParamAttr(object): ...@@ -50,8 +50,9 @@ class ParamAttr(object):
w_param_attrs = fluid.ParamAttr(name="fc_weight", w_param_attrs = fluid.ParamAttr(name="fc_weight",
learning_rate=0.5, learning_rate=0.5,
regularizer=fluid.L2Decay(1.0), regularizer=fluid.regularizer.L2Decay(1.0),
trainable=True) trainable=True)
x = fluid.layers.data(name='X', shape=[1], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=10, param_attr=w_param_attrs) y_predict = fluid.layers.fc(input=x, size=10, param_attr=w_param_attrs)
""" """
......
...@@ -125,13 +125,14 @@ def slice_variable(var_list, slice_count, min_block_size): ...@@ -125,13 +125,14 @@ def slice_variable(var_list, slice_count, min_block_size):
class DistributeTranspilerConfig(object): class DistributeTranspilerConfig(object):
""" """
slice_var_up (bool): Do Tensor slice for pservers, default is True. Args:
split_method (PSDispatcher): RoundRobin or HashName can be used slice_var_up (bool): Do Tensor slice for pservers, default is True.
try to choose the best method to balance loads for pservers. split_method (PSDispatcher): RoundRobin or HashName can be used
min_block_size (int): Minimum splitted element number in block. try to choose the best method to balance loads for pservers.
According:https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156 min_block_size (int): Minimum splitted element number in block.
We can use bandwidth effiently when data size is larger than 2MB.If you According:https://github.com/PaddlePaddle/Paddle/issues/8638#issuecomment-369912156
want to change it, please be sure you see the slice_variable function. We can use bandwidth effiently when data size is larger than 2MB.If you
want to change it, please be sure you see the slice_variable function.
""" """
slice_var_up = True slice_var_up = True
...@@ -163,35 +164,35 @@ class DistributeTranspiler(object): ...@@ -163,35 +164,35 @@ class DistributeTranspiler(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
# for pserver mode # for pserver mode
pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174" pserver_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174" trainer_endpoints = "192.168.0.1:6174,192.168.0.2:6174"
current_endpoint = "192.168.0.1:6174" current_endpoint = "192.168.0.1:6174"
trainer_id = 0 trainer_id = 0
trainers = 4 trainers = 4
role = os.getenv("PADDLE_TRAINING_ROLE") role = os.getenv("PADDLE_TRAINING_ROLE")
t = fluid.DistributeTranspiler() t = fluid.DistributeTranspiler()
t.transpile( t.transpile(
trainer_id, pservers=pserver_endpoints, trainers=trainers) trainer_id, pservers=pserver_endpoints, trainers=trainers)
if role == "PSERVER": if role == "PSERVER":
pserver_program = t.get_pserver_program(current_endpoint) pserver_program = t.get_pserver_program(current_endpoint)
pserver_startup_program = t.get_startup_program(current_endpoint, pserver_startup_program = t.get_startup_program(current_endpoint,
pserver_program) pserver_program)
elif role == "TRAINER": elif role == "TRAINER":
trainer_program = t.get_trainer_program() trainer_program = t.get_trainer_program()
# for nccl2 mode # for nccl2 mode
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.mode = "nccl2" config.mode = "nccl2"
t = fluid.DistributeTranspiler(config=config) t = fluid.DistributeTranspiler(config=config)
t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep) t.transpile(trainer_id, workers=workers, current_endpoint=curr_ep)
exe = fluid.ParallelExecutor( exe = fluid.ParallelExecutor(
use_cuda, use_cuda,
loss_name=loss_var.name, loss_name=loss_var.name,
num_trainers=len(trainers.split(",)), num_trainers=len(trainers.split(",)),
trainer_id=trainer_id trainer_id=trainer_id
) )
""" """
def __init__(self, config=None): def __init__(self, config=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册