提交 37015fad 编写于 作者: D dangqingqing

update code

...@@ -30,7 +30,8 @@ RUN apt-get update && \ ...@@ -30,7 +30,8 @@ RUN apt-get update && \
python-numpy python-matplotlib gcc g++ \ python-numpy python-matplotlib gcc g++ \
automake locales clang-format-3.8 swig doxygen cmake \ automake locales clang-format-3.8 swig doxygen cmake \
liblapack-dev liblapacke-dev libboost-dev \ liblapack-dev liblapacke-dev libboost-dev \
clang-3.8 llvm-3.8 libclang-3.8-dev && \ clang-3.8 llvm-3.8 libclang-3.8-dev \
net-tools && \
apt-get clean -y apt-get clean -y
# Install Go # Install Go
......
...@@ -452,3 +452,12 @@ eos ...@@ -452,3 +452,12 @@ eos
--- ---
.. autoclass:: paddle.v2.layer.eos .. autoclass:: paddle.v2.layer.eos
:noindex: :noindex:
Activation with learnable parameter
===================================
prelu
--------
.. autoclass:: paddle.v2.layer.prelu
:noindex:
...@@ -632,7 +632,7 @@ void Argument::printValueString(std::ostream& stream, ...@@ -632,7 +632,7 @@ void Argument::printValueString(std::ostream& stream,
const std::string& prefix) const { const std::string& prefix) const {
std::unordered_map<std::string, std::string> out; std::unordered_map<std::string, std::string> out;
getValueString(&out); getValueString(&out);
for (auto field : {"value", "id", "sequence pos", "sub-sequence pos"}) { for (auto field : {"value", "ids", "sequence pos", "sub-sequence pos"}) {
auto it = out.find(field); auto it = out.find(field);
if (it != out.end()) { if (it != out.end()) {
stream << prefix << field << ":\n" << it->second; stream << prefix << field << ":\n" << it->second;
......
...@@ -383,20 +383,23 @@ void SocketClient::TcpClient(const std::string &serverAddr, int serverPort) { ...@@ -383,20 +383,23 @@ void SocketClient::TcpClient(const std::string &serverAddr, int serverPort) {
setOption(sockfd); setOption(sockfd);
/// Now connect to the server /// Now connect to the server
int retry_second = 0; int retry_count = 0;
int error = 0;
do { do {
error = connect(sockfd, (sockaddr *)&serv_addr, sizeof(serv_addr)); if (connect(sockfd, (sockaddr *)&serv_addr, sizeof(serv_addr)) == 0) {
if (error == ECONNREFUSED) { break;
}
if (errno == ECONNREFUSED) {
LOG(WARNING) << "connection refused by pserver, try again!"; LOG(WARNING) << "connection refused by pserver, try again!";
if (retry_second++ >= 7) { if (retry_count++ >= 7) {
LOG(FATAL) << "connection refused by pserver, maybe pserver failed!"; LOG(FATAL) << "connection refused by pserver, maybe pserver failed!";
} }
std::this_thread::sleep_for(std::chrono::seconds(1)); std::this_thread::sleep_for(std::chrono::seconds(1));
} else { } else {
PCHECK(error >= 0) << "ERROR connecting to " << serverAddr; PCHECK(errno != 0) << "ERROR connecting to " << serverAddr << ":"
<< serverPort << "errorno: " << errno;
} }
} while (error == ECONNREFUSED); } while (errno == ECONNREFUSED);
channel_.reset(new SocketChannel(sockfd, serverAddr)); channel_.reset(new SocketChannel(sockfd, serverAddr));
tcpRdma_ = F_TCP; tcpRdma_ = F_TCP;
......
...@@ -73,7 +73,6 @@ To use this from paddle_trainer, paddle_trainer should be called with ...@@ -73,7 +73,6 @@ To use this from paddle_trainer, paddle_trainer should be called with
--config_args=extension_module_name=[MODULE_NAME] --config_args=extension_module_name=[MODULE_NAME]
''' '''
import copy import copy
import logging import logging
import os import os
...@@ -1731,9 +1730,10 @@ class ParameterReluLayer(LayerBase): ...@@ -1731,9 +1730,10 @@ class ParameterReluLayer(LayerBase):
def __init__(self, name, inputs, partial_sum=1, **args): def __init__(self, name, inputs, partial_sum=1, **args):
super(ParameterReluLayer, self).__init__( super(ParameterReluLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **args) name, self.layer_type, 0, inputs=inputs, **args)
config_assert(len(self.inputs) == 1)
config_assert(self.input_layer.size % partial_sum == 0)
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
config_assert(len(self.inputs) == 1, "prelu layer has only one input.")
config_assert(input_layer.size % partial_sum == 0,
"a wrong setting for partial_sum")
self.set_layer_size(input_layer.size) self.set_layer_size(input_layer.size)
self.create_input_parameter(0, input_layer.size / partial_sum) self.create_input_parameter(0, input_layer.size / partial_sum)
......
...@@ -31,31 +31,31 @@ except ImportError: ...@@ -31,31 +31,31 @@ except ImportError:
import copy import copy
__all__ = [ __all__ = [
"full_matrix_projection", 'full_matrix_projection',
"AggregateLevel", 'AggregateLevel',
"ExpandLevel", 'ExpandLevel',
"identity_projection", 'identity_projection',
"dotmul_projection", 'dotmul_projection',
"dotmul_operator", 'dotmul_operator',
"repeat_layer", 'repeat_layer',
"seq_reshape_layer", 'seq_reshape_layer',
"table_projection", 'table_projection',
"mixed_layer", 'mixed_layer',
"data_layer", 'data_layer',
"embedding_layer", 'embedding_layer',
"fc_layer", 'fc_layer',
"grumemory", 'grumemory',
"pooling_layer", 'pooling_layer',
"lstmemory", 'lstmemory',
"last_seq", 'last_seq',
"first_seq", 'first_seq',
"cos_sim", 'cos_sim',
"hsigmoid", 'hsigmoid',
"conv_projection", 'conv_projection',
"mse_cost", 'mse_cost',
"regression_cost", 'regression_cost',
'classification_cost', 'classification_cost',
"LayerOutput", 'LayerOutput',
'img_conv_layer', 'img_conv_layer',
'img_pool_layer', 'img_pool_layer',
'batch_norm_layer', 'batch_norm_layer',
...@@ -122,6 +122,7 @@ __all__ = [ ...@@ -122,6 +122,7 @@ __all__ = [
'layer_support', 'layer_support',
'multiplex_layer', 'multiplex_layer',
'row_conv_layer', 'row_conv_layer',
'prelu_layer',
] ]
...@@ -130,26 +131,26 @@ class LayerType(object): ...@@ -130,26 +131,26 @@ class LayerType(object):
Layer type enumerations. Layer type enumerations.
""" """
DATA = "data" DATA = 'data'
MIXED_LAYER = "mixed" MIXED_LAYER = 'mixed'
LSTMEMORY = "lstmemory" LSTMEMORY = 'lstmemory'
GRUMEMORY = "gated_recurrent" GRUMEMORY = 'gated_recurrent'
SEQUENCE_LAST_INSTANCE = "seqlastins" SEQUENCE_LAST_INSTANCE = 'seqlastins'
SEQUENCE_FIRST_INSTANCE = "seqfirstins" SEQUENCE_FIRST_INSTANCE = 'seqfirstins'
SEQUENCE_RESHAPE = "seqreshape" SEQUENCE_RESHAPE = 'seqreshape'
POOLING_MAX = "max" POOLING_MAX = 'max'
POOLING_AVG = 'average' POOLING_AVG = 'average'
FC_LAYER = "fc" FC_LAYER = 'fc'
COST = 'cost' COST = 'cost'
COSINE_SIM_VEC = 'cos_vm' COSINE_SIM_VEC = 'cos_vm'
COSINE_SIM = 'cos' COSINE_SIM = 'cos'
HSIGMOID = 'hsigmoid' HSIGMOID = 'hsigmoid'
CONV_LAYER = "conv" CONV_LAYER = 'conv'
CONVTRANS_LAYER = "convt" CONVTRANS_LAYER = 'convt'
EXCONV_LAYER = "exconv" EXCONV_LAYER = 'exconv'
EXCONVTRANS_LAYER = "exconvt" EXCONVTRANS_LAYER = 'exconvt'
CUDNNCONV_LAYER = "cudnn_conv" CUDNNCONV_LAYER = 'cudnn_conv'
POOL_LAYER = "pool" POOL_LAYER = 'pool'
BATCH_NORM_LAYER = 'batch_norm' BATCH_NORM_LAYER = 'batch_norm'
NORM_LAYER = 'norm' NORM_LAYER = 'norm'
SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm' SUM_TO_ONE_NORM_LAYER = 'sum_to_one_norm'
...@@ -190,25 +191,19 @@ class LayerType(object): ...@@ -190,25 +191,19 @@ class LayerType(object):
PAD_LAYER = "pad" PAD_LAYER = "pad"
MULTIPLEX_LAYER = "multiplex" MULTIPLEX_LAYER = "multiplex"
ROW_CONV_LAYER = "row_conv" ROW_CONV_LAYER = "row_conv"
PRINT_LAYER = "print"
PRIORBOX_LAYER = "priorbox"
CTC_LAYER = "ctc"
WARP_CTC_LAYER = "warp_ctc"
CRF_LAYER = "crf"
CRF_DECODING_LAYER = "crf_decoding"
NCE_LAYER = 'nce' NCE_LAYER = 'nce'
RANK_COST = "rank-cost" RANK_COST = 'rank-cost'
LAMBDA_COST = "lambda_cost" LAMBDA_COST = 'lambda_cost'
HUBER = "huber" HUBER = 'huber'
CROSS_ENTROPY = "multi-class-cross-entropy" CROSS_ENTROPY = 'multi-class-cross-entropy'
CROSS_ENTROPY_WITH_SELFNORM = "multi_class_cross_entropy_with_selfnorm" CROSS_ENTROPY_WITH_SELFNORM = 'multi_class_cross_entropy_with_selfnorm'
SOFT_BIN_CLASS_CROSS_ENTROPY = "soft_binary_class_cross_entropy" SOFT_BIN_CLASS_CROSS_ENTROPY = 'soft_binary_class_cross_entropy'
MULTI_BIN_LABEL_CROSS_ENTROPY = "multi_binary_label_cross_entropy" MULTI_BIN_LABEL_CROSS_ENTROPY = 'multi_binary_label_cross_entropy'
SUM_COST = "sum_cost" SUM_COST = 'sum_cost'
SMOOTH_L1 = "smooth_l1" SMOOTH_L1 = 'smooth_l1'
PRELU = 'prelu'
@staticmethod @staticmethod
def is_layer_type(type_name): def is_layer_type(type_name):
...@@ -3862,7 +3857,6 @@ def classification_cost(input, ...@@ -3862,7 +3857,6 @@ def classification_cost(input,
label, label,
weight=None, weight=None,
name=None, name=None,
top_k=None,
evaluator=classification_error_evaluator, evaluator=classification_error_evaluator,
layer_attr=None): layer_attr=None):
""" """
...@@ -3877,8 +3871,6 @@ def classification_cost(input, ...@@ -3877,8 +3871,6 @@ def classification_cost(input,
:param weight: The weight affects the cost, namely the scale of cost. :param weight: The weight affects the cost, namely the scale of cost.
It is an optional argument. It is an optional argument.
:type weight: LayerOutput :type weight: LayerOutput
:param top_k: number k in top-k error rate
:type top_k: int
:param evaluator: Evaluator method. :param evaluator: Evaluator method.
:param layer_attr: layer's extra attribute. :param layer_attr: layer's extra attribute.
:type layer_attr: ExtraLayerAttribute :type layer_attr: ExtraLayerAttribute
...@@ -3906,7 +3898,7 @@ def classification_cost(input, ...@@ -3906,7 +3898,7 @@ def classification_cost(input,
assert isinstance(e.for_classification, bool) assert isinstance(e.for_classification, bool)
assert e.for_classification assert e.for_classification
e(name=e.__name__, input=input, label=label, weight=weight, top_k=top_k) e(name=e.__name__, input=input, label=label, weight=weight)
if not isinstance(evaluator, collections.Sequence): if not isinstance(evaluator, collections.Sequence):
evaluator = [evaluator] evaluator = [evaluator]
...@@ -4727,7 +4719,7 @@ def ctc_layer(input, ...@@ -4727,7 +4719,7 @@ def ctc_layer(input,
fc_layer with softmax activation, should be num_classes + 1. The size of ctc_layer fc_layer with softmax activation, should be num_classes + 1. The size of ctc_layer
should also be num_classes + 1. should also be num_classes + 1.
The simple usage: The example usage is:
.. code-block:: python .. code-block:: python
...@@ -4814,7 +4806,7 @@ def warp_ctc_layer(input, ...@@ -4814,7 +4806,7 @@ def warp_ctc_layer(input,
- As a native 'softmax' activation is interated to the warp-ctc library, - As a native 'softmax' activation is interated to the warp-ctc library,
'linear' activation is expected instead in the 'input' layer. 'linear' activation is expected instead in the 'input' layer.
The simple usage: The example usage is:
.. code-block:: python .. code-block:: python
...@@ -4875,7 +4867,7 @@ def crf_layer(input, ...@@ -4875,7 +4867,7 @@ def crf_layer(input,
A layer for calculating the cost of sequential conditional random A layer for calculating the cost of sequential conditional random
field model. field model.
The simple usage: The example usage is:
.. code-block:: python .. code-block:: python
...@@ -4949,7 +4941,7 @@ def crf_decoding_layer(input, ...@@ -4949,7 +4941,7 @@ def crf_decoding_layer(input,
this layer will also calculate error. output.value[i] is 1 for incorrect this layer will also calculate error. output.value[i] is 1 for incorrect
decoding or 0 for correct decoding. decoding or 0 for correct decoding.
The simple usage: The example usage is:
.. code-block:: python .. code-block:: python
...@@ -5142,7 +5134,7 @@ def rank_cost(left, ...@@ -5142,7 +5134,7 @@ def rank_cost(left,
- :math:`o_i` and :math:`o_j`: the left output and right output. - :math:`o_i` and :math:`o_j`: the left output and right output.
Their dimension is one. Their dimension is one.
The simple usage: The example usage is:
.. code-block:: python .. code-block:: python
...@@ -5199,7 +5191,7 @@ def lambda_cost(input, ...@@ -5199,7 +5191,7 @@ def lambda_cost(input,
""" """
lambdaCost for lambdaRank LTR approach. lambdaCost for lambdaRank LTR approach.
The simple usage: The example usage is:
.. code-block:: python .. code-block:: python
...@@ -5257,6 +5249,8 @@ def cross_entropy(input, ...@@ -5257,6 +5249,8 @@ def cross_entropy(input,
""" """
A loss layer for multi class entropy. A loss layer for multi class entropy.
The example usage is:
.. code-block:: python .. code-block:: python
cost = cross_entropy(input=input_layer, cost = cross_entropy(input=input_layer,
...@@ -5303,6 +5297,8 @@ def cross_entropy_with_selfnorm(input, ...@@ -5303,6 +5297,8 @@ def cross_entropy_with_selfnorm(input,
A loss layer for multi class entropy with selfnorm. A loss layer for multi class entropy with selfnorm.
Input should be a vector of positive numbers, without normalization. Input should be a vector of positive numbers, without normalization.
The example usage is:
.. code-block:: python .. code-block:: python
cost = cross_entropy_with_selfnorm(input=input_layer, cost = cross_entropy_with_selfnorm(input=input_layer,
...@@ -5344,6 +5340,8 @@ def sum_cost(input, name=None, layer_attr=None): ...@@ -5344,6 +5340,8 @@ def sum_cost(input, name=None, layer_attr=None):
""" """
A loss layer which calculate the sum of the input as loss A loss layer which calculate the sum of the input as loss
The example usage is:
.. code-block:: python .. code-block:: python
cost = sum_cost(input=input_layer) cost = sum_cost(input=input_layer)
...@@ -5373,6 +5371,8 @@ def huber_cost(input, label, name=None, coeff=1.0, layer_attr=None): ...@@ -5373,6 +5371,8 @@ def huber_cost(input, label, name=None, coeff=1.0, layer_attr=None):
""" """
A loss layer for huber loss. A loss layer for huber loss.
The example usage is:
.. code-block:: python .. code-block:: python
cost = huber_cost(input=input_layer, cost = huber_cost(input=input_layer,
...@@ -5413,6 +5413,8 @@ def multi_binary_label_cross_entropy(input, ...@@ -5413,6 +5413,8 @@ def multi_binary_label_cross_entropy(input,
""" """
A loss layer for multi binary label cross entropy. A loss layer for multi binary label cross entropy.
The example usage is:
.. code-block:: python .. code-block:: python
cost = multi_binary_label_cross_entropy(input=input_layer, cost = multi_binary_label_cross_entropy(input=input_layer,
...@@ -5472,6 +5474,8 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None): ...@@ -5472,6 +5474,8 @@ def smooth_l1_cost(input, label, name=None, coeff=1.0, layer_attr=None):
More details can be found by referring to `Fast R-CNN More details can be found by referring to `Fast R-CNN
<https://arxiv.org/pdf/1504.08083v2.pdf>`_ <https://arxiv.org/pdf/1504.08083v2.pdf>`_
The example usage is:
.. code-block:: python .. code-block:: python
cost = smooth_l1_cost(input=input_layer, cost = smooth_l1_cost(input=input_layer,
...@@ -5521,6 +5525,8 @@ def multiplex_layer(input, name=None, layer_attr=None): ...@@ -5521,6 +5525,8 @@ def multiplex_layer(input, name=None, layer_attr=None):
where, y is output. :math:`x_{k}` is the k-th input layer and where, y is output. :math:`x_{k}` is the k-th input layer and
:math:`k = x_{0}[i] + 1`. :math:`k = x_{0}[i] + 1`.
The example usage is:
.. code-block:: python .. code-block:: python
maxid = multiplex_layer(input=layers) maxid = multiplex_layer(input=layers)
...@@ -5627,3 +5633,63 @@ def row_conv_layer(input, ...@@ -5627,3 +5633,63 @@ def row_conv_layer(input,
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput( return LayerOutput(
name, LayerType.ROW_CONV_LAYER, input, activation=act, size=input.size) name, LayerType.ROW_CONV_LAYER, input, activation=act, size=input.size)
@layer_support()
@wrap_name_default()
@wrap_param_attr_default()
def prelu_layer(input,
name=None,
partial_sum=1,
param_attr=None,
layer_attr=None):
"""
The Parameter Relu activation that actives outputs with a learnable weight.
Reference:
Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification http://arxiv.org/pdf/1502.01852v1.pdf
.. math::
z_i &\\quad if \\quad z_i > 0 \\\\
a_i * z_i &\\quad \\mathrm{otherwise}
The example usage is:
.. code-block:: python
prelu = prelu_layer(input=layers, partial_sum=1)
:param name: Name of this layer.
:type name: basestring
:param input: The input layer.
:type input: LayerOutput
:param partial_sum: this parameter makes a group of inputs share a same weight.
- partial_sum = 1, indicates the element-wise activation: each element has a weight.
- partial_sum = number of elements in one channel, indicates the channel-wise activation, elements in a channel share a same weight.
- partial_sum = number of outputs, indicates all elements share a same weight.
:type partial_sum: int
:param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute|None
:param layer_attr: Extra layer configurations. Default is None.
:type layer_attr: ExtraLayerAttribute|None
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput), 'prelu_layer only accepts one input'
assert isinstance(param_attr, ParameterAttribute)
l = Layer(
name=name,
type=LayerType.PRELU,
inputs=Input(input.name, **param_attr.attr),
partial_sum=partial_sum,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
name=name,
layer_type=LayerType.PRELU,
parents=input,
size=l.config.size)
...@@ -5,6 +5,7 @@ last_first_seq test_expand_layer test_ntm_layers test_hsigmoid ...@@ -5,6 +5,7 @@ last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers img_layers img_trans_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight
test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops
test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_row_conv) test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer
test_prelu_layer test_row_conv)
export whole_configs=(test_split_datasource) export whole_configs=(test_split_datasource)
type: "nn"
layers {
name: "input"
type: "data"
size: 300
active_type: ""
}
layers {
name: "__prelu_layer_0__"
type: "prelu"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
input_parameter_name: "___prelu_layer_0__.w0"
}
}
parameters {
name: "___prelu_layer_0__.w0"
size: 300
initial_mean: 0.0
initial_std: 0.057735026919
initial_strategy: 0
initial_smart: true
}
input_layer_names: "input"
output_layer_names: "__prelu_layer_0__"
sub_models {
name: "root"
layer_names: "input"
layer_names: "__prelu_layer_0__"
input_layer_names: "input"
output_layer_names: "__prelu_layer_0__"
is_recurrent_layer_group: false
}
from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300)
prelu = prelu_layer(input=data)
outputs(prelu)
...@@ -12,7 +12,7 @@ from paddle.trainer.config_parser import logger ...@@ -12,7 +12,7 @@ from paddle.trainer.config_parser import logger
try: try:
import cv2 import cv2
except ImportError: except ImportError:
logger.warning("OpenCV2 is not installed, using PIL to prcoess") logger.warning("OpenCV2 is not installed, using PIL to process")
cv2 = None cv2 = None
__all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"] __all__ = ["CvTransformer", "PILTransformer", "MultiProcessImageTransformer"]
......
...@@ -11,17 +11,19 @@ packages=['paddle', ...@@ -11,17 +11,19 @@ packages=['paddle',
'paddle.v2.reader', 'paddle.v2.reader',
'paddle.v2.plot'] 'paddle.v2.plot']
setup_requires=["requests",
"numpy",
"protobuf==3.1",
"matplotlib",
"rarfile"]
if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']:
setup_requires+=["opencv-python"]
setup(name='paddle', setup(name='paddle',
version='${PADDLE_VERSION}', version='${PADDLE_VERSION}',
description='Parallel Distributed Deep Learning', description='Parallel Distributed Deep Learning',
install_requires=[ install_requires=setup_requires,
"requests",
"numpy",
"protobuf==${PROTOBUF_VERSION}",
"matplotlib",
"opencv-python",
"rarfile"
],
packages=packages, packages=packages,
package_dir={ package_dir={
'': '${CMAKE_CURRENT_SOURCE_DIR}' '': '${CMAKE_CURRENT_SOURCE_DIR}'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册