提交 abb3357d 编写于 作者: S sweetsky0901

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into my_unpool_max_2d

...@@ -12,11 +12,11 @@ Machine: ...@@ -12,11 +12,11 @@ Machine:
System: CentOS release 6.3 (Final), Docker 1.12.1. System: CentOS release 6.3 (Final), Docker 1.12.1.
PaddlePaddle: paddlepaddle/paddle:latest (TODO: will rerun after 0.11.0) PaddlePaddle: paddlepaddle/paddle:latest (for MKLML and MKL-DNN), paddlepaddle/paddle:latest-openblas (for OpenBLAS)
- MKL-DNN tag v0.11
- MKL-DNN tag v0.10 - MKLML 2018.0.1.20171007
- MKLML 2018.0.20170720
- OpenBLAS v0.2.20 - OpenBLAS v0.2.20
(TODO: will rerun after 0.11.0)
On each machine, we will test and compare the performance of training on single node using MKL-DNN / MKLML / OpenBLAS respectively. On each machine, we will test and compare the performance of training on single node using MKL-DNN / MKLML / OpenBLAS respectively.
...@@ -31,15 +31,26 @@ Input image size - 3 * 224 * 224, Time: images/second ...@@ -31,15 +31,26 @@ Input image size - 3 * 224 * 224, Time: images/second
| BatchSize | 64 | 128 | 256 | | BatchSize | 64 | 128 | 256 |
|--------------|-------| -----| --------| |--------------|-------| -----| --------|
| OpenBLAS | 7.82 | 8.62 | 10.34 | | OpenBLAS | 7.80 | 9.00 | 10.80 |
| MKLML | 11.02 | 12.86 | 15.33 | | MKLML | 12.12 | 13.70 | 16.18 |
| MKL-DNN | 27.69 | 28.8 | 29.27 | | MKL-DNN | 28.46 | 29.83 | 30.44 |
chart on batch size 128
TBD
- ResNet-50
| BatchSize | 64 | 128 | 256 |
|--------------|-------| ------| -------|
| OpenBLAS | 25.22 | 25.68 | 27.12 |
| MKLML | 32.52 | 31.89 | 33.12 |
| MKL-DNN | 81.69 | 82.35 | 84.08 |
chart on batch size 128 chart on batch size 128
TBD TBD
- ResNet
- GoogLeNet - GoogLeNet
### Laptop ### Laptop
......
...@@ -294,22 +294,8 @@ void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) { ...@@ -294,22 +294,8 @@ void MKLDNNLayer::resetMergeGrad(MKLDNNMatrixPtr& out) {
srcs.push_back(*src); srcs.push_back(*src);
} }
// TODO(TJ): remove me when mkldnn sum support different formats auto sumPD = sum::primitive_desc(out->getMemoryDesc(), scales, srcPDs);
for (size_t i = 1; i < srcPDs.size(); ++i) { mergeGrad_.reset(new sum(sumPD, srcs, *out));
CHECK(srcPDs[0] == srcPDs[i]);
}
tmpOutGrad_ = out;
tmpCvt_ = nullptr;
if (out->getPrimitiveDesc() != srcPDs[0]) {
tmpOutGrad_ = MKLDNNMatrix::create(srcPDs[0]);
tmpCvt_ = MKLDNNMatrix::createReorder(tmpOutGrad_, out);
CHECK(tmpCvt_);
pipelineMergeGrad_.push_back(*tmpCvt_);
}
auto sumPD =
sum::primitive_desc(tmpOutGrad_->getMemoryDesc(), scales, srcPDs);
mergeGrad_.reset(new sum(sumPD, srcs, *tmpOutGrad_));
pipelineMergeGrad_.insert(pipelineMergeGrad_.begin(), *mergeGrad_); pipelineMergeGrad_.insert(pipelineMergeGrad_.begin(), *mergeGrad_);
} }
......
...@@ -36,7 +36,7 @@ class MKLDNNLayer : public Layer { ...@@ -36,7 +36,7 @@ class MKLDNNLayer : public Layer {
protected: protected:
// batch size // batch size
int bs_; int bs_;
// they sizes are always from the first input layer // their sizes are always from the first input layer
// input image channel, height and width // input image channel, height and width
int ic_, ih_, iw_; int ic_, ih_, iw_;
// output image channel, height and width // output image channel, height and width
...@@ -94,11 +94,6 @@ protected: ...@@ -94,11 +94,6 @@ protected:
std::vector<mkldnn::primitive> pipelineMergeGrad_; std::vector<mkldnn::primitive> pipelineMergeGrad_;
// tmp input argument to save input grad, only used to merge grad // tmp input argument to save input grad, only used to merge grad
Argument tmpInArg_; Argument tmpInArg_;
// since mkldnn sum do not support different formats:
// can refer to https://github.com/01org/mkl-dnn/issues/134
// so need create reorder manually and save tmp MKLDNNMatrix
MKLDNNMatrixPtr tmpOutGrad_;
std::shared_ptr<mkldnn::primitive> tmpCvt_;
public: public:
explicit MKLDNNLayer(const LayerConfig& config) explicit MKLDNNLayer(const LayerConfig& config)
......
...@@ -315,7 +315,7 @@ TEST(MKLDNNLayer, AddtoLayer) { ...@@ -315,7 +315,7 @@ TEST(MKLDNNLayer, AddtoLayer) {
static void getMKLDNNConcatConfig(TestConfig& cfg, static void getMKLDNNConcatConfig(TestConfig& cfg,
const std::vector<testImageDesc>& inputs) { const std::vector<testImageDesc>& inputs) {
CHECK_GE(inputs.size(), 2) << "at least two inputs"; CHECK_GE(inputs.size(), 2UL) << "at least two inputs";
int oc = inputs[0].ic; int oc = inputs[0].ic;
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i].bs, inputs[0].bs); CHECK_EQ(inputs[i].bs, inputs[0].bs);
......
...@@ -139,7 +139,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) { ...@@ -139,7 +139,7 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
items->reserve(framework::product(ids.dims())); items->reserve(framework::product(ids.dims()));
for (size_t offset = abs_lod[lod_level_][sent_offset_]; for (size_t offset = abs_lod[lod_level_][sent_offset_];
offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) { offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) {
for (int d = 0; d < instance_dim; d++) { for (size_t d = 0; d < instance_dim; d++) {
const size_t dim_offset = offset * instance_dim + d; const size_t dim_offset = offset * instance_dim + d;
items->emplace_back(offset, ids_data[dim_offset], items->emplace_back(offset, ids_data[dim_offset],
scores_data[dim_offset]); scores_data[dim_offset]);
......
...@@ -138,7 +138,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper>& config, ...@@ -138,7 +138,7 @@ void Trainer::init(const std::shared_ptr<TrainerConfigHelper>& config,
} }
if (FLAGS_use_mkldnn) { if (FLAGS_use_mkldnn) {
CHECK_EQ(FLAGS_trainer_count, 1UL) << "MKLDNN only need 1 trainer"; CHECK_EQ(FLAGS_trainer_count, 1) << "MKLDNN only need 1 trainer";
} }
if (testing) { if (testing) {
......
...@@ -2037,13 +2037,20 @@ class ParameterReluLayer(LayerBase): ...@@ -2037,13 +2037,20 @@ class ParameterReluLayer(LayerBase):
def __init__(self, name, inputs, partial_sum=1, **args): def __init__(self, name, inputs, partial_sum=1, **args):
super(ParameterReluLayer, self).__init__( super(ParameterReluLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **args) name, self.layer_type, 0, inputs=inputs, **args)
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
config_assert(len(self.inputs) == 1, "prelu layer has only one input.") config_assert(len(self.inputs) == 1, "prelu layer has only one input.")
config_assert(input_layer.size % partial_sum == 0, config_assert(input_layer.size % partial_sum == 0,
"a wrong setting for partial_sum") "a wrong setting for partial_sum")
dims = [1, input_layer.size / partial_sum]
self.set_layer_size(input_layer.size) self.set_layer_size(input_layer.size)
self.config.partial_sum = partial_sum self.config.partial_sum = partial_sum
self.create_input_parameter(0, input_layer.size / partial_sum) self.create_input_parameter(0, input_layer.size / partial_sum, dims)
self.set_layer_height_width(self.get_input_layer(0).height, \
self.get_input_layer(0).width)
self.set_layer_depth(self.get_input_layer(0).depth)
@config_layer('conv') @config_layer('conv')
......
...@@ -297,7 +297,7 @@ def auc_evaluator( ...@@ -297,7 +297,7 @@ def auc_evaluator(
def pnpair_evaluator( def pnpair_evaluator(
input, input,
label, label,
info, query_id,
weight=None, weight=None,
name=None, ): name=None, ):
""" """
...@@ -308,16 +308,20 @@ def pnpair_evaluator( ...@@ -308,16 +308,20 @@ def pnpair_evaluator(
.. code-block:: python .. code-block:: python
eval = pnpair_evaluator(input, label, info) eval = pnpair_evaluator(input, label, query_id)
:param input: Input Layer name. The output prediction of network. :param input: Input Layer name. The output prediction of network.
:type input: LayerOutput :type input: LayerOutput
:param label: Label layer name. :param label: Label layer name.
:type label: LayerOutput :type label: LayerOutput
:param info: Info layer name. (TODO, explaination) :param query_id: Query_id layer name. Query_id indicates that which query
:type info: LayerOutput each sample belongs to. Its shape should be
the same as output of Label layer.
:type query_id: LayerOutput
:param weight: Weight Layer name. It should be a matrix with size :param weight: Weight Layer name. It should be a matrix with size
[sample_num, 1]. (TODO, explaination) [sample_num, 1] which indicates the weight of each sample.
The default weight of sample is 1 if the weight layer is None.
And the pair weight is the mean of the two samples' weight.
:type weight: LayerOutput :type weight: LayerOutput
:param name: Evaluator name. :param name: Evaluator name.
:type name: None|basestring :type name: None|basestring
...@@ -326,8 +330,8 @@ def pnpair_evaluator( ...@@ -326,8 +330,8 @@ def pnpair_evaluator(
input = [input] input = [input]
if label: if label:
input.append(label) input.append(label)
if info: if query_id:
input.append(info) input.append(query_id)
evaluator_base( evaluator_base(
input=input, input=input,
type="pnpair", type="pnpair",
......
...@@ -6604,10 +6604,11 @@ def row_conv_layer(input, ...@@ -6604,10 +6604,11 @@ def row_conv_layer(input,
@layer_support() @layer_support()
@wrap_name_default() @wrap_name_default()
@wrap_param_attr_default()
def prelu_layer(input, def prelu_layer(input,
name=None, name=None,
partial_sum=1, partial_sum=1,
channel_shared=None,
num_channels=None,
param_attr=None, param_attr=None,
layer_attr=None): layer_attr=None):
""" """
...@@ -6638,6 +6639,14 @@ def prelu_layer(input, ...@@ -6638,6 +6639,14 @@ def prelu_layer(input,
- partial_sum = number of outputs, indicates all elements share the same weight. - partial_sum = number of outputs, indicates all elements share the same weight.
:type partial_sum: int :type partial_sum: int
:param channel_shared: whether or not the parameter are shared across channels.
- channel_shared = True, we set the partial_sum to the number of outputs.
- channel_shared = False, we set the partial_sum to the number of elements in one channel.
:type channel_shared: bool
:param num_channels: number of input channel.
:type num_channels: int
:param param_attr: The parameter attribute. See ParameterAttribute for details. :param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute :type param_attr: ParameterAttribute
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
...@@ -6648,7 +6657,25 @@ def prelu_layer(input, ...@@ -6648,7 +6657,25 @@ def prelu_layer(input,
""" """
assert isinstance(input, LayerOutput), 'prelu_layer accepts only one input.' assert isinstance(input, LayerOutput), 'prelu_layer accepts only one input.'
assert isinstance(param_attr, ParameterAttribute)
if not param_attr:
param_attr = ParamAttr(initial_mean=0.25, initial_std=0.0)
else:
assert isinstance(param_attr, ParameterAttribute)
if num_channels is None:
assert input.num_filters is not None, \
'the input channel cannot be detected, please specify the num_channels parameter'
num_channels = input.num_filters
if channel_shared is not None:
assert isinstance(channel_shared, bool)
assert (input.height != 0 and input.width != 0), \
'input height and widht must be setted'
if channel_shared:
partial_sum = input.height * input.width * num_channels
else:
partial_sum = input.height * input.width
l = Layer( l = Layer(
name=name, name=name,
...@@ -6660,6 +6687,7 @@ def prelu_layer(input, ...@@ -6660,6 +6687,7 @@ def prelu_layer(input,
name=name, name=name,
layer_type=LayerType.PRELU, layer_type=LayerType.PRELU,
parents=input, parents=input,
num_filters=num_channels,
size=l.config.size) size=l.config.size)
......
...@@ -4,6 +4,8 @@ layers { ...@@ -4,6 +4,8 @@ layers {
type: "data" type: "data"
size: 300 size: 300
active_type: "" active_type: ""
height: 10
width: 10
} }
layers { layers {
name: "__prelu_layer_0__" name: "__prelu_layer_0__"
...@@ -15,6 +17,9 @@ layers { ...@@ -15,6 +17,9 @@ layers {
input_parameter_name: "___prelu_layer_0__.w0" input_parameter_name: "___prelu_layer_0__.w0"
} }
partial_sum: 1 partial_sum: 1
height: 10
width: 10
depth: 1
} }
layers { layers {
name: "__prelu_layer_1__" name: "__prelu_layer_1__"
...@@ -26,6 +31,9 @@ layers { ...@@ -26,6 +31,9 @@ layers {
input_parameter_name: "___prelu_layer_1__.w0" input_parameter_name: "___prelu_layer_1__.w0"
} }
partial_sum: 1 partial_sum: 1
height: 10
width: 10
depth: 1
} }
layers { layers {
name: "__prelu_layer_2__" name: "__prelu_layer_2__"
...@@ -37,41 +45,100 @@ layers { ...@@ -37,41 +45,100 @@ layers {
input_parameter_name: "___prelu_layer_2__.w0" input_parameter_name: "___prelu_layer_2__.w0"
} }
partial_sum: 5 partial_sum: 5
height: 10
width: 10
depth: 1
}
layers {
name: "__prelu_layer_3__"
type: "prelu"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
input_parameter_name: "___prelu_layer_3__.w0"
}
partial_sum: 300
height: 10
width: 10
depth: 1
}
layers {
name: "__prelu_layer_4__"
type: "prelu"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
input_parameter_name: "___prelu_layer_4__.w0"
}
partial_sum: 100
height: 10
width: 10
depth: 1
} }
parameters { parameters {
name: "___prelu_layer_0__.w0" name: "___prelu_layer_0__.w0"
size: 300 size: 300
initial_mean: 0.0 initial_mean: 0.25
initial_std: 0.057735026919 initial_std: 0.0
dims: 1
dims: 300
initial_strategy: 0 initial_strategy: 0
initial_smart: true initial_smart: false
} }
parameters { parameters {
name: "___prelu_layer_1__.w0" name: "___prelu_layer_1__.w0"
size: 300 size: 300
initial_mean: 0.0 initial_mean: 0.25
initial_std: 0.057735026919 initial_std: 0.0
dims: 1
dims: 300
initial_strategy: 0 initial_strategy: 0
initial_smart: true initial_smart: false
} }
parameters { parameters {
name: "___prelu_layer_2__.w0" name: "___prelu_layer_2__.w0"
size: 60 size: 60
initial_mean: 0.0 initial_mean: 0.25
initial_std: 0.129099444874 initial_std: 0.0
dims: 1
dims: 60
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___prelu_layer_3__.w0"
size: 1
initial_mean: 0.25
initial_std: 0.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___prelu_layer_4__.w0"
size: 3
initial_mean: 0.25
initial_std: 0.0
dims: 1
dims: 3
initial_strategy: 0 initial_strategy: 0
initial_smart: true initial_smart: false
} }
input_layer_names: "input" input_layer_names: "input"
output_layer_names: "__prelu_layer_2__" output_layer_names: "__prelu_layer_4__"
sub_models { sub_models {
name: "root" name: "root"
layer_names: "input" layer_names: "input"
layer_names: "__prelu_layer_0__" layer_names: "__prelu_layer_0__"
layer_names: "__prelu_layer_1__" layer_names: "__prelu_layer_1__"
layer_names: "__prelu_layer_2__" layer_names: "__prelu_layer_2__"
layer_names: "__prelu_layer_3__"
layer_names: "__prelu_layer_4__"
input_layer_names: "input" input_layer_names: "input"
output_layer_names: "__prelu_layer_2__" output_layer_names: "__prelu_layer_4__"
is_recurrent_layer_group: false is_recurrent_layer_group: false
} }
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300) data = data_layer(name='input', size=300, height=10, width=10)
prelu = prelu_layer(input=data) prelu = prelu_layer(input=data, num_channels=3)
prelu = prelu_layer(input=data, partial_sum=1) prelu = prelu_layer(input=data, partial_sum=1, num_channels=3)
prelu = prelu_layer(input=data, partial_sum=5) prelu = prelu_layer(input=data, partial_sum=5, num_channels=3)
prelu = prelu_layer(input=data, channel_shared=True, num_channels=3)
prelu = prelu_layer(input=data, channel_shared=False, num_channels=3)
outputs(prelu) outputs(prelu)
...@@ -62,21 +62,15 @@ __all__ = [ ...@@ -62,21 +62,15 @@ __all__ = [
cp.begin_parse() cp.begin_parse()
def init(**kwargs): def set_omp_mkl_env_vars(trainer_count):
import py_paddle.swig_paddle as api '''Auto set CPU environment if have not set before.
args = [] export KMP_AFFINITY, OMP_DYNAMIC according to the Hyper Threading status.
args_dict = {} export OMP_NUM_THREADS, MKL_NUM_THREADS according to trainer_count.
# NOTE: append arguments if they are in ENV '''
for ek, ev in os.environ.iteritems(): import platform
if ek.startswith("PADDLE_INIT_"): if not platform.system() in ['Linux', 'Darwin']:
args_dict[ek.replace("PADDLE_INIT_", "").lower()] = str(ev) return
args_dict.update(kwargs)
# NOTE: overwrite arguments from ENV if it is in kwargs
for key in args_dict.keys():
args.append('--%s=%s' % (key, str(args_dict[key])))
# auto set cpu environment
def set_env(key, value): def set_env(key, value):
'''If the key has not been set in the environment, set it with value.''' '''If the key has not been set in the environment, set it with value.'''
assert isinstance(key, str) assert isinstance(key, str)
...@@ -85,22 +79,59 @@ def init(**kwargs): ...@@ -85,22 +79,59 @@ def init(**kwargs):
if envset is None: if envset is None:
os.environ[key] = value os.environ[key] = value
ht = os.popen("lscpu |grep \"per core\"|awk -F':' '{print $2}'|xargs") def num_physical_cores():
ht = int(ht.read()) '''Get the number of physical cores'''
if ht == 1: # ht is off if platform.system() == "Linux":
set_env("OMP_DYNAMIC", "false") num_sockets = int(
set_env("KMP_AFFINITY", "granularity=fine,compact,0,0") os.popen("lscpu |grep \"Socket\" |awk -F':' '{print $2}'|xargs")
else: .read())
num_cores_per_socket = int(
os.popen(
"lscpu |grep \"per socket\" |awk -F':' '{print $2}'|xargs")
.read())
return num_sockets * num_cores_per_socket
else:
cmds = {"Darwin": "sysctl hw.physicalcpu"}
return int(os.popen(cmds.get(platform.system(), "expr 1")).read())
def num_logical_processors():
'''Get the number of logical processors'''
cmds = {
"Linux": "grep \"processor\" /proc/cpuinfo|sort -u|wc -l",
"Darwin": "sysctl hw.logicalcpu"
}
return int(os.popen(cmds.get(platform.system(), "expr 1")).read())
num_cores = num_physical_cores()
num_processors = num_logical_processors()
if num_processors > num_cores: # Hyper Threading is enabled
set_env("OMP_DYNAMIC", "true") set_env("OMP_DYNAMIC", "true")
set_env("KMP_AFFINITY", "granularity=fine,compact,1,0") set_env("KMP_AFFINITY", "granularity=fine,compact,1,0")
processors = os.popen("grep \"processor\" /proc/cpuinfo|sort -u|wc -l") else:
processors = int(processors.read()) set_env("OMP_DYNAMIC", "false")
trainers = kwargs.get('trainer_count', 1) set_env("KMP_AFFINITY", "granularity=fine,compact,0,0")
threads = processors / trainers threads = num_processors / trainer_count
threads = '1' if threads < 1 else str(threads) threads = '1' if threads < 1 else str(threads)
set_env("OMP_NUM_THREADS", threads) set_env("OMP_NUM_THREADS", threads)
set_env("MKL_NUM_THREADS", threads) set_env("MKL_NUM_THREADS", threads)
def init(**kwargs):
import py_paddle.swig_paddle as api
args = []
args_dict = {}
# NOTE: append arguments if they are in ENV
for ek, ev in os.environ.iteritems():
if ek.startswith("PADDLE_INIT_"):
args_dict[ek.replace("PADDLE_INIT_", "").lower()] = str(ev)
args_dict.update(kwargs)
# NOTE: overwrite arguments from ENV if it is in kwargs
for key in args_dict.keys():
args.append('--%s=%s' % (key, str(args_dict[key])))
set_omp_mkl_env_vars(kwargs.get('trainer_count', 1))
if 'use_gpu' in kwargs: if 'use_gpu' in kwargs:
cp.g_command_config_args['use_gpu'] = kwargs['use_gpu'] cp.g_command_config_args['use_gpu'] = kwargs['use_gpu']
if 'use_mkldnn' in kwargs: if 'use_mkldnn' in kwargs:
......
...@@ -285,3 +285,86 @@ class XavierInitializer(Initializer): ...@@ -285,3 +285,86 @@ class XavierInitializer(Initializer):
}) })
var.op = op var.op = op
return op return op
class MSRAInitializer(Initializer):
"""Implements the MSRA initializer a.k.a. Kaiming Initializer
This class implements the weight initialization from the paper
Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification[1] by Kaiming He, Xiangyu Zhang, Shaoqing Ren
and Jian Sun. This is a robust initialization method that particularly
considers the rectifier nonlinearities. In case of Uniform distribution,
the range is [-x, x], where x = sqrt(6 / fan_in). In case of Normal
distribution, the mean is 0 and the standard deviation
is sqrt(2/ fan_in).
References:
[1] Delving Deep into Rectifiers: Surpassing Human-Level Performance
on ImageNet Classification
(https://arxiv.org/abs/1502.01852)
"""
def __init__(self, uniform=True, fan_in=None, seed=0):
"""Constructor for MSRAInitializer
Args:
uniform: whether to use uniform or normal distribution
fan_in: fan_in for MSRAInitializer. If None, it is
inferred from the variable.
seed: random seed
Note: It is recommended to set fan_in to None for most cases.
"""
assert uniform is not None
assert seed is not None
super(MSRAInitializer, self).__init__()
self._uniform = uniform
self._fan_in = fan_in
self._seed = seed
def __call__(self, var, block):
"""Add MSRA initialization ops for a variable
Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added
Returns:
the initialization op
"""
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var)
# If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in
if self._uniform:
limit = np.sqrt(6.0 / float(fan_in))
op = block.prepend_op(
type="uniform_random",
outputs={"Out": var},
attrs={
"shape": var.shape,
"data_type": int(var.data_type),
"min": -limit,
"max": limit,
"seed": self._seed
})
else:
std = np.sqrt(2.0 / float(fan_in))
op = block.prepend_op(
type="gaussian_random",
outputs={"Out": var},
attrs={
"shape": var.shape,
"data_type": int(var.data_type),
"mean": 0.0,
"std": std,
"seed": self._seed
})
var.op = op
return op
...@@ -17,13 +17,13 @@ __all__ = [ ...@@ -17,13 +17,13 @@ __all__ = [
def fc(input, def fc(input,
size, size,
num_flatten_dims=1,
param_attr=None, param_attr=None,
param_initializer=None, param_initializer=None,
bias_attr=None, bias_attr=None,
bias_initializer=None, bias_initializer=None,
name=None,
act=None, act=None,
num_flatten_dims=1, name=None,
main_program=None, main_program=None,
startup_program=None): startup_program=None):
""" """
...@@ -32,15 +32,15 @@ def fc(input, ...@@ -32,15 +32,15 @@ def fc(input,
Args: Args:
input: The input tensor to the function input: The input tensor to the function
size: The size of the layer size: The size of the layer
num_flatten_dims: Number of columns in input
param_attr: The parameters/weights to the FC Layer param_attr: The parameters/weights to the FC Layer
param_initializer: Initializer used for the weight/parameter. param_initializer: Initializer used for the weight/parameter.
If None, XavierInitializer() is used If None, XavierInitializer() is used
bias_attr: The bias parameter for the FC layer bias_attr: The bias parameter for the FC layer
bias_initializer: Initializer used for the bias. bias_initializer: Initializer used for the bias.
If None, then ConstantInitializer() is used If None, then ConstantInitializer() is used
name: Name/alias of the function
act: Activation to be applied to the output of FC layer act: Activation to be applied to the output of FC layer
num_flatten_dims: Number of columns in input name: Name/alias of the function
main_program: Name of the main program that calls this main_program: Name of the main program that calls this
startup_program: Name of the startup program startup_program: Name of the startup program
...@@ -111,9 +111,9 @@ def fc(input, ...@@ -111,9 +111,9 @@ def fc(input,
def embedding(input, def embedding(input,
size, size,
data_type='float32',
is_sparse=False, is_sparse=False,
param_attr=None, param_attr=None,
data_type='float32',
main_program=None, main_program=None,
startup_program=None): startup_program=None):
""" """
...@@ -122,9 +122,9 @@ def embedding(input, ...@@ -122,9 +122,9 @@ def embedding(input,
Args: Args:
input: The input to the function input: The input to the function
size: The size of the layer size: The size of the layer
data_type: The type of data : float32, float_16, int etc
is_sparse: A flag that decleares whether the input is sparse is_sparse: A flag that decleares whether the input is sparse
param_attr: Parameters for this layer param_attr: Parameters for this layer
data_type: The type of data : float32, float_16, int etc
main_program: Name of the main program that calls this main_program: Name of the main program that calls this
startup_program: Name of the startup program startup_program: Name of the startup program
...@@ -152,7 +152,6 @@ def embedding(input, ...@@ -152,7 +152,6 @@ def embedding(input,
# TODO(qijun): expose H0 and C0 # TODO(qijun): expose H0 and C0
def dynamic_lstm(input, def dynamic_lstm(input,
size, size,
data_type='float32',
param_attr=None, param_attr=None,
bias_attr=None, bias_attr=None,
use_peepholes=True, use_peepholes=True,
...@@ -160,6 +159,7 @@ def dynamic_lstm(input, ...@@ -160,6 +159,7 @@ def dynamic_lstm(input,
gate_activation='sigmoid', gate_activation='sigmoid',
cell_activation='tanh', cell_activation='tanh',
candidate_activation='tanh', candidate_activation='tanh',
data_type='float32',
main_program=None, main_program=None,
startup_program=None): startup_program=None):
helper = LayerHelper('lstm', **locals()) helper = LayerHelper('lstm', **locals())
...@@ -200,9 +200,9 @@ def dynamic_lstm(input, ...@@ -200,9 +200,9 @@ def dynamic_lstm(input,
def data(name, def data(name,
shape, shape,
append_batch_size=True,
data_type='float32', data_type='float32',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
append_batch_size=True,
main_program=None, main_program=None,
startup_program=None, startup_program=None,
stop_gradient=True): stop_gradient=True):
...@@ -212,9 +212,9 @@ def data(name, ...@@ -212,9 +212,9 @@ def data(name,
Args: Args:
name: The name/alias of the function name: The name/alias of the function
shape: Tuple declaring the shape. shape: Tuple declaring the shape.
append_batch_size: Whether or not to append the data as a batch.
data_type: The type of data : float32, float_16, int etc data_type: The type of data : float32, float_16, int etc
type: The output type. By default it is LOD_TENSOR. type: The output type. By default it is LOD_TENSOR.
append_batch_size: Whether or not to append the data as a batch.
main_program: Name of the main program that calls this main_program: Name of the main program that calls this
startup_program: Name of the startup program startup_program: Name of the startup program
stop_gradient: A boolean that mentions whether gradient should flow. stop_gradient: A boolean that mentions whether gradient should flow.
...@@ -600,12 +600,12 @@ def sequence_conv(input, ...@@ -600,12 +600,12 @@ def sequence_conv(input,
num_filters, num_filters,
filter_size=3, filter_size=3,
filter_stride=1, filter_stride=1,
act=None,
padding=None, padding=None,
bias_attr=None, bias_attr=None,
bias_initializer=None, bias_initializer=None,
param_attr=None, param_attr=None,
param_initializer=None, param_initializer=None,
act=None,
main_program=None, main_program=None,
startup_program=None): startup_program=None):
""" """
...@@ -658,16 +658,16 @@ def sequence_conv(input, ...@@ -658,16 +658,16 @@ def sequence_conv(input,
def conv2d(input, def conv2d(input,
num_filters, num_filters,
name=None, filter_size,
filter_size=[1, 1],
act=None,
groups=None,
stride=[1, 1], stride=[1, 1],
padding=None, padding=None,
bias_attr=None, groups=None,
bias_initializer=None,
param_attr=None, param_attr=None,
param_initializer=None, param_initializer=None,
bias_attr=None,
bias_initializer=None,
act=None,
name=None,
main_program=None, main_program=None,
startup_program=None): startup_program=None):
""" """
......
...@@ -54,17 +54,17 @@ def to_lodtensor(data, place): ...@@ -54,17 +54,17 @@ def to_lodtensor(data, place):
return res return res
def chop_data(data, chop_len=80, batch_len=50): def chop_data(data, chop_len=80, batch_size=50):
data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len] data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len]
return data[:batch_len] return data[:batch_size]
def prepare_feed_data(data, place): def prepare_feed_data(data, place):
tensor_words = to_lodtensor(map(lambda x: x[0], data), place) tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
label = np.array(map(lambda x: x[1], data)).astype("int64") label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([50, 1]) label = label.reshape([len(label), 1])
tensor_label = core.LoDTensor() tensor_label = core.LoDTensor()
tensor_label.set(label, place) tensor_label.set(label, place)
...@@ -72,33 +72,41 @@ def prepare_feed_data(data, place): ...@@ -72,33 +72,41 @@ def prepare_feed_data(data, place):
def main(): def main():
word_dict = paddle.dataset.imdb.word_dict() BATCH_SIZE = 100
cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2) PASS_NUM = 5
batch_size = 100 word_dict = paddle.dataset.imdb.word_dict()
train_data = paddle.batch( print "load word dict successfully"
paddle.reader.buffered( dict_dim = len(word_dict)
paddle.dataset.imdb.train(word_dict), size=batch_size * 10), class_dim = 2
batch_size=batch_size)
data = chop_data(next(train_data())) cost, acc = lstm_net(dict_dim=dict_dim, class_dim=class_dim)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10),
batch_size=BATCH_SIZE)
place = core.CPUPlace() place = core.CPUPlace()
tensor_words, tensor_label = prepare_feed_data(data, place)
exe = Executor(place) exe = Executor(place)
exe.run(framework.default_startup_program()) exe.run(framework.default_startup_program())
while True: for pass_id in xrange(PASS_NUM):
outs = exe.run(framework.default_main_program(), for data in train_data():
feed={"words": tensor_words, chopped_data = chop_data(data)
"label": tensor_label}, tensor_words, tensor_label = prepare_feed_data(chopped_data, place)
fetch_list=[cost, acc])
cost_val = np.array(outs[0]) outs = exe.run(framework.default_main_program(),
acc_val = np.array(outs[1]) feed={"words": tensor_words,
"label": tensor_label},
print("cost=" + str(cost_val) + " acc=" + str(acc_val)) fetch_list=[cost, acc])
if acc_val > 0.9: cost_val = np.array(outs[0])
break acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if acc_val > 0.7:
exit(0)
exit(1)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -223,5 +223,109 @@ class TestXavierInitializer(unittest.TestCase): ...@@ -223,5 +223,109 @@ class TestXavierInitializer(unittest.TestCase):
self.assertEqual(init_op.attr('seed'), 134) self.assertEqual(init_op.attr('seed'), 134)
class TestMSRAInitializer(unittest.TestCase):
def test_uniform_msra_initializer(self):
"""Test MSRA initializer with uniform distribution on
for matrix multiply.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
limit = np.sqrt(6.0 / param.shape[0])
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_uniform_msra_initializer_conv(self):
"""Test MSRA initializer with uniform distribution on
for convolutions.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
receptive_field_size = float(15 * 20)
limit = np.sqrt(6.0 / (param.shape[1] * receptive_field_size))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_normal_msra_initializer(self):
"""Test MSRA initializer with normal distribution on
for matrix multiply.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer(uniform=False))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
std = np.sqrt(2.0 / param.shape[0])
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_normal_msra_initializer_conv(self):
"""Test MSRA initializer with normal distribution on
for convolutions.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer(uniform=False))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
receptive_field_size = float(15 * 20)
std = np.sqrt(2.0 / (param.shape[1] * receptive_field_size))
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)
def test_msra_initializer_supplied_arguments(self):
"""Test the MSRA initializer with supplied arguments
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer(
fan_in=12, seed=134))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
limit = np.sqrt(6.0 / 12)
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 134)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册