提交 8415e18b 编写于 作者: L Luo Tao

Merge branch 'develop' into softmax_doc

......@@ -64,7 +64,8 @@ class OpConverter {
(*it)(op, scope, test_mode);
}
// convert fluid block to tensorrt network
// Convert a fluid block to tensorrt network, NOTE it just convert operators,
// the INetwork's inputs and outputs should specified in some other modules.
void ConvertBlock(const framework::proto::BlockDesc& block,
const std::unordered_set<std::string>& parameters,
const framework::Scope& scope, TensorRTEngine* engine) {
......
......@@ -51,11 +51,12 @@ class TensorRTEngine : public EngineBase {
nvinfer1::Weights w_;
};
TensorRTEngine(int max_batch, int max_workspace, cudaStream_t* stream,
TensorRTEngine(int max_batch, int max_workspace,
cudaStream_t* stream = nullptr,
nvinfer1::ILogger& logger = NaiveLogger::Global())
: max_batch_(max_batch),
max_workspace_(max_workspace),
stream_(stream),
stream_(stream ? stream : &default_stream_),
logger_(logger) {}
virtual ~TensorRTEngine();
......@@ -121,6 +122,8 @@ class TensorRTEngine : public EngineBase {
// the max memory size the engine uses
int max_workspace_;
cudaStream_t* stream_;
// If stream_ is not set from outside, hold its own stream.
cudaStream_t default_stream_;
nvinfer1::ILogger& logger_;
std::vector<Buffer> buffers_;
......@@ -165,20 +168,31 @@ class TensorRTEngine : public EngineBase {
*/
class TRT_EngineManager {
public:
TensorRTEngine* Create(int max_batch, int max_workspace,
cudaStream_t* stream) {
engines_.emplace_back(new TensorRTEngine(max_batch, max_workspace, stream));
return engines_.back().get();
bool HasEngine(const std::string& name) const {
return engines_.count(name) != 0;
}
// Get an engine called `name`.
TensorRTEngine* Get(const std::string& name) const {
return engines_.at(name).get();
}
// Create or get an engine called `name`
TensorRTEngine* Create(int max_batch, int max_workspace, cudaStream_t* stream,
const std::string& name) {
auto* p = new TensorRTEngine(max_batch, max_workspace, stream);
engines_[name].reset(p);
return p;
}
void DeleteALl() {
for (auto& ptr : engines_) {
ptr.reset(nullptr);
for (auto& item : engines_) {
item.second.reset(nullptr);
}
}
private:
std::vector<std::unique_ptr<TensorRTEngine>> engines_;
std::unordered_map<std::string, std::unique_ptr<TensorRTEngine>> engines_;
};
} // namespace tensorrt
......
......@@ -252,15 +252,14 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC(
Softshrink Activation Operator.
$$
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
$$
:strong:`Softshrink Activation Operator`
.. math::
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
)DOC");
}
......
......@@ -348,7 +348,8 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
};
void SignalHandler::StopAndExit(int signal_num) {
VLOG(3) << "Catch interrupt signal: " << signal_num << ", program will exit";
// Do not use VLOG here for the device for printing maybe already released.
// exit will release interal allocated resoureces.
exit(0);
}
......
......@@ -33,12 +33,10 @@ class MeanOp : public framework::OperatorWithKernel {
class MeanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of mean op");
AddOutput("Out", "The output of mean op").Reuse("X");
AddInput("X", "(Tensor) The input of mean op");
AddOutput("Out", "(Tensor) The output of mean op").Reuse("X");
AddComment(R"DOC(
Mean Operator.
Out is a scalar which is the mean of all elements in X.
Mean Operator calculates the mean of all elements in X.
)DOC");
}
......
......@@ -66,17 +66,25 @@ nvinfer1::Dims Vec2TRT_Dims(const std::vector<int64_t> &shape) {
} // namespace
template <typename DeviceContext, typename T>
void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
void TensorRTEngineKernel<DeviceContext, T>::Prepare(
const framework::ExecutionContext &context) const {
VLOG(4) << "Prepare engine";
// Get the ProgramDesc and pass to convert.
framework::proto::BlockDesc block_desc;
block_desc.ParseFromString(context.Attr<std::string>("subgraph"));
max_batch_ = context.Attr<int>("max_batch");
int max_batch = context.Attr<int>("max_batch");
auto max_workspace = context.Attr<int>("max_workspace");
engine_ = Singleton<TRT_EngineManager>::Global().Create(
max_batch_, max_workspace, &stream_);
engine_->InitNetwork();
auto params = context.Attr<std::vector<std::string>>("parameters");
std::unordered_set<std::string> parameters;
for (const auto &param : params) {
parameters.insert(param);
}
// TODO(Superjomn) replace this with a different stream
auto *engine = Singleton<TRT_EngineManager>::Global().Create(
max_batch, max_workspace, nullptr /*engine hold its own stream*/,
context.Attr<std::string>("engine_uniq_key"));
engine->InitNetwork();
framework::BlockDesc block(nullptr /*programdesc*/, &block_desc);
// Add inputs
......@@ -87,24 +95,23 @@ void paddle::operators::TensorRTEngineKernel<DeviceContext, T>::Prepare(
PADDLE_ENFORCE_EQ(var->GetType(), FluidDT::VarType_Type_LOD_TENSOR,
"TensorRT engine only takes LoDTensor as input");
auto shape = var->GetShape();
engine_->DeclareInput(
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
Vec2TRT_Dims(var->GetShape()));
}
// TODO(Superjomn) parameters should be passed after analysised from outside.
inference::Singleton<inference::tensorrt::OpConverter>::Global().ConvertBlock(
block_desc, {}, context.scope(), engine_);
block_desc, parameters, context.scope(), engine);
// Add outputs
VLOG(4) << "declare outputs";
for (auto &output : context.Outputs("Ys")) {
VLOG(4) << "declare output " << output;
engine_->DeclareOutput(output);
engine->DeclareOutput(output);
}
engine_->FreezeNetwork();
engine->FreezeNetwork();
}
class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -113,6 +120,7 @@ class TensorRTEngineOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Xs", "A list of inputs.").AsDuplicable();
AddOutput("Ys", "A list of outputs").AsDuplicable();
AddAttr<std::string>("subgraph", "the subgraph.");
AddAttr<std::string>("engine_uniq_key", "unique key for the TRT engine.");
AddAttr<int>("max_batch", "the maximum batch size.");
AddAttr<int>("max_workspace", "the maximum batch size.");
AddComment("TensorRT engine operator.");
......
......@@ -19,10 +19,14 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace operators {
using inference::Singleton;
using inference::tensorrt::TRT_EngineManager;
class TensorRTEngineOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -47,16 +51,18 @@ template <typename DeviceContext, typename T>
class TensorRTEngineKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
if (!engine_) {
auto engine_name = context.Attr<std::string>("engine_uniq_key");
if (!Singleton<TRT_EngineManager>::Global().HasEngine(engine_name)) {
Prepare(context);
}
auto* engine = Singleton<TRT_EngineManager>::Global().Get(engine_name);
auto input_names = context.op().Inputs("Xs");
PADDLE_ENFORCE(!input_names.empty(), "should pass more than one inputs");
// Try to determine a batch_size
auto& tensor0 = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), input_names.front());
int batch_size = tensor0.dims()[0];
PADDLE_ENFORCE_LE(batch_size, max_batch_);
PADDLE_ENFORCE_LE(batch_size, context.Attr<int>("max_batch"));
// Convert input tensor from fluid to engine.
for (const auto& x : context.Inputs("Xs")) {
......@@ -64,20 +70,20 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto& t = inference::analysis::GetFromScope<framework::LoDTensor>(
context.scope(), x);
if (platform::is_cpu_place(t.place())) {
engine_->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
engine->SetInputFromCPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
} else {
engine_->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
engine->SetInputFromGPU(x, static_cast<const void*>(t.data<void>()),
t.memory_size());
}
}
// Execute the engine.
PADDLE_ENFORCE_GT(batch_size, 0);
engine_->Execute(batch_size);
engine->Execute(batch_size);
// Convert output tensor from engine to fluid
for (const auto& y : context.Outputs("Ys")) {
// convert output and copy to fluid.
nvinfer1::ITensor* trt_t = engine_->GetITensor(y);
nvinfer1::ITensor* trt_t = engine->GetITensor(y);
auto dims = trt_t->getDimensions();
// Use the output ITensor's dims to reshape the Fluid Tensor.
std::vector<int> ddim(dims.d, dims.d + dims.nbDims);
......@@ -89,27 +95,22 @@ class TensorRTEngineKernel : public framework::OpKernel<T> {
auto size = inference::analysis::AccuDims(dims.d, dims.nbDims);
if (platform::is_cpu_place(fluid_t->place())) {
// TODO(Superjomn) change this float to dtype size.
engine_->GetOutputInCPU(
engine->GetOutputInCPU(
y, fluid_t->mutable_data<float>(platform::CPUPlace()),
size * sizeof(float));
} else {
engine_->GetOutputInGPU(
engine->GetOutputInGPU(
y, fluid_t->mutable_data<float>(platform::CUDAPlace()),
size * sizeof(float));
}
}
cudaStreamSynchronize(stream_);
cudaStreamSynchronize(*engine->stream());
}
protected:
// Build the engine.
void Prepare(const framework::ExecutionContext& context) const;
private:
mutable cudaStream_t stream_;
mutable inference::tensorrt::TensorRTEngine* engine_{nullptr};
mutable int max_batch_{0};
};
} // namespace operators
......
......@@ -79,6 +79,17 @@ void SetAttr<int64_t>(framework::proto::OpDesc* op, const std::string& name,
attr->set_type(paddle::framework::proto::AttrType::LONG);
attr->set_l(data);
}
template <>
void SetAttr<std::vector<std::string>>(framework::proto::OpDesc* op,
const std::string& name,
const std::vector<std::string>& data) {
auto* attr = op->add_attrs();
attr->set_name(name);
attr->set_type(paddle::framework::proto::AttrType::STRINGS);
for (const auto& s : data) {
attr->add_strings(s.c_str());
}
}
} // namespace
......@@ -123,11 +134,15 @@ TEST(TensorRTEngineOp, manual) {
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z0"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", 30);
SetAttr<int>(engine_op_desc.Proto(), "max_batch", 100);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 1 << 10);
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "a_engine");
SetAttr<std::vector<std::string>>(engine_op_desc.Proto(), "parameters",
std::vector<std::string>({}));
LOG(INFO) << "create engine op";
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
LOG(INFO) << "engine_op " << engine_op.get();
framework::Scope scope;
platform::CPUPlace place;
......@@ -145,6 +160,88 @@ TEST(TensorRTEngineOp, manual) {
engine_op->Run(scope, place);
}
void Execute(int batch_size, int input_dim, int output_dim, int nlayers = 1) {
framework::ProgramDesc program;
framework::Scope scope;
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
auto* block_ = program.Proto()->add_blocks();
block_->set_idx(0);
block_->set_parent_idx(-1);
using shape_t = std::vector<int64_t>;
LOG(INFO) << "create block desc";
framework::BlockDesc block_desc(&program, block_);
auto AddFCLayer = [&](const std::string& x_name, const std::string& y_name,
const std::string& z_name, bool x_created,
const shape_t& x_shape, const shape_t& y_shape,
const shape_t& z_shape) {
LOG(INFO) << "create fc op";
auto* fc = block_desc.AppendOp();
fc->SetType("mul");
fc->SetInput("X", std::vector<std::string>({x_name}));
fc->SetInput("Y", std::vector<std::string>({y_name}));
fc->SetOutput("Out", std::vector<std::string>({z_name}));
// Set inputs' variable shape in BlockDesc
if (!x_created) {
AddTensorToBlockDesc(block_, x_name,
std::vector<int64_t>({batch_size, input_dim, 1, 1}));
}
AddTensorToBlockDesc(block_, y_name,
std::vector<int64_t>({input_dim, output_dim}));
AddTensorToBlockDesc(block_, z_name,
std::vector<int64_t>({batch_size, output_dim}));
// Prepare variables.
if (!x_created) {
CreateCPUTensor(&scope, x_name, std::vector<int64_t>(x_shape));
}
CreateCPUTensor(&scope, y_name, std::vector<int64_t>(y_shape));
CreateCPUTensor(&scope, z_name, std::vector<int64_t>(z_shape));
// It is wired, need to copy manually.
*block_->add_ops() = *fc->Proto();
};
// Test with 4 layer FC
AddFCLayer("x0", "y0", "z0", false, {batch_size, input_dim},
{input_dim, output_dim}, {batch_size, output_dim});
AddFCLayer("z0", "y1", "z1", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
AddFCLayer("z1", "y2", "z2", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
AddFCLayer("z2", "y3", "z3", true, {}, {output_dim, output_dim},
{batch_size, output_dim});
LOG(INFO) << "create tensorrt desc";
framework::OpDesc engine_op_desc(nullptr);
engine_op_desc.SetType("tensorrt_engine");
engine_op_desc.SetInput("Xs", std::vector<std::string>({"x0"}));
engine_op_desc.SetOutput("Ys", std::vector<std::string>({"z3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "subgraph",
block_->SerializeAsString());
SetAttr<int>(engine_op_desc.Proto(), "max_batch", batch_size);
SetAttr<int>(engine_op_desc.Proto(), "max_workspace", 2 << 10);
SetAttr<std::vector<std::string>>(
engine_op_desc.Proto(), "parameters",
std::vector<std::string>({"y0", "y1", "y2", "y3"}));
SetAttr<std::string>(engine_op_desc.Proto(), "engine_uniq_key", "b_engine");
auto engine_op = framework::OpRegistry::CreateOp(*engine_op_desc.Proto());
// Execute them.
engine_op->Run(scope, place);
}
// Test with a larger FC layer.
TEST(TensorRTEngineOp, fc) { Execute(40, 256, 256); }
} // namespace operators
} // namespace paddle
......
......@@ -754,7 +754,7 @@ def lod_rank_table(x, level=0):
.. code-block:: python
x = fluid.layers.data(name='x', shape=[10],
dtype='float32', lod_level=1)
dtype='float32', lod_level=1)
out = layers.lod_rank_table(x=x, level=0)
"""
helper = LayerHelper("lod_rank_table", **locals())
......
......@@ -22,9 +22,9 @@ from ..executor import global_scope
from layer_function_generator import generate_layer_fn, templatedoc
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'shuffle', 'batch', 'double_buffer',
'random_data_generator', 'Preprocessor', 'load'
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'Recv',
'open_recordio_file', 'open_files', 'read_file', 'shuffle', 'batch',
'double_buffer', 'random_data_generator', 'Preprocessor', 'load'
]
......@@ -177,18 +177,17 @@ class ListenAndServ(object):
})
def Send(endpoints, send_vars, get_vars=None):
def Send(endpoints, send_vars, sync=True):
"""
Send layer
Send variables to the server side, and get vars from server
side when server have finished running server side program.
Args:
endpoints: comma seperated IP:PORT pairs in the order
endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
Send variables to the server side, and get vars from server
side when server have finished running server side program.
send_vars (list): variables to send to server
sync (bool): whether to wait the request finish
"""
assert (type(send_vars) == list)
......@@ -196,40 +195,33 @@ def Send(endpoints, send_vars, get_vars=None):
endpoints = list(set(epmap))
helper = LayerHelper("Send", **locals())
if not get_vars:
get_vars = []
for s in send_vars:
v = helper.create_tmp_variable(dtype=s.dtype, stop_gradient=True)
get_vars.append(v)
rpc_op_role_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
helper.append_op(
type="send",
inputs={"X": send_vars},
outputs={"Out": get_vars},
attrs={
"endpoints": endpoints,
"epmap": epmap,
rpc_op_role_name: core.op_proto_and_checker_maker.OpRole.RPC
})
return get_vars
if sync:
helper.append_op(type="send_barrier", attrs={"endpoints": endpoints})
def Recv(endpoints, get_vars):
def Recv(endpoints, get_vars, sync=True):
"""
Recv layer
Receive variables from server side
Args:
endpoints: comma seperated IP:PORT pairs in the order
endpoints (str): comma seperated IP:PORT pairs in the order
of send_vars to send
send_vars: vars to send
get_vars: vars to get from server after send completes.
get_vars (list): vars to get from server after send completes.
sync (bool): whether to wait the request finish
Send variables to the server side, and get vars from server
side when server have finished running server side program.
Returns:
list: list of received variables
"""
assert (type(send_vars) == list)
assert (type(get_vars) == list)
epmap = endpoints.split(",")
......@@ -242,6 +234,9 @@ def Recv(endpoints, get_vars):
outputs={"Out": get_vars},
attrs={"endpoints": endpoints,
"epmap": epmap})
if sync:
helper.append_op(type="fetch_barrier", attrs={"endpoints": endpoints})
return get_vars
def monkey_patch_reader_methods(reader):
......@@ -541,6 +536,9 @@ def __create_unshared_decorated_reader__(op_type, reader, attrs, name=None):
def shuffle(reader, buffer_size):
"""
Shuffle the reader.
"""
return __create_unshared_decorated_reader__(
'create_shuffle_reader', reader, {'buffer_size': int(buffer_size)})
......
......@@ -76,7 +76,6 @@ def _generate_doc_string_(op_proto):
line_begin = ' {0}: '.format(_convert_(each_input.name))
buf.write(line_begin)
buf.write(escape_math(each_input.comment))
buf.write('\n')
if each_input.duplicable:
buf.write(" Duplicatable.")
if each_input.dispensable:
......
......@@ -225,11 +225,11 @@ def embedding(input,
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update.
is_distributed (bool): Whether to run lookup table from remote parameter server.
is_distributed(bool): Whether to run lookup table from remote parameter server.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`padding_idx < 0`, the :attr:`padding_idx` to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.VarDesc.VarType|str): The type of data : float32, float_16, int etc
......@@ -1270,14 +1270,17 @@ def conv2d(input,
act=None,
name=None):
"""
**Convlution2D Layer**
The convolution2D layer calculates the output based on the input, filter
and strides, paddings, dilations, groups parameters. Input(Input) and
Output(Output) are in NCHW format. Where N is batch size, C is the number of
and strides, paddings, dilations, groups parameters. Input and
Output are in NCHW format, where N is batch size, C is the number of
channels, H is the height of the feature, and W is the width of the feature.
The details of convolution layer, please refer UFLDL's `convolution,
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_ .
Filter is in MCHW format, where M is the number of output image channels,
C is the number of input image channels, H is the height of the filter,
and W is the width of the filter. If the groups is greater than 1,
C will equal the number of input image channels divided by the groups.
Please refer to UFLDL's `convolution
<http://ufldl.stanford.edu/tutorial/supervised/FeatureExtractionUsingConvolution/>`_
for more detials.
If bias attribution and activation type are provided, bias is added to the
output of the convolution, and the corresponding activation function is
applied to the final result.
......@@ -1288,15 +1291,14 @@ def conv2d(input,
Out = \sigma (W \\ast X + b)
In the above equation:
Where:
* :math:`X`: Input value, a tensor with NCHW format.
* :math:`W`: Filter value, a tensor with MCHW format.
* :math:`\\ast`: Convolution operation.
* :math:`b`: Bias value, a 2-D tensor with shape [M, 1].
* :math:`\\sigma`: Activation function.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be
different.
* :math:`Out`: Output value, the shape of :math:`Out` and :math:`X` may be different.
Example:
......@@ -1307,6 +1309,7 @@ def conv2d(input,
Filter shape: :math:`(C_{out}, C_{in}, H_f, W_f)`
- Output:
Output shape: :math:`(N, C_{out}, H_{out}, W_{out})`
Where
......@@ -1318,7 +1321,7 @@ def conv2d(input,
Args:
input (Variable): The input image with [N, C, H, W] format.
num_filters(int): The number of filter. It is as same as the output
num_filters(int): The number of filter. It is as same as the output
image channel.
filter_size (int|tuple|None): The filter size. If filter_size is a tuple,
it must contain two integers, (filter_size_H, filter_size_W).
......@@ -1341,7 +1344,8 @@ def conv2d(input,
bias_attr (ParamAttr): Bias parameter for the Conv2d layer. Default: None
use_cudnn (bool): Use cudnn kernel or not, it is valid only when the cudnn
library is installed. Default: True
use_mkldnn (bool): Use mkldnn kernels or not.
use_mkldnn (bool): Use mkldnn kernels or not, it is valid only when compiled
with mkldnn library. Default: False
act (str): Activation type. Default: None
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......
......@@ -201,6 +201,7 @@ def assign(input, output):
Examples:
.. code-block:: python
out = fluid.layers.create_tensor(dtype='float32')
hidden = fluid.layers.fc(input=data, size=10)
fluid.layers.assign(hidden, out)
......
......@@ -16,6 +16,7 @@ import os
import time
import unittest
from multiprocessing import Process
import signal
import numpy
......@@ -24,9 +25,6 @@ import paddle.fluid.layers as layers
class TestSendOp(unittest.TestCase):
@unittest.skip(
"This test is buggy. We cannot use time.sleep to sync processes, the connection may fail in unittest."
)
def test_send(self):
# Run init_serv in a thread
place = fluid.CPUPlace()
......@@ -35,7 +33,9 @@ class TestSendOp(unittest.TestCase):
p.daemon = True
p.start()
time.sleep(10)
self.ps_timeout = 5
self._wait_ps_ready(p.pid)
with open("/tmp/paddle.%d.port" % p.pid, "r") as fn:
selected_port = int(fn.readlines()[0])
self.init_client(place, selected_port)
......@@ -44,9 +44,23 @@ class TestSendOp(unittest.TestCase):
self.assertTrue(numpy.allclose(self.local_out, self.dist_out))
# FIXME(typhoonzero): find a way to gracefully shutdown the server.
os.system("kill -9 %d" % p.pid)
os.kill(p.pid, signal.SIGKILL)
p.join()
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
sleep_time = 0.5
while True:
assert start_left_time >= 0, "wait ps ready failed"
time.sleep(sleep_time)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error:
start_left_time -= sleep_time
def init_serv(self, place):
main = fluid.Program()
......@@ -84,7 +98,10 @@ class TestSendOp(unittest.TestCase):
dtype="float32",
persistable=False,
shape=[32, 32])
o = layers.Send("127.0.0.1:%d" % port, [x], [get_var])
fluid.initializer.Constant(value=2.3)(get_var, main.global_block())
layers.Send("127.0.0.1:%d" % port, [x])
o = layers.Recv("127.0.0.1:%d" % port, [get_var])
exe = fluid.Executor(place)
self.dist_out = exe.run(main, fetch_list=o) # o is a list
......
......@@ -57,17 +57,18 @@ class TestListenAndServOp(OpTest):
def setUp(self):
self.ps_timeout = 5
self.ip = "127.0.0.1"
self.port = "6173"
self.port = "0"
self.trainers = 1
self.trainer_id = 1
self.trainer_id = 0
def _start_pserver(self, use_cuda, sync_mode):
p = Process(
target=run_pserver,
args=(use_cuda, sync_mode, self.ip, self.port, self.trainers,
self.trainer_id))
p.daemon = True
p.start()
return p.pid
return p
def _wait_ps_ready(self, pid):
start_left_time = self.ps_timeout
......@@ -89,18 +90,20 @@ class TestListenAndServOp(OpTest):
def test_handle_signal_in_serv_op(self):
# run pserver on CPU in sync mode
pid = self._start_pserver(False, True)
self._wait_ps_ready(pid)
p1 = self._start_pserver(False, True)
self._wait_ps_ready(p1.pid)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)
os.kill(p1.pid, signal.SIGKILL)
p1.join()
# run pserver on CPU in async mode
pid = self._start_pserver(False, False)
self._wait_ps_ready(pid)
p2 = self._start_pserver(False, False)
self._wait_ps_ready(p2.pid)
# raise SIGTERM to pserver
os.kill(pid, signal.SIGTERM)
os.kill(p2.pid, signal.SIGKILL)
p2.join()
if __name__ == '__main__':
......
......@@ -173,6 +173,7 @@ class TestCRFModel(unittest.TestCase):
pe.run(feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name]))[0]
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
......@@ -181,6 +182,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_all_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
......@@ -189,6 +191,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=False, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......@@ -197,6 +200,7 @@ class TestCRFModel(unittest.TestCase):
self.check_network_convergence(
is_sparse=True, build_strategy=build_strategy, use_cuda=False)
@unittest.skip(reason="CI hangs")
def test_update_dense_parameter_reduce(self):
build_strategy = fluid.BuildStrategy()
build_strategy.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册