提交 98b14a3a 编写于 作者: T tensor-tang

Merge remote-tracking branch 'ups/develop' into fix

......@@ -112,7 +112,7 @@ $$out = \frac{1}{1 + e^{-x}}$$
__attribute__((unused)) constexpr char LogSigmoidDoc[] = R"DOC(
Logsigmoid Activation Operator
$$out = \log \frac{1}{1 + e^{-x}}$$
$$out = \\log \\frac{1}{1 + e^{-x}}$$
)DOC";
......
......@@ -106,23 +106,36 @@ class BoxCoderOpMaker : public framework::OpProtoAndCheckerMaker {
"and M represents the number of deocded boxes.");
AddComment(R"DOC(
Bounding Box Coder Operator.
Bounding Box Coder.
Encode/Decode the target bounding box with the priorbox information.
The Encoding schema described below:
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = log(abs(tw / pw)) / pwv
oh = log(abs(th / ph)) / phv
ox = (tx - px) / pw / pxv
oy = (ty - py) / ph / pyv
ow = log(abs(tw / pw)) / pwv
oh = log(abs(th / ph)) / phv
The Decoding schema described below:
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2
where tx, ty, tw, th denote the target box's center coordinates, width and
height respectively. Similarly, px, py, pw, ph denote the priorbox's(anchor)
center coordinates, width and height. pxv, pyv, pwv, phv denote the variance
of the priorbox and ox, oy, ow, oh denote the encoded/decoded coordinates,
width and height.
ox = (pw * pxv * tx * + px) - tw / 2
oy = (ph * pyv * ty * + py) - th / 2
ow = exp(pwv * tw) * pw + tw / 2
oh = exp(phv * th) * ph + th / 2
where `tx`, `ty`, `tw`, `th` denote the target box's center coordinates, width
and height respectively. Similarly, `px`, `py`, `pw`, `ph` denote the
priorbox's (anchor) center coordinates, width and height. `pxv`, `pyv`, `pwv`,
`phv` denote the variance of the priorbox and `ox`, `oy`, `ow`, `oh` denote the
encoded/decoded coordinates, width and height.
)DOC");
}
};
......
......@@ -36,11 +36,12 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
void Apply() override {
AddAttr<float>("mean",
"(float, default 0.0) "
"mean of random tensor.")
"The mean (or center) of the gaussian distribution.")
.SetDefault(.0f);
AddAttr<float>("std",
"(float, default 1.0) "
"std of random tensor.")
"The standard deviation (std, or spread) of the "
"gaussian distribution.")
.SetDefault(1.0f);
AddAttr<int>("seed",
"(int, default 0) "
......@@ -55,9 +56,11 @@ class GaussianRandomBatchSizeLikeOpMaker : public BatchSizeLikeOpMaker {
.SetDefault(framework::proto::VarType::FP32);
AddComment(R"DOC(
GaussianRandom Operator.
Used to initialize tensors with gaussian random generator.
The defalut mean of the distribution is 0. and defalut standard
deviation (std) of the distribution is 1.. Uers can set mean and std
by input arguments.
)DOC");
}
};
......
......@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0);
}
......
......@@ -15,11 +15,13 @@
import framework
import numpy as np
import contextlib
from framework import convert_np_dtype_to_dtype_
from core import VarDesc
__all__ = [
'Constant', 'Uniform', 'Normal', 'Xavier', 'force_init_on_cpu',
'Constant', 'Uniform', 'Normal', 'Xavier', 'Bilinear', 'force_init_on_cpu',
'init_on_cpu', 'ConstantInitializer', 'UniformInitializer',
'NormalInitializer', 'XavierInitializer'
'NormalInitializer', 'XavierInitializer', 'BilinearInitializer'
]
_force_init_on_cpu_ = False
......@@ -422,6 +424,101 @@ class MSRAInitializer(Initializer):
return op
class BilinearInitializer(Initializer):
"""Implements the bilinear initializer.
This initializer can be used in transposed convolution operator to
act as upsampling. Users can upsample a feature map with shape of
(B, C, H, W) by any integer factor. The usage is:
>>> factor = 2
>>> w_attr = ParamAttr(learning_rate=0., regularizer=L2Decay(0.),
>>> initializer=Bilinear())
>>> conv_up = fluid.layers.conv2d_transpose(
>>> input,
>>> num_filters=C,
>>> output_size=None,
>>> filter_size=2 * factor - factor % 2,
>>> padding=ceil((factor - 1) / 2.),
>>> stride=factor,
>>> groups=C,
>>> param_attr=w_attr,
>>> bias_attr=False)
Where, `num_filters=C` and `groups=C` means this is channel-wise tranposed
convolution. The filter shape will be (C, 1, K, K) where K is `filer_size`,
This initializer will set a (K, K) interpolation kernel for every channel
of the filter identically. The resulting shape of the output feature map
will be (B, C, factor * H, factor * W). Note that the learning rate and the
weight decay are set to 0 in order to keep coefficient values of bilinear
interpolation unchanged during training.
"""
def __init__(self):
"""Constructor for BilinearInitializer.
"""
super(BilinearInitializer, self).__init__()
def __call__(self, var, block):
"""Add biliear initialization ops for a variable
Args:
var (Variable): Variable that needs to be initialized.
block (Block): The block in which initialization ops should
be added.
Returns:
the initialization op
Raises:
ValueError: If type of `var` and `block` is not right.
If the shape of `var` size is not 4 and
var.shape[2] != var.shape[3].
"""
if not isinstance(var, framework.Variable):
raise ValueError("var must be framework.Variable.")
if not isinstance(block, framework.Block):
raise ValueError("block must be framework.Block.")
shape = var.shape
if len(shape) != 4:
raise ValueError("the length of shape must be 4.")
if shape[2] != shape[3]:
raise ValueError("shape[2] must be equal to shape[3].")
weight = np.zeros(np.prod(var.shape), dtype='float32')
size = shape[3]
# factor
f = np.ceil(size / 2.)
# center
c = (2 * f - 1 - f % 2) / (2. * f)
for i in range(np.prod(shape)):
x = i % size
y = (i / size) % size
weight[i] = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
weight = np.reshape(weight, shape)
if var.dtype == VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in weight.flat]
else:
raise ValueError("Unsupported dtype %s", input.dtype)
if np.prod(shape) > 1024 * 1024:
raise ValueError("The size of input is too big. ")
op = block.append_op(
type='assign_value',
outputs={'Out': [var]},
attrs={
'dtype': var.dtype,
'shape': list(shape),
value_name: values
})
var.op = op
return op
# We short the class name, since users will use the initializer with the package
# name. The sample code:
#
......@@ -436,3 +533,4 @@ Uniform = UniformInitializer
Normal = NormalInitializer
Xavier = XavierInitializer
MSRA = MSRAInitializer
Bilinear = BilinearInitializer
......@@ -22,9 +22,9 @@ from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'Preprocessor', 'load'
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
]
......@@ -177,18 +177,17 @@ class ListenAndServ(object):
})
def Send(endpoints, send_vars, get_vars=None):
def Send(endpoints, send_vars, sync=True):
"""
Send layer
Send variables to the server side, and get vars from server
side when server have finished running server side program.
Args:
endpoints: comma seperated IP:PORT pairs in the order
endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
Send variables to the server side, and get vars from server
side when server have finished running server side program.
send_vars (list): variables to send to server
sync (bool): whether to wait the request finish
"""
assert (type(send_vars) == list)
......@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars},
attrs={
"endpoints": endpoints,
"epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
return get_vars
if sync:
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
def Recv(endpoints, get_vars):
def Recv(endpoints, get_vars, sync=True):
"""
Recv layer
Receive variables from server side
Args:
endpoints: comma seperated IP:PORT pairs in the order
endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
get_vars (list): vars to get from server after send completes.
sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server
side when server have finished running server side program.
Returns:
list: list of received variables
"""
assert (type(send_vars) == list)
assert (type(get_vars) == list)
epmap = endpoints.split(",")
......@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
if sync:
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
return get_vars
def monkey_patch_reader_methods(reader):
......@@ -383,16 +378,16 @@ def random_data_generator(low, high, shapes, lod_levels, for_parallel=True):
Variable: A Reader Variable from which we can get random data.
Examples:
.. code-block:: python
reader = fluid.layers.io.random_data_generator(
low=0.0,
high=1.0,
shapes=[(3,224,224), (1)],
lod_levels=[0, 0])
.. code-block:: python
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.io.read_file(reader)
reader = fluid.layers.random_data_generator(
low=0.0,
high=1.0,
shapes=[[3,224,224], [1]],
lod_levels=[0, 0])
# Via the reader, we can use 'read_file' layer to get data:
image, label = fluid.layers.read_file(reader)
"""
dtypes = [core.VarDesc.VarType.FP32] * len(shapes)
shape_concat = []
......
......@@ -44,6 +44,11 @@ def _type_to_str_(tp):
return framework_pb2.AttrType.Name(tp)
_two_dollar_pattern_ = re.compile(r"\$\$([^\$]+)\$\$")
_single_dollar_pattern_ = re.compile(r"\$([^\$]+)\$")
_two_bang_pattern_ = re.compile(r"!!([^!]+)!!")
def _generate_doc_string_(op_proto):
"""
Generate docstring by OpProto
......@@ -55,22 +60,26 @@ def _generate_doc_string_(op_proto):
str: the document string
"""
def escape_math(text):
return _two_bang_pattern_.sub(
r'$$\1$$',
_single_dollar_pattern_.sub(
r':math:`\1`', _two_dollar_pattern_.sub(r"!!\1!!", text)))
if not isinstance(op_proto, framework_pb2.OpProto):
raise TypeError("OpProto should be `framework_pb2.OpProto`")
buf = cStringIO.StringIO()
buf.write(op_proto.comment)
buf.write(escape_math(op_proto.comment))
buf.write('\nArgs:\n')
for each_input in op_proto.inputs:
line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin)
buf.write(each_input.comment)
buf.write('\n')
buf.write(' ' * len(line_begin))
buf.write('Duplicable: ')
buf.write(str(each_input.duplicable))
buf.write(' Optional: ')
buf.write(str(each_input.dispensable))
buf.write(escape_math(each_input.comment))
if each_input.duplicable:
buf.write(" Duplicatable.")
if each_input.dispensable:
buf.write(" Optional.")
buf.write('\n')
skip_attrs = OpProtoHolder.generated_op_attr_names()
......@@ -83,7 +92,7 @@ def _generate_doc_string_(op_proto):
buf.write(' (')
buf.write(_type_to_str_(each_attr.type))
buf.write('): ')
buf.write(each_attr.comment)
buf.write(escape_math(each_attr.comment))
buf.write('\n')
if len(op_proto.outputs) != 0:
......@@ -92,7 +101,7 @@ def _generate_doc_string_(op_proto):
for each_opt in op_proto.outputs:
if not each_opt.intermediate:
break
buf.write(each_opt.comment)
buf.write(escape_math(each_opt.comment))
return buf.getvalue()
......
......@@ -364,8 +364,7 @@ def dynamic_lstm(input,
cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh",
"relu", "identity"],
Choices = ["sigmoid", "tanh", "relu", "identity"],
default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
......@@ -540,27 +539,31 @@ def dynamic_lstmp(input,
cell_activation(str): The activation for cell output. Choices = ["sigmoid",
"tanh", "relu", "identity"], default "tanh".
candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh",
"relu", "identity"],
Choices = ["sigmoid", "tanh", "relu", "identity"],
default "tanh".
proj_activation(str): The activation for projection output.
Choices = ["sigmoid", "tanh",
"relu", "identity"],
Choices = ["sigmoid", "tanh", "relu", "identity"],
default "tanh".
dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
tuple: The projection of hidden state, and cell state of LSTMP. The \
shape of projection is (T x P), for the cell state which is \
(T x D), and both LoD is the same with the `input`.
tuple: A tuple of two output variable: the projection of hidden state, \
and cell state of LSTMP. The shape of projection is (T x P), \
for the cell state which is (T x D), and both LoD is the same \
with the `input`.
Examples:
.. code-block:: python
dict_dim, emb_dim = 128, 64
data = fluid.layers.data(name='sequence', shape=[1],
dtype='int32', lod_level=1)
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim, proj_dim = 512, 256
fc_out = fluid.layers.fc(input=input_seq, size=hidden_dim * 4,
fc_out = fluid.layers.fc(input=emb, size=hidden_dim * 4,
act=None, bias_attr=None)
proj_out, _ = fluid.layers.dynamic_lstmp(input=fc_out,
size=hidden_dim * 4,
......@@ -626,10 +629,10 @@ def dynamic_gru(input,
candidate_activation='tanh',
h_0=None):
"""
**Dynamic GRU Layer**
**Gated Recurrent Unit (GRU) Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_ .
The formula is as follows:
......@@ -676,17 +679,25 @@ def dynamic_gru(input,
Choices = ["sigmoid", "tanh", "relu", "identity"], default "sigmoid".
candidate_activation(str): The activation for candidate hidden state.
Choices = ["sigmoid", "tanh", "relu", "identity"], default "tanh".
h_0 (Variable): The hidden output of the first time step.
h_0 (Variable): This is initial hidden state. If not set, default is
zero. This is a tensor with shape (N x D), where N is the number of
total time steps of input mini-batch feature and D is the hidden
size.
Returns:
Variable: The hidden state of GRU. The shape is :math:`(T \\times D)`, \
and lod is the same with the input.
and sequence length is the same with the input.
Examples:
.. code-block:: python
dict_dim, emb_dim = 128, 64
data = fluid.layers.data(name='sequence', shape=[1],
dtype='int32', lod_level=1)
emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
hidden_dim = 512
x = fluid.layers.fc(input=data, size=hidden_dim * 3)
x = fluid.layers.fc(input=emb, size=hidden_dim * 3)
hidden = fluid.layers.dynamic_gru(input=x, dim=hidden_dim)
"""
......@@ -924,13 +935,13 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
Drop or keep each element of `x` independently. Dropout is a regularization
technique for reducing overfitting by preventing neuron co-adaption during
training. The dropout operator randomly set (according to the given dropout
training. The dropout operator randomly sets (according to the given dropout
probability) the outputs of some units to zero, while others are remain
unchanged.
Args:
x (Variable): The input tensor.
dropout_prob (float): Probability of setting units to zero.
x (Variable): The input tensor variable.
dropout_prob (float): Probability of setting units to zero.
is_test (bool): A flag indicating whether it is in test phrase or not.
seed (int): A Python integer used to create random seeds. If this
parameter is set to None, a random seed is used.
......@@ -940,13 +951,14 @@ def dropout(x, dropout_prob, is_test=False, seed=None, name=None):
will be named automatically.
Returns:
Variable: A tensor variable.
Variable: A tensor variable is the shape with `x`.
Examples:
.. code-block:: python
x = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
droped = fluid.layers.dropout(input=x, dropout_rate=0.5)
x = fluid.layers.data(name="data", shape=[32, 32], dtype="float32")
droped = fluid.layers.dropout(x, dropout_prob=0.5)
"""
helper = LayerHelper('dropout', **locals())
......@@ -2990,32 +3002,33 @@ def l2_normalize(x, axis, epsilon=1e-12, name=None):
norm. For a 1-D tensor (`dim` is fixed to 0), this layer computes
.. math::
y = \frac{x}{ \sqrt{\sum {x^2} + epsion }}
y = \\frac{x}{ \sqrt{\sum {x^2} + epsion }}
For `x` with more dimensions, this layer independently normalizes each 1-D
slice along dimension `axis`.
Args:
x(Variable|list): The input tensor to l2_normalize layer.
axis(int): The axis on which to apply normalization. If `axis < 0`,
axis(int): The axis on which to apply normalization. If `axis < 0`, \
the dimension to normalization is rank(X) + axis. -1 is the
last dimension.
epsilon(float): The epsilon value is used to avoid division by zero,
epsilon(float): The epsilon value is used to avoid division by zero, \
the defalut value is 1e-10.
name(str|None): A name for this layer(optional). If set None, the layer
name(str|None): A name for this layer(optional). If set None, the layer \
will be named automatically.
Returns:
Variable: The output tensor variable.
Variable: The output tensor variable is the same shape with `x`.
Examples:
.. code-block:: python
data = fluid.layers.data(name="data",
shape=(3, 17, 13),
dtype="float32")
normed = fluid.layers.l2_normalize(x=data, axis=1)
data = fluid.layers.data(name="data",
shape=(3, 17, 13),
dtype="float32")
normed = fluid.layers.l2_normalize(x=data, axis=1)
"""
if len(x.shape) == 1:
......
......@@ -497,11 +497,27 @@ def save_combine(x, file_path, overwrite=True):
Saves a list of variables into a single file.
Args:
x(list): A list of Tensor/LoDTensor to be saved together in a single file.
x(list): A list of Tensor/LoDTensor variables to be saved together in
a single file.
file_path(str): The file path where variables will be saved.
overwrite(bool): Whether or not cover the given file when it has already
overwrite(bool): Whether or not cover the given file when it has already
existed. If it's set 'False' and the file is existed, a runtime
error will be thrown.
Returns:
There is no return value.
Examples:
.. code-block:: python
v1 = fluid.layers.data(name="data",
shape=(4, 6),
dtype="float32")
v2 = fluid.layers.data(name="data",
shape=(6, 8, 4),
dtype="float32")
normed = fluid.layers.save_combine([v1, v2], file_path="output")
"""
helper = LayerHelper("save_combine", **locals())
helper.append_op(
......
......@@ -16,6 +16,7 @@ import os
import time
import unittest
from multiprocessing import Process
import signal
import numpy
......@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers
class TestSendOp(unittest.TestCase):
@unittest.skip(
"This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest."
)
def test_send(self):
# Run init_serv in a thread
place = fluid.CPUPlace()
......@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase):
p.daemon = True
p.start()
time.sleep(10)
self.ps_timeout = 5
self._wait_ps_ready(p.pid)
with open("/tmp/paddle.%d.port" % p.pid, "r") as fn:
selected_port = int(fn.readlines()[0])
self.init_client(place, selected_port)
......@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase):
self.assertTrue(numpy.allclose(self.local_out, self.dist_out))
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
os.system("kill -9 %d" % p.pid)
os.kill(p.pid, signal.SIGKILL)
p.join()
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def init_serv(self, place):
main = fluid.Program()
......@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase):
dtype="float32",
persistable=False,
shape=[32, 32])
o = layers.Send("127.0.0.1:%d" % port, [x], [get_var])
fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
layers.Send("127.0.0.1:%d" % port, [x])
o = layers.Recv("127.0.0.1:%d" % port, [get_var])
exe = fluid.Executor(place)
self.dist_out = exe.run(main, fetch_list=o) # o is a list
......
......@@ -364,5 +364,22 @@ class TestMSRAInitializer(unittest.TestCase):
self.assertEqual(init_op.attr('seed'), 134)
class TestMSRAInitializer(unittest.TestCase):
def test_bilinear_initializer(self):
"""Test the bilinear initializer with supplied arguments
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[8, 1, 3, 3],
lod_level=0,
name="param",
initializer=initializer.BilinearInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'assign_value')
if __name__ == '__main__':
unittest.main()
......@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest):
def setUp(self):
self.ps_timeout = 5
self.ip = "127.0.0.1"
self.port = "6173"
self.port = "0"
self.trainers = 1
self.trainer_id = 1
self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode):
p = Process(
target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id))
p.daemon = True
p.start()
return p.pid
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
......@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest):
def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
pid = self._start_pserver(False, True)
self._wait_ps_ready(pid)
p1 = self._start_pserver(False, True)
self._wait_ps_ready(p1.pid)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)
os.kill(p1.pid, signal.SIGKILL)
p1.join()
# run pserver on CPU in async mode
pid = self._start_pserver(False, False)
self._wait_ps_ready(pid)
p2 = self._start_pserver(False, False)
self._wait_ps_ready(p2.pid)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)
os.kill(p2.pid, signal.SIGKILL)
p2.join()
if __name__ == '__main__':
......
......@@ -173,6 +173,7 @@ class TestCRFModel(unittest.TestCase):
pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0]
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
......@@ -181,6 +182,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
......@@ -189,6 +191,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......@@ -197,6 +200,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册