提交 4b94494a 编写于 作者: L liaogang

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into external

# External dependency to Google protobuf.
http_archive(
name="protobuf",
url="http://github.com/google/protobuf/archive/v3.1.0.tar.gz",
sha256="0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7",
strip_prefix="protobuf-3.1.0")
# External dependency to gtest 1.7.0. This method comes from
# https://www.bazel.io/versions/master/docs/tutorial/cpp.html.
new_http_archive(
name="gtest",
url="https://github.com/google/googletest/archive/release-1.7.0.zip",
sha256="b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0",
build_file="third_party/gtest.BUILD",
strip_prefix="googletest-release-1.7.0")
# External dependency to gflags. This method comes from
# https://github.com/gflags/example/blob/master/WORKSPACE.
new_git_repository(
name="gflags",
tag="v2.2.0",
remote="https://github.com/gflags/gflags.git",
build_file="third_party/gflags.BUILD")
# External dependency to glog. This method comes from
# https://github.com/reyoung/bazel_playground/blob/master/WORKSPACE
new_git_repository(
name="glog",
remote="https://github.com/google/glog.git",
commit="b6a5e0524c28178985f0d228e9eaa43808dbec3c",
build_file="third_party/glog.BUILD")
......@@ -4,3 +4,4 @@ mnist_vgg_model
plot.png
train.log
*pyc
.ipynb_checkpoints
"""
A very basic example for how to use current Raw SWIG API to train mnist network.
Current implementation uses Raw SWIG, which means the API call is directly \
passed to C++ side of Paddle.
The user api could be simpler and carefully designed.
"""
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
import numpy as np
import random
from mnist_util import read_from_mnist
from paddle.trainer_config_helpers import *
def optimizer_config():
settings(
learning_rate=1e-4,
learning_method=AdamOptimizer(),
batch_size=1000,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5))
def network_config():
imgs = data_layer(name='pixel', size=784)
hidden1 = fc_layer(input=imgs, size=200)
hidden2 = fc_layer(input=hidden1, size=200)
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
cost = classification_cost(
input=inference, label=data_layer(
name='label', size=10))
outputs(cost)
def init_parameter(network):
assert isinstance(network, api.GradientMachine)
for each_param in network.getParameters():
assert isinstance(each_param, api.Parameter)
array_size = len(each_param)
array = np.random.uniform(-1.0, 1.0, array_size).astype('float32')
each_param.getBuf(api.PARAMETER_VALUE).copyFromNumpyArray(array)
def generator_to_batch(generator, batch_size):
ret_val = list()
for each_item in generator:
ret_val.append(each_item)
if len(ret_val) == batch_size:
yield ret_val
ret_val = list()
if len(ret_val) != 0:
yield ret_val
class BatchPool(object):
def __init__(self, generator, batch_size):
self.data = list(generator)
self.batch_size = batch_size
def __call__(self):
random.shuffle(self.data)
for offset in xrange(0, len(self.data), self.batch_size):
limit = min(offset + self.batch_size, len(self.data))
yield self.data[offset:limit]
def input_order_converter(generator):
for each_item in generator:
yield each_item['pixel'], each_item['label']
def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
# get enable_types for each optimizer.
# enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers.
opt_config_proto = parse_optimizer_config(optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
# Create Simple Gradient Machine.
model_config = parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto(
model_config, api.CREATE_MODE_NORMAL, enable_types)
# This type check is not useful. Only enable type hint in IDE.
# Such as PyCharm
assert isinstance(m, api.GradientMachine)
# Initialize Parameter by numpy.
init_parameter(network=m)
# Create Local Updater. Local means not run in cluster.
# For a cluster training, here we can change to createRemoteUpdater
# in future.
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
# Initialize ParameterUpdater.
updater.init(m)
# DataProvider Converter is a utility convert Python Object to Paddle C++
# Input. The input format is as same as Paddle's DataProvider.
converter = DataProviderConverter(
input_types=[dp.dense_vector(784), dp.integer_value(10)])
train_file = './data/raw_data/train'
test_file = './data/raw_data/t10k'
# start gradient machine.
# the gradient machine must be started before invoke forward/backward.
# not just for training, but also for inference.
m.start()
# evaluator can print error rate, etc. It is a C++ class.
batch_evaluator = m.makeEvaluator()
test_evaluator = m.makeEvaluator()
# Get Train Data.
# TrainData will stored in a data pool. Currently implementation is not care
# about memory, speed. Just a very naive implementation.
train_data_generator = input_order_converter(read_from_mnist(train_file))
train_data = BatchPool(train_data_generator, 512)
# outArgs is Neural Network forward result. Here is not useful, just passed
# to gradient_machine.forward
outArgs = api.Arguments.createArguments(0)
for pass_id in xrange(2): # we train 2 passes.
updater.startPass()
for batch_id, data_batch in enumerate(train_data()):
# data_batch is input images.
# here, for online learning, we could get data_batch from network.
# Start update one batch.
pass_type = updater.startBatch(len(data_batch))
# Start BatchEvaluator.
# batch_evaluator can be used between start/finish.
batch_evaluator.start()
# forwardBackward is a shortcut for forward and backward.
# It is sometimes faster than invoke forward/backward separately,
# because in GradientMachine, it may be async.
m.forwardBackward(converter(data_batch), outArgs, pass_type)
for each_param in m.getParameters():
updater.update(each_param)
# Get cost. We use numpy to calculate total cost for this batch.
cost_vec = outArgs.getSlotValue(0)
cost_vec = cost_vec.copyToNumpyMat()
cost = cost_vec.sum() / len(data_batch)
# Make evaluator works.
m.eval(batch_evaluator)
# Print logs.
print 'Pass id', pass_id, 'Batch id', batch_id, 'with cost=', \
cost, batch_evaluator
batch_evaluator.finish()
# Finish batch.
# * will clear gradient.
# * ensure all values should be updated.
updater.finishBatch(cost)
# testing stage. use test data set to test current network.
updater.apply()
test_evaluator.start()
test_data_generator = input_order_converter(read_from_mnist(test_file))
for data_batch in generator_to_batch(test_data_generator, 512):
# in testing stage, only forward is needed.
m.forward(converter(data_batch), outArgs, api.PASS_TEST)
m.eval(test_evaluator)
# print error rate for test data set
print 'Pass', pass_id, ' test evaluator: ', test_evaluator
test_evaluator.finish()
updater.restore()
updater.catchUpWith()
params = m.getParameters()
for each_param in params:
assert isinstance(each_param, api.Parameter)
value = each_param.getBuf(api.PARAMETER_VALUE)
value = value.copyToNumpyArray()
# Here, we could save parameter to every where you want
print each_param.getName(), value
updater.finishPass()
m.finish()
if __name__ == '__main__':
main()
from paddle.trainer.PyDataProvider2 import *
import numpy
from mnist_util import read_from_mnist
# Define a py data provider
......@@ -8,27 +8,5 @@ import numpy
'label': integer_value(10)},
cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, filename): # settings is not used currently.
imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte"
f = open(imgf, "rb")
l = open(labelf, "rb")
f.read(16)
l.read(8)
# Define number of samples for train/test
if "train" in filename:
n = 60000
else:
n = 10000
images = numpy.fromfile(
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
f.close()
l.close()
for each in read_from_mnist(filename):
yield each
import numpy
__all__ = ['read_from_mnist']
def read_from_mnist(filename):
imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte"
f = open(imgf, "rb")
l = open(labelf, "rb")
f.read(16)
l.read(8)
# Define number of samples for train/test
if "train" in filename:
n = 60000
else:
n = 10000
images = numpy.fromfile(
f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
images = images / 255.0 * 2.0 - 1.0
labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
f.close()
l.close()
......@@ -17,7 +17,7 @@ set -e
#Note the default model is pass-00002, you shold make sure the model path
#exists or change the mode path.
#only test on trainer_config.lr.py
model=output/pass-00001/
model=output/model/pass-00001/
config=trainer_config.lr.py
label=data/labels.list
dict=data/dict.txt
......
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Should run pserver.sh before run this script.
bin_dir=$(cd `dirname $0`; pwd)
home_dir=$(cd "${bin_dir}/.."; pwd)
source "$bin_dir/env.sh"
model_dir="$bin_dir/output"
log_file="$bin_dir/train.log"
pushd "$home_dir"
cfg=trainer_config.lr.py
paddle train \
--config=$cfg \
--save_dir=${model_dir} \
--trainer_count=4 \
--local=0 \
--log_period=100 \
--num_passes=15 \
--use_gpu=false \
--show_parameter_stats_period=100 \
--test_all_data_in_one_period=1 \
--num_gradient_servers=1 \
--nics=`get_nics` \
--port=7164 \
--ports_num=1 \
--pservers="127.0.0.1" \
--comment="paddle_trainer" \
2>&1 | tee "$log_file"
popd
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
function get_nics() {
machine=`uname -s`
local nics=""
if [ "$machine" == "Linux" ]; then
nics="lo"
elif [ "$machine" == "Darwin" ]; then
nics="lo0"
else
nics="unsupport"
fi
echo $nics
}
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
bin_dir=$(cd `dirname $0`; pwd)
source "$bin_dir/env.sh"
paddle pserver \
--nics=`get_nics` \
--port=7164 \
--ports_num=1 \
--ports_num_for_sparse=1 \
--num_gradient_servers=1 \
--comment="paddle_pserver" \
2>&1 | tee 'pserver.log'
......@@ -65,16 +65,13 @@ The general development workflow with Docker and Bazel is as follows:
--name paddle \
-p 2022:22 \
-v $PWD:/paddle \
-v $HOME/.cache/bazel:/root/.cache/bazel \
paddle:dev
where :code:`-d` makes the container running in background,
:code:`--name paddle` allows us to run a nginx container to serve
documents in this container, :code:`-p 2022:22` allows us to SSH
into this container, :code:`-v $PWD:/paddle` shares the source code
on the host with the container, :code:`-v
$HOME/.cache/bazel:/root/.cache/bazel` shares Bazel cache on the
host with the container.
on the host with the container.
4. SSH into the container:
......@@ -94,13 +91,6 @@ The general development workflow with Docker and Bazel is as follows:
make -j `nproc`
CTEST_OUTPUT_ON_FAILURE=1 ctest
or Bazel in the container:
.. code-block:: bash
cd /paddle
bazel test ...
CPU-only and GPU Images
-----------------------
......
# Generative Adversarial Networks (GAN)
This demo implements GAN training described in the original [GAN paper](https://arxiv.org/abs/1406.2661) and deep convolutional generative adversarial networks [DCGAN paper](https://arxiv.org/abs/1511.06434).
The high-level structure of GAN is shown in Figure. 1 below. It is composed of two major parts: a generator and a discriminator, both of which are based on neural networks. The generator takes in some kind of noise with a known distribution and transforms it into an image. The discriminator takes in an image and determines whether it is artificially generated by the generator or a real image. So the generator and the discriminator are in a competitive game in which generator is trying to generate image to look as real as possible to fool the discriminator, while the discriminator is trying to distinguish between real and fake images.
<p align="center">
<img src="./gan.png" width="500" height="300">
</p>
<p align="center">
Figure 1. GAN-Model-Structure
<a href="https://ishmaelbelghazi.github.io/ALI/">figure credit</a>
</p>
The generator and discriminator take turn to be trained using SGD. The objective function of the generator is for its generated images being classified as real by the discriminator, and the objective function of the discriminator is to correctly classify real and fake images. When the GAN model is trained to converge to the equilibrium state, the generator will transform the given noise distribution to the distribution of real images, and the discriminator will not be able to distinguish between real and fake images at all.
## Implementation of GAN Model Structure
Since GAN model involves multiple neural networks, it requires to use paddle python API. So the code walk-through below can also partially serve as an introduction to the usage of Paddle Python API.
There are three networks defined in gan_conf.py, namely **generator_training**, **discriminator_training** and **generator**. The relationship to the model structure we defined above is that **discriminator_training** is the discriminator, **generator** is the generator, and the **generator_training** combined the generator and discriminator since training generator would require the discriminator to provide loss function. This relationship is described in the following code:
```python
if is_generator_training:
noise = data_layer(name="noise", size=noise_dim)
sample = generator(noise)
if is_discriminator_training:
sample = data_layer(name="sample", size=sample_dim)
if is_generator_training or is_discriminator_training:
label = data_layer(name="label", size=1)
prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(
input=prob, label=label, name=mode + '_error')
outputs(cost)
if is_generator:
noise = data_layer(name="noise", size=noise_dim)
outputs(generator(noise))
```
In order to train the networks defined in gan_conf.py, one first needs to initialize a Paddle environment, parse the config, create GradientMachine from the config and create trainer from GradientMachine as done in the code chunk below:
```python
import py_paddle.swig_paddle as api
# init paddle environment
api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10',
'--log_period=100', '--gpu_id=' + args.gpu_id,
'--save_dir=' + "./%s_params/" % data_source)
# Parse config
gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
# Create GradientMachine
dis_training_machine = api.GradientMachine.createFromConfigProto(
dis_conf.model_config)
gen_training_machine = api.GradientMachine.createFromConfigProto(
gen_conf.model_config)
generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config)
# Create trainer
dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
```
In order to balance the strength between generator and discriminator, we schedule to train whichever one is performing worse by comparing their loss function value. The loss function value can be calculated by a forward pass through the GradientMachine.
```python
def get_training_loss(training_machine, inputs):
outputs = api.Arguments.createArguments(0)
training_machine.forward(inputs, outputs, api.PASS_TEST)
loss = outputs.getSlotValue(0).copyToNumpyMat()
return numpy.mean(loss)
```
After training one network, one needs to sync the new parameters to the other networks. The code below demonstrates one example of such use case:
```python
# Train the gen_training
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
# Copy the parameters from gen_training to dis_training and generator
copy_shared_parameters(gen_training_machine,
dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine)
```
## A Toy Example
With the infrastructure explained above, we can now walk you through a toy example of generating two dimensional uniform distribution using 10 dimensional Gaussian noise.
The Gaussian noises are generated using the code below:
```python
def get_noise(batch_size, noise_dim):
return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
```
The real samples (2-D uniform) are generated using the code below:
```python
# synthesize 2-D uniform data in gan_trainer.py:114
def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32')
return data
```
The generator and discriminator network are built using fully-connected layer and batch_norm layer, and are defined in gan_conf.py.
To train the GAN model, one can use the command below. The flag -d specifies the training data (cifar, mnist or uniform) and flag --useGpu specifies whether to use gpu for training (0 is cpu, 1 is gpu).
```bash
$python gan_trainer.py -d uniform --useGpu 1
```
The generated samples can be found in ./uniform_samples/ and one example is shown below as Figure 2. One can see that it roughly recovers the 2D uniform distribution.
<p align="center">
<img src="./uniform_sample.png" width="300" height="300">
</p>
<p align="center">
Figure 2. Uniform Sample
</p>
## MNIST Example
### Data preparation
To download the MNIST data, one can use the following commands:
```bash
$cd data/
$./get_mnist_data.sh
```
### Model description
Following the DC-Gan paper (https://arxiv.org/abs/1511.06434), we use convolution/convolution-transpose layer in the discriminator/generator network to better deal with images. The details of the network structures are defined in gan_conf_image.py.
### Training the model
To train the GAN model on mnist data, one can use the following command:
```bash
$python gan_trainer.py -d mnist --useGpu 1
```
The generated sample images can be found at ./mnist_samples/ and one example is shown below as Figure 3.
<p align="center">
<img src="./mnist_sample.png" width="300" height="300">
</p>
<p align="center">
Figure 3. MNIST Sample
</p>
set(API_SOURCES
Arguments.cpp
ConfigParser.cpp
Evaluator.cpp
GradientMachine.cpp
Matrix.cpp
Parameter.cpp
ParameterOptimizer.cpp
ParameterUpdater.cpp
SequenceGenerator.cpp
Trainer.cpp
Util.cpp
......@@ -63,6 +65,15 @@ install(DIRECTORY ${PROJ_ROOT}/paddle/dist/
add_custom_target(python_api_wheel ALL DEPENDS
${PROJ_ROOT}/paddle/dist/.timestamp)
add_dependencies(python_api_wheel python_swig_sources
paddle_parameter
paddle_math
paddle_utils
paddle_gserver
paddle_pserver
paddle_trainer
paddle_api
paddle_cuda)
if(WITH_TESTING)
add_subdirectory(test)
......
......@@ -11,13 +11,19 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sstream>
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#pragma once
Evaluator::Evaluator() : m(new EvaluatorPrivate()) {}
Evaluator::~Evaluator() { delete m; }
/**
* Disable copy macro.
*/
#define DISABLE_COPY(CLASS_NAME) \
CLASS_NAME(CLASS_NAME &&) = delete; \
CLASS_NAME(const CLASS_NAME &other) = delete; \
CLASS_NAME &operator=(const CLASS_NAME &other) = delete
void Evaluator::start() { m->rawPtr->start(); }
void Evaluator::finish() { m->rawPtr->finish(); }
std::string Evaluator::toString() {
std::ostringstream sout;
m->rawPtr->printStats(sout);
return sout.str();
}
......@@ -64,6 +64,18 @@ GradientMachine* GradientMachine::createByModelConfig(
return GradientMachine::createFromPaddleModelPtr(confPtr, mode, types);
}
void GradientMachine::start() { m->machine->start(); }
void GradientMachine::finish() { m->machine->finish(); }
void GradientMachine::onPassEnd() { m->machine->onPassEnd(); }
void GradientMachine::prefetch(const Arguments& inArgs) {
auto& in =
m->cast<std::vector<paddle::Argument>>(inArgs.getInternalArgumentsPtr());
m->machine->prefetch(in);
}
void GradientMachine::forward(const Arguments& inArgs,
Arguments* outArgs,
PassType passType) {
......@@ -158,3 +170,13 @@ SequenceGenerator* GradientMachine::asSequenceGenerator(
r->setBeamSize(beam_size);
return r;
}
Evaluator* GradientMachine::makeEvaluator() {
auto ev = new Evaluator();
ev->m->rawPtr = m->machine->makeEvaluator();
return ev;
}
void GradientMachine::eval(Evaluator* evaluator) {
m->machine->eval(evaluator->m->rawPtr);
}
......@@ -96,7 +96,9 @@ namespace std {
%rename(__getitem__) Vector::get;
%rename(__setitem__) Vector::set;
%rename(__len__) Vector::getSize;
%rename(__len__) Parameter::getSize;
%rename(__call__) ParameterTraverseCallback::apply;
%rename(__repr__) Evaluator::toString;
%apply (float* INPLACE_ARRAY2, int DIM1, int DIM2) {
(float* data, int dim1, int dim2)
......@@ -167,6 +169,7 @@ namespace std {
%newobject GradientMachine::asSequenceGenerator;
%newobject GradientMachine::getParameter;
%newobject GradientMachine::getLayerOutput;
%newobject GradientMachine::makeEvaluator;
%newobject TrainerConfig::createFromTrainerConfigFile;
%newobject TrainerConfig::getModelConfig;
%newobject TrainerConfig::getOptimizationConfig;
......@@ -174,6 +177,7 @@ namespace std {
%newobject Parameter::getConfig;
%newobject ParameterOptimizer::create;
%newobject ParameterOptimizer::needSpecialTraversal;
%newobject ParameterUpdater::createLocalUpdater;
%feature("director") UpdateCallback;
%feature("autodoc", 1); // To generate method stub, for code hint in ide
......@@ -193,4 +197,4 @@ namespace std {
%ignore OptimizationConfigPrivate;
%ignore ParameterTraverseCallbackPrivate;
%include "utils/GlobalConstants.h"
%include "api/PaddleAPI.h"
\ No newline at end of file
%include "api/PaddleAPI.h"
......@@ -20,15 +20,11 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
/// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT
#define DISABLE_COPY_AND_ASSIGN(classname) \
classname(const classname& other); \
classname& operator=(const classname& other)
/**
* @brief Initialize paddle.
*
......@@ -102,7 +98,7 @@ const size_t NO_SPARSE_ID = -1UL;
struct MatrixPrivate;
class Matrix {
Matrix(); // User Cannot Create Matrix.
DISABLE_COPY_AND_ASSIGN(Matrix);
DISABLE_COPY(Matrix);
static Matrix* createByPaddleMatrixPtr(void* sharedPtr);
public:
......@@ -242,7 +238,7 @@ private:
struct VectorPrivate;
class Vector {
DISABLE_COPY_AND_ASSIGN(Vector);
DISABLE_COPY(Vector);
Vector();
static Vector* createByPaddleVectorPtr(void* ptr);
......@@ -322,7 +318,7 @@ private:
struct IVectorPrivate;
class IVector {
IVector();
DISABLE_COPY_AND_ASSIGN(IVector);
DISABLE_COPY(IVector);
static IVector* createByPaddleVectorPtr(void* ptr);
public:
......@@ -402,7 +398,7 @@ struct ArgumentsPrivate;
class Arguments {
private:
Arguments(); // Internal Create.
DISABLE_COPY_AND_ASSIGN(Arguments);
DISABLE_COPY(Arguments);
public:
/**
......@@ -472,7 +468,7 @@ enum GradientMatchineCreateMode {
struct ParameterConfigPrivate;
class ParameterConfig {
DISABLE_COPY_AND_ASSIGN(ParameterConfig);
DISABLE_COPY(ParameterConfig);
ParameterConfig();
/**
......@@ -502,7 +498,7 @@ private:
struct OptimizationConfigPrivate;
class OptimizationConfig {
DISABLE_COPY_AND_ASSIGN(OptimizationConfig);
DISABLE_COPY(OptimizationConfig);
OptimizationConfig();
public:
......@@ -519,6 +515,7 @@ private:
friend class TrainerConfig;
friend class ParameterOptimizer;
friend class ParameterUpdater;
friend class Trainer;
};
......@@ -526,7 +523,7 @@ struct ParameterPrivate;
class Parameter {
private:
Parameter();
DISABLE_COPY_AND_ASSIGN(Parameter);
DISABLE_COPY(Parameter);
public:
virtual ~Parameter();
......@@ -549,6 +546,8 @@ public:
ParameterConfig* getConfig();
void setValueUpdated();
size_t getSize() const;
private:
static Parameter* createFromRawPtr(void* ptr);
static Parameter* createFromSharedPtr(void* ptr);
......@@ -557,6 +556,7 @@ private:
ParameterPrivate* m;
friend class UpdateCallbackWrapper;
friend class GradientMachine;
friend class ParameterUpdater;
};
struct ModelConfigPrivate;
......@@ -568,7 +568,7 @@ struct ModelConfigPrivate;
class ModelConfig {
private:
ModelConfig();
DISABLE_COPY_AND_ASSIGN(ModelConfig);
DISABLE_COPY(ModelConfig);
public:
virtual ~ModelConfig();
......@@ -589,7 +589,7 @@ struct TrainerConfigPrivate;
class TrainerConfig {
private:
TrainerConfig();
DISABLE_COPY_AND_ASSIGN(TrainerConfig);
DISABLE_COPY(TrainerConfig);
public:
virtual ~TrainerConfig();
......@@ -629,7 +629,7 @@ public:
struct ParameterTraverseCallbackPrivate;
class ParameterTraverseCallback {
DISABLE_COPY_AND_ASSIGN(ParameterTraverseCallback);
DISABLE_COPY(ParameterTraverseCallback);
ParameterTraverseCallback();
public:
......@@ -651,7 +651,7 @@ private:
*/
struct ParameterOptimizerPrivate;
class ParameterOptimizer {
DISABLE_COPY_AND_ASSIGN(ParameterOptimizer);
DISABLE_COPY(ParameterOptimizer);
ParameterOptimizer();
public:
......@@ -683,12 +683,12 @@ private:
};
class SequenceGenerator;
class Evaluator;
struct GradientMachinePrivate;
class GradientMachine {
private:
GradientMachine();
DISABLE_COPY_AND_ASSIGN(GradientMachine);
DISABLE_COPY(GradientMachine);
public:
virtual ~GradientMachine();
......@@ -714,6 +714,23 @@ public:
GradientMatchineCreateMode mode = CREATE_MODE_NORMAL,
const std::vector<int>& parameterTypes = defaultParamTypes);
/**
* @brief finish
*/
void finish();
void start();
/**
* Prefetch row ids of sparse parameter.
*/
void prefetch(const Arguments& inArgs);
/**
* Do some thing when train pass ended.
*/
void onPassEnd();
/**
* The forward stage of GradientMachine.
*
......@@ -761,6 +778,10 @@ public:
size_t max_length = 100UL,
size_t beam_size = -1UL);
Evaluator* makeEvaluator();
void eval(Evaluator* evaluator);
private:
GradientMachinePrivate* m;
......@@ -772,6 +793,109 @@ private:
// Not to use c++ 11 init-list, so we use static var as function default arg.
static std::vector<int> defaultParamTypes;
friend class Trainer;
friend class ParameterUpdater;
};
struct ParameterUpdaterPrivate;
class ParameterUpdater {
private:
ParameterUpdater();
public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
~ParameterUpdater();
/**
* @brief initialize Parameter Updater by GradientMachine.
* @param gm
*/
void init(const GradientMachine& gm);
/**
* @brief begin of a training/testing of one pass.
*/
void startPass();
/**
* @brief end of a traning/testing of one pass.
*/
void finishPass();
/**
* @brief begin of a training/testing of one batch.
* @param data batch's size
* @return PassType, mostly will be training.
*/
PassType startBatch(size_t batchSize);
/**
* @brief end of a traning/testing of one batch
* @param cost current batch cost.
*/
void finishBatch(float cost);
/**
* @brief update a parameter (by local optimizer or by cluster pserver)
* @param param
*/
void update(Parameter* param);
/**
* @brief restore the average parameter.
* @note It is only used in AverageOptimizer. Restore will get the current
* PARAMETER_VALUE back.
*/
void restore();
/**
* @brief apply. Store the average parameter.
* @note It is only used in AverageOptimizer. Apply will store the current
* PARAMETER_VALUE to buffer, calcaualte current Average Parameter, and save
* it to PARAMETER_VALUE.
*/
void apply();
/**
* @brief catchUpWith The Regularization will be delayed in many situations(
* pserver, local sparse). Catch Up means catch the regularization up, apply
* regularization to all params.
*/
void catchUpWith();
private:
ParameterUpdaterPrivate* m;
};
struct EvaluatorPrivate;
class Evaluator {
private:
Evaluator();
DISABLE_COPY(Evaluator);
public:
~Evaluator();
/**
* @brief begin an evaluate stage.
*/
void start();
/**
* @brief end an evaluate stage.
*/
void finish();
/**
* @brief toString will get a evaluate result.
*
* __repr__ method in python
*/
std::string toString();
private:
EvaluatorPrivate* m;
friend class GradientMachine;
};
struct TrainerPrivate;
......@@ -780,7 +904,7 @@ private:
TrainerPrivate* m;
Trainer();
Trainer(TrainerConfig* optConfig, GradientMachine* gm);
DISABLE_COPY_AND_ASSIGN(Trainer);
DISABLE_COPY(Trainer);
public:
virtual ~Trainer();
......@@ -846,7 +970,7 @@ public:
struct SequenceGeneratorPrivate;
class SequenceGenerator {
DISABLE_COPY_AND_ASSIGN(SequenceGenerator);
DISABLE_COPY(SequenceGenerator);
SequenceGenerator();
public:
......
......@@ -11,12 +11,14 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "PaddleAPI.h"
#include "paddle/gserver/evaluators/Evaluator.h"
#include "paddle/gserver/gradientmachines/GradientMachine.h"
#include "paddle/parameter/ParameterUpdaterBase.h"
#include "paddle/trainer/TrainerConfigHelper.h"
#pragma once
struct GradientMachinePrivate {
std::shared_ptr<paddle::GradientMachine> machine;
......@@ -65,3 +67,31 @@ struct ArgumentsPrivate {
return *(std::shared_ptr<T>*)(rawPtr);
}
};
struct ParameterUpdaterPrivate {
std::unique_ptr<paddle::ParameterUpdater> updater;
};
struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr; // rawPtr only used in ParameterUpdater,
// in other situation sharedPtr should
// contains value.
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
struct EvaluatorPrivate {
paddle::Evaluator* rawPtr;
EvaluatorPrivate() : rawPtr(nullptr) {}
~EvaluatorPrivate() { delete rawPtr; }
};
......@@ -14,21 +14,7 @@ limitations under the License. */
#include "paddle/parameter/Parameter.h"
#include "PaddleAPI.h"
struct ParameterPrivate {
std::shared_ptr<paddle::Parameter> sharedPtr;
paddle::Parameter* rawPtr;
ParameterPrivate() : sharedPtr(nullptr), rawPtr(nullptr) {}
paddle::Parameter* getPtr() {
if (sharedPtr) {
return sharedPtr.get();
} else {
return rawPtr;
}
}
};
#include "PaddleAPIPrivate.h"
Parameter::Parameter() : m(new ParameterPrivate()) {}
......@@ -70,3 +56,5 @@ ParameterConfig* Parameter::getConfig() {
size_t Parameter::getID() const { return m->getPtr()->getID(); }
void Parameter::setValueUpdated() { m->getPtr()->setValueUpdated(); }
size_t Parameter::getSize() const { return m->getPtr()->getSize(); }
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "PaddleAPI.h"
#include "PaddleAPIPrivate.h"
#include "paddle/trainer/ThreadParameterUpdater.h"
ParameterUpdater::ParameterUpdater() : m(new ParameterUpdaterPrivate()) {}
ParameterUpdater *ParameterUpdater::createLocalUpdater(
OptimizationConfig *config) {
auto param = new ParameterUpdater();
param->m->updater.reset(new paddle::SgdThreadUpdater(config->m->getConfig()));
return param;
}
ParameterUpdater::~ParameterUpdater() { delete m; }
void ParameterUpdater::init(const GradientMachine &gm) {
m->updater->init(gm.m->machine->getNonStaticParameters());
}
void ParameterUpdater::startPass() { m->updater->startPass(); }
void ParameterUpdater::finishPass() { m->updater->finishPass(); }
PassType ParameterUpdater::startBatch(size_t batchSize) {
return m->updater->startBatch((int64_t)batchSize);
}
void ParameterUpdater::finishBatch(float cost) {
m->updater->finishBatch(cost);
}
void ParameterUpdater::update(Parameter *param) {
auto paddleParam = param->m->getPtr();
m->updater->update(paddleParam);
}
void ParameterUpdater::restore() { m->updater->restore(); }
void ParameterUpdater::apply() { m->updater->apply(); }
void ParameterUpdater::catchUpWith() { m->updater->catchUpWith(); }
......@@ -253,7 +253,7 @@ void Vector::copyToNumpyArray(float** view_m_data, int* dim1) {
*view_m_data = new float[*dim1];
if (auto cpuVec = dynamic_cast<paddle::CpuVector*>(m->vec.get())) {
std::memcpy(*view_m_data, cpuVec->getData(), sizeof(float) * (*dim1));
} else if (auto gpuVec = dynamic_cast<paddle::CpuVector*>(m->vec.get())) {
} else if (auto gpuVec = dynamic_cast<paddle::GpuVector*>(m->vec.get())) {
hl_memcpy_device2host(
*view_m_data, gpuVec->getData(), sizeof(float) * (*dim1));
} else {
......
......@@ -141,9 +141,12 @@ try:
def c_flag(self):
if self.with_coverage:
return ["-fprofile-arcs", "-ftest-coverage", "-O0", "-g"]
return [
"-fprofile-arcs", "-ftest-coverage", "-O0", "-g",
"-std=c++11"
]
else:
return None
return ["-std=c++11"]
except ImportError:
class PaddleLDFlag(object):
......
......@@ -16,7 +16,31 @@ limitations under the License. */
#define HL_BASE_H_
#include <cstddef>
#include "paddle/utils/TypeDefs.h"
#ifdef PADDLE_TYPE_DOUBLE
#define HL_FLOAT_MAX 3.40282347e+38F
#define HL_FLOAT_MIN 1.17549435e-38F
using real = double;
#else
#define HL_FLOAT_MAX 1.7976931348623157e+308
#define HL_FLOAT_MIN 2.2250738585072014e-308
using real = float;
#endif
/**
* The maximum input value for exp, used to avoid overflow problem.
* currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
/**
* @brief DIVUP(x, y) is similar to ceil(x / y).
* @note For CUDA, DIVUP will be used to specify
* the size of blockDim.
*/
#ifndef DIVUP
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#endif
/**
* HPPL is an internal high performance parallel computing library
......@@ -181,46 +205,6 @@ typedef struct {
size_t nnz;
} _hl_sparse_matrix_s, *hl_sparse_matrix_s;
#ifndef PADDLE_TYPE_DOUBLE
/**
* HPPL data type: real (float or double)
*
* if real == float
*
* HL_FLOAT_MAX: 3.40282347e+38F
*
* HL_FLOAT_MIN: 1.17549435e-38F
*/
#define HL_FLOAT_MAX 3.40282347e+38F
/**
* if real == double
*
* HL_FLOAT_MAX: 1.7976931348623157e+308
*
* HL_FLOAT_MIN: 2.2250738585072014e-308
*/
#define HL_FLOAT_MIN 1.17549435e-38F
#else
#define HL_FLOAT_MAX 1.7976931348623157e+308
#define HL_FLOAT_MIN 2.2250738585072014e-308
#endif
/**
* The maximum input value for exp, used to avoid overflow problem.
*
* Currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
/**
* @brief DIVUP(x, y) is similar to ceil(x / y).
* @note For CUDA, DIVUP will be used to specify
* the size of blockDim.
*/
#ifndef DIVUP
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#endif
#ifdef __NVCC__
#include "cuda_runtime.h"
......
......@@ -34,8 +34,8 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/common.h"
namespace paddle {
/**
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "ModelConfig.pb.h"
#include "hl_gpu.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "ModelConfig.pb.h"
#include "hl_gpu.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <memory>
#include <random>
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <stdint.h>
#include <cstddef>
#include "TensorExpression.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -27,7 +27,7 @@ limitations under the License. */
#include "MemoryHandle.h"
#include "Vector.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <cstddef>
#include "hl_tensor_ops.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "BaseMatrix.h"
#include "MemoryHandle.h"
#include "paddle/utils/Thread.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -28,7 +28,7 @@ limitations under the License. */
#include "paddle/parameter/ParameterUpdateFunctions.h"
#include "paddle/utils/Flags.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
#include "ParameterConfig.pb.h"
......
......@@ -29,8 +29,8 @@ limitations under the License. */
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/math/Vector.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/math/Matrix.h"
#include "paddle/pserver/ProtoServer.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -26,8 +26,8 @@ limitations under the License. */
#include "paddle/utils/Flags.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/common.h"
#include "ParameterService.pb.h"
......
......@@ -32,7 +32,7 @@ limitations under the License. */
#include "paddle/utils/Locks.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
#include "ParameterService.pb.h"
......
......@@ -15,6 +15,7 @@
import paddle.trainer.PyDataProvider2 as dp2
import collections
import swig_paddle
import numpy
__all__ = ['DataProviderConverter']
......@@ -35,18 +36,18 @@ class IScanner(object):
class DenseScanner(IScanner):
def __init__(self, input_type, pos):
IScanner.__init__(self, input_type, pos)
self.__mat__ = []
self.__height__ = 0
self.__mat__ = None
def scan(self, dat):
self.__mat__.extend(dat)
self.__height__ += 1
if self.__mat__ is None:
self.__mat__ = numpy.array([dat], dtype='float32')
else:
self.__mat__ = numpy.append(self.__mat__, [dat], axis=0)
def finish_scan(self, argument):
assert isinstance(argument, swig_paddle.Arguments)
assert isinstance(self.input_type, dp2.InputType)
m = swig_paddle.Matrix.createDense(self.__mat__, self.__height__,
self.input_type.dim, False)
m = swig_paddle.Matrix.createDenseFromNumpy(self.__mat__, True, False)
argument.setSlotValue(self.pos, m)
......
......@@ -17,18 +17,6 @@ RUN cd /usr/src/gtest && cmake . && make && cp *.a /usr/lib
RUN pip install -U BeautifulSoup docopt PyYAML pillow \
sphinx sphinx_rtd_theme recommonmark
# cmake tends to hide and blur the dependencies between code modules, as
# noted here https://github.com/PaddlePaddle/Paddle/issues/763. We are
# thinking about using Bazel to fix this problem, e.g.,
# https://github.com/PaddlePaddle/Paddle/issues/681#issuecomment-263996102. To
# start the trail of fixing, we add Bazel to our Dockerfiles.
RUN apt-get update && apt-get install -y curl software-properties-common \
&& add-apt-repository ppa:webupd8team/java \
&& echo "oracle-java8-installer shared/accepted-oracle-license-v1-1 select true" | debconf-set-selections \
&& echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list \
&& curl https://bazel.build/bazel-release.pub.gpg | apt-key add - \
&& apt-get update && apt-get install -y oracle-java8-installer bazel
ARG WITH_AVX
ARG WITH_DOC
ARG WITH_SWIG_PY
......
......@@ -17,18 +17,6 @@ RUN cd /usr/src/gtest && cmake . && make && cp *.a /usr/lib
RUN pip install -U BeautifulSoup docopt PyYAML pillow \
sphinx sphinx_rtd_theme recommonmark
# cmake tends to hide and blur the dependencies between code modules, as
# noted here https://github.com/PaddlePaddle/Paddle/issues/763. We are
# thinking about using Bazel to fix this problem, e.g.,
# https://github.com/PaddlePaddle/Paddle/issues/681#issuecomment-263996102. To
# start the trail of fixing, we add Bazel to our Dockerfiles.
RUN apt-get update && apt-get install -y curl software-properties-common \
&& add-apt-repository ppa:webupd8team/java \
&& echo "oracle-java8-installer shared/accepted-oracle-license-v1-1 select true" | debconf-set-selections \
&& echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | tee /etc/apt/sources.list.d/bazel.list \
&& curl https://bazel.build/bazel-release.pub.gpg | apt-key add - \
&& apt-get update && apt-get install -y oracle-java8-installer bazel
ARG WITH_AVX
ARG WITH_DOC
ARG WITH_SWIG_PY
......
......@@ -30,8 +30,10 @@ is_lin = (system == 'linux')
# The extra links will passed from COMAKE
# because generate paddle LDFLAGS is too complicated to do in setup.py
# it just read COMAKE generated LDFLAGS.
extra_comps = []
extra_links = []
obj = api.paddle_ld_flags.PaddleLDFlag()
extra_comps = obj.c_flag()
ldflags = obj.ldflag_str()
if ldflags is not None:
extra_links.extend(ldflags.split(" "))
......@@ -51,20 +53,15 @@ elif is_osx == True:
include_dirs = [np.get_include(), "../"] # include numpy and paddle.
extra_c = obj.c_flag()
attr=dict()
if extra_c is not None:
attr["extra_compile_args"] = extra_c
setup(name="py_paddle",
version="@PADDLE_VERSION@",
ext_modules=[
Extension('py_paddle._swig_paddle', # Build SWIG Extension.
['Paddle_wrap.cxx'],
language = "c++",
include_dirs = include_dirs,
extra_link_args = extra_links,
**attr
extra_compile_args = extra_comps
)
],
packages=['py_paddle'],
......
......@@ -33,8 +33,8 @@ namespace paddle {
because at the current moment, the merging on CPU is happening on the
main thread, and the its parameter size can be much larger than the one GPU.
Thus, for GPU, the parameter updates happens in updateImpl() function, which
is called by gradient machines as a callback function as a callback function
supplied to backward() and forwardBackward().
is called by gradient machines as a callback function supplied to backward()
and forwardBackward().
For CPU, the parameter updates happens in separate threads maintained by this
class.
*/
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#pragma once
#include "DisableCopy.h"
#include "common.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include <condition_variable>
#include <mutex>
#include "DisableCopy.h"
#include "common.h"
namespace paddle {
......
......@@ -26,12 +26,11 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "DisableCopy.h"
#include "Logging.h"
#include "TrainerConfig.pb.h"
#include "common.h"
#include "Flags.h"
#include "TypeDefs.h"
#include "hl_gpu.h"
/**
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <stddef.h>
#include <iostream>
#include "TypeDefs.h"
#include "common.h"
namespace paddle {
......
......@@ -14,13 +14,20 @@ limitations under the License. */
#pragma once
/**
* Disable copy macro.
*/
#define DISABLE_COPY(class_name) \
class_name(class_name &&) = delete; \
class_name(const class_name &other) = delete; \
class_name &operator=(const class_name &other) = delete
namespace paddle {
#ifdef PADDLE_TYPE_DOUBLE
typedef double real;
using real = double;
#else
typedef float real;
using real = float;
#endif
} // namespace paddle
using paddle::real;
......@@ -3416,8 +3416,35 @@ def register_parse_config_hook(f):
_parse_config_hooks.add(f)
def parse_config(config_file, config_arg_str):
def update_g_config():
'''
Update g_config after execute config_file or config_functions.
'''
for k, v in settings.iteritems():
if v is None:
continue
g_config.opt_config.__setattr__(k, v)
for k, v in trainer_settings.iteritems():
if v is None:
continue
g_config.__setattr__(k, v)
for name in g_config.model_config.input_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
'The type of input layer "%s" is not "data"' % name
for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
return g_config
def parse_config(trainer_config, config_arg_str):
'''
@param trainer_config: can be a string of config file name or a function name
with config logic
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
......@@ -3451,45 +3478,20 @@ def parse_config(config_file, config_arg_str):
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel
# for paddle on spark, need support non-file config.
# you can use parse_config like below:
#
# from paddle.trainer.config_parser import parse_config
# def configs():
# #your paddle config code, which is same as config file.
#
# config = parse_config(configs, "is_predict=1")
# # then you get config proto object.
if hasattr(config_file, '__call__'):
config_file.func_globals.update(
if hasattr(trainer_config, '__call__'):
trainer_config.func_globals.update(
make_config_environment("", config_args))
config_file()
trainer_config()
else:
execfile(config_file, make_config_environment(config_file, config_args))
for k, v in settings.iteritems():
if v is None:
continue
g_config.opt_config.__setattr__(k, v)
for k, v in trainer_settings.iteritems():
if v is None:
continue
g_config.__setattr__(k, v)
execfile(trainer_config,
make_config_environment(trainer_config, config_args))
for name in g_config.model_config.input_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
'The type of input layer "%s" is not "data"' % name
for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
return g_config
return update_g_config()
def parse_config_and_serialize(config_file, config_arg_str):
def parse_config_and_serialize(trainer_config, config_arg_str):
try:
config = parse_config(config_file, config_arg_str)
config = parse_config(trainer_config, config_arg_str)
#logger.info(config)
return config.SerializeToString()
except:
......
......@@ -20,6 +20,6 @@ from layers import *
from networks import *
from optimizers import *
from attrs import *
from config_parser_utils import *
# This will enable operator overload for LayerOutput
import math as layer_math
import layer_math
......@@ -19,34 +19,34 @@ __all__ = [
def convert_and_compare(x, Type):
"""
Convert x to be the same type as Type and then convert back to
check whether there is a loss of information
:param x: object to be checked
:param Type: target type to check x over
"""
Convert x to be the same type as Type and then convert back to
check whether there is a loss of information
:param x: object to be checked
:param Type: target type to check x over
"""
return type(x)(Type(x)) == x
def is_compatible_with(x, Type):
"""
Check if x has a type compatible with Type
:param x: object to be checked
:param Type: target type to check x over
"""
Check if x has a type compatible with Type
:param x: object to be checked
:param Type: target type to check x over
"""
if type(x) == Type:
return True
try:
if float == Type or int == Type:
# avoid those types that can be converted to float/int but not very
# meaningful and could potentially lead to error
# i.e., str and bool typed value should not be used for initializing float/int variable
# avoid those types that can be converted to float/int but not very
# meaningful and could potentially lead to error
# i.e., str and bool typed value should not be used for initializing float/int variable
if not isinstance(x, str) and not isinstance(x, bool):
return convert_and_compare(x, Type)
elif bool == Type:
# should not use string type to initialize bool variable
# should not use string type to initialize bool variable
if not isinstance(x, str):
return convert_and_compare(x, Type)
else:
......@@ -88,6 +88,10 @@ class ParameterAttribute(object):
:type learning_rate: float or None
:param momentum: The parameter momentum. None means use global value.
:type momentum: float or None
:param gradient_clipping_threshold: gradient clipping threshold. If gradient
value larger than some value, will be
clipped.
:type gradient_clipping_threshold: float
:param sparse_update: Enable sparse update for this parameter. It will
enable both local and remote sparse update.
:type sparse_update: bool
......@@ -104,6 +108,7 @@ class ParameterAttribute(object):
l2_rate=None,
learning_rate=None,
momentum=None,
gradient_clipping_threshold=None,
sparse_update=False):
# initialize strategy.
if is_static:
......@@ -152,6 +157,11 @@ class ParameterAttribute(object):
self.attr['sparse_update'] = True
self.attr['sparse_remote_update'] = True
if gradient_clipping_threshold is not None and \
is_compatible_with(gradient_clipping_threshold, float):
self.attr['gradient_clipping_threshold'] = \
gradient_clipping_threshold
def set_default_parameter_name(self, name):
"""
Set default parameter name. If parameter not set, then will use default
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.trainer.config_parser as config_parser
'''
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
and optimizer configuration.
'''
__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
]
def parse_trainer_config(trainer_conf, config_arg_str):
return config_parser.parse_config(trainer_conf, config_arg_str)
def parse_network_config(network_conf):
config = config_parser.parse_config(network_conf, '')
return config.model_config
def parse_optimizer_config(optimizer_conf):
config = config_parser.parse_config(optimizer_conf, '')
return config.opt_config
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.trainer.config_parser as config_parser
'''
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
and optimizer configuration.
'''
__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
]
def parse_trainer_config(trainer_conf, config_arg_str):
return config_parser.parse_config(trainer_conf, config_arg_str)
def parse_network_config(network_conf, config_arg_str=''):
config = config_parser.parse_config(network_conf, config_arg_str)
return config.model_config
def parse_optimizer_config(optimizer_conf, config_arg_str=''):
config = config_parser.parse_config(optimizer_conf, config_arg_str)
return config.opt_config
# Bazel (http://bazel.io/) BUILD file for gflags.
#
# See INSTALL.md for instructions for adding gflags to a Bazel workspace.
licenses(["notice"])
exports_files(["src/gflags_complections.sh", "COPYING.txt"])
load(":bazel/gflags.bzl", "gflags_sources", "gflags_library")
(hdrs, srcs) = gflags_sources(namespace=["google", "gflags"])
gflags_library(hdrs=hdrs, srcs=srcs, threads=0)
gflags_library(hdrs=hdrs, srcs=srcs, threads=1)
licenses(["notice"]) # Apache 2.0
cc_test(
name="gflags_test",
srcs=["gflags_test.cc"],
copts=["-Iexternal/gtest/include"],
deps=[
"@gtest//:gtest",
"@gflags//:gflags",
], )
#include <iostream>
#include <string>
#include "gflags/gflags.h"
#include "gtest/gtest.h"
DEFINE_bool(verbose, false, "Display program name before message");
DEFINE_string(message, "Hello world!", "Message to print");
static bool IsNonEmptyMessage(const char *flagname, const std::string &value) {
return value[0] != '\0';
}
DEFINE_validator(message, &IsNonEmptyMessage);
namespace third_party {
namespace gflags_test {
TEST(GflagsTest, ParseAndPrint) {
gflags::SetUsageMessage("some usage message");
gflags::SetVersionString("1.0.0");
int argc = 1;
char program_name[] = "gflags_test";
char **argv = new char *[2];
argv[0] = program_name;
argv[1] = NULL;
gflags::ParseCommandLineFlags(&argc, reinterpret_cast<char ***>(&argv), true);
EXPECT_EQ("gflags_test", std::string(gflags::ProgramInvocationShortName()));
EXPECT_EQ("Hello world!", FLAGS_message);
gflags::ShutDownCommandLineFlags();
}
} // namespace gflags_test
} // namespace third_party
licenses(["notice"])
cc_library(
visibility=["//visibility:public"],
name="glog",
includes=[
".",
"src",
],
copts=[
"-D_START_GOOGLE_NAMESPACE_='namespace google {'",
"-D_END_GOOGLE_NAMESPACE_='}'",
"-DGOOGLE_NAMESPACE='google'",
"-DGOOGLE_GLOG_DLL_DECL=''",
"-DHAVE_DLADDR",
"-DHAVE_SNPRINTF",
"-DHAVE_DLFCN_H",
"-DHAVE_FCNTL",
"-DHAVE_GLOB_H",
"-DHAVE_INTTYPES_H",
"-DHAVE_LIBPTHREAD",
"-DHAVE_SYS_SYSCALL_H",
"-DHAVE_MEMORY_H",
"-DHAVE_NAMESPACES",
"-DHAVE_PREAD",
"-DHAVE_PTHREAD",
"-DHAVE_PWD_H",
"-DHAVE_PWRITE",
"-DHAVE_RWLOCK",
"-DHAVE_SIGACTION",
"-DHAVE_SIGALTSTACK",
"-DHAVE_STDINT_H",
"-DHAVE_STRING_H",
"-DHAVE_SYS_TIME_H",
"-DHAVE_SYS_TYPES_H",
"-DHAVE_SYS_UCONTEXT_H",
"-DHAVE_SYS_UTSNAME_H",
"-DHAVE_UNISTD_H",
"-DHAVE_USING_OPERATOR",
"-DHAVE_HAVE___ATTRIBUTE___",
"-DHAVE_HAVE___BUILTIN_EXPECT",
#"-DNO_FRAME_POINTER",
"-D_GNU_SOURCE",
#"-fno-sanitize=thread",
#"-fno-sanitize=address",
"-Iexternal/glog/src",
],
srcs=[
"src/demangle.cc",
"src/logging.cc",
"src/raw_logging.cc",
"src/signalhandler.cc",
"src/symbolize.cc",
"src/utilities.cc",
"src/vlog_is_on.cc",
":config_h",
":logging_h",
":raw_logging_h",
":stl_logging_h",
":vlog_is_on_h",
],
hdrs=[
"src/demangle.h",
"src/mock-log.h",
"src/stacktrace.h",
"src/symbolize.h",
"src/utilities.h",
"src/base/commandlineflags.h",
"src/base/googleinit.h",
"src/base/mutex.h",
"src/glog/log_severity.h",
])
genrule(
name="config_h",
srcs=["src/config.h.cmake.in"],
outs=["config.h"],
cmd="awk '{ gsub(/^#cmakedefine/, \"//cmakedefine\"); print; }' $(<) > $(@)",
)
genrule(
name="logging_h",
srcs=["src/glog/logging.h.in"],
outs=["glog/logging.h"],
cmd="$(location :gen_sh) < $(<) > $(@)",
tools=[":gen_sh"])
genrule(
name="raw_logging_h",
srcs=["src/glog/raw_logging.h.in"],
outs=["glog/raw_logging.h"],
cmd="$(location :gen_sh) < $(<) > $(@)",
tools=[":gen_sh"])
genrule(
name="stl_logging_h",
srcs=["src/glog/stl_logging.h.in"],
outs=["glog/stl_logging.h"],
cmd="$(location :gen_sh) < $(<) > $(@)",
tools=[":gen_sh"])
genrule(
name="vlog_is_on_h",
srcs=["src/glog/vlog_is_on.h.in"],
outs=["glog/vlog_is_on.h"],
cmd="$(location :gen_sh) < $(<) > $(@)",
tools=[":gen_sh"])
genrule(
name="gen_sh",
outs=["gen.sh"],
cmd="""
cat > $@ <<"EOF"
#! /bin/sh
sed -e 's/@ac_cv_have_unistd_h@/1/g' \
-e 's/@ac_cv_have_stdint_h@/1/g' \
-e 's/@ac_cv_have_systypes_h@/1/g' \
-e 's/@ac_cv_have_libgflags_h@/1/g' \
-e 's/@ac_cv_have_uint16_t@/1/g' \
-e 's/@ac_cv_have___builtin_expect@/1/g' \
-e 's/@ac_cv_have_.*@/0/g' \
-e 's/@ac_google_start_namespace@/namespace google {/g' \
-e 's/@ac_google_end_namespace@/}/g' \
-e 's/@ac_google_namespace@/google/g' \
-e 's/@ac_cv___attribute___noinline@/__attribute__((noinline))/g' \
-e 's/@ac_cv___attribute___noreturn@/__attribute__((noreturn))/g' \
-e 's/@ac_cv___attribute___printf_4_5@/__attribute__((__format__ (__printf__, 4, 5)))/g'
EOF""")
licenses(["notice"]) # Apache 2.0
cc_test(
name="glog_test",
srcs=["glog_test.cc"],
copts=["-Iexternal/gtest/include"],
deps=[
"@gtest//:gtest",
"@glog//:glog",
], )
#include <iostream>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
TEST(GlogTest, Logging) { LOG(INFO) << "Hello world"; }
cc_library(
name="gtest",
srcs=glob(
["src/*.cc"], exclude=["src/gtest-all.cc"]),
hdrs=glob(["include/**/*.h", "src/*.h"]),
copts=["-Iexternal/gtest/include"],
linkopts=["-pthread"],
visibility=["//visibility:public"], )
licenses(["notice"]) # Apache 2.0
load("@protobuf//:protobuf.bzl", "cc_proto_library")
cc_proto_library(
name="example_proto",
srcs=["example.proto"],
protoc="@protobuf//:protoc",
default_runtime="@protobuf//:protobuf", )
cc_library(
name="example_lib",
srcs=["example_lib.cc"],
hdrs=["example_lib.h"],
deps=[":example_proto"], )
cc_test(
name="example_lib_test",
srcs=["example_lib_test.cc"],
copts=["-Iexternal/gtest/include"],
deps=[
"@gtest//:gtest",
":example_lib",
], )
This package tests that Bazel can build protobuf related rules.
syntax = "proto3";
package third_party.protobuf_test;
message Greeting {
string name = 1;
}
#include "third_party/protobuf_test/example_lib.h"
namespace third_party {
namespace protobuf_test {
std::string get_greet(const Greeting& who) { return "Hello " + who.name(); }
} // namespace protobuf_test
} // namespace thrid_party
#pragma once
#include "third_party/protobuf_test/example.pb.h"
#include <string>
namespace third_party {
namespace protobuf_test {
std::string get_greet(const Greeting &who);
} // namespace protobuf_test
} // namespace third_party
#include "third_party/protobuf_test/example_lib.h"
#include "gtest/gtest.h"
namespace third_party {
namespace protobuf_test {
TEST(ProtobufTest, GetGreet) {
Greeting g;
g.set_name("Paddle");
EXPECT_EQ("Hello Paddle", get_greet(g));
}
} // namespace protobuf_test
} // namespace third_party
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册