提交 eefe5a7c 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/mnist_train_api

......@@ -17,7 +17,7 @@ set -e
#Note the default model is pass-00002, you shold make sure the model path
#exists or change the mode path.
#only test on trainer_config.lr.py
model=output/pass-00001/
model=output/model/pass-00001/
config=trainer_config.lr.py
label=data/labels.list
dict=data/dict.txt
......
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
# Should run pserver.sh before run this script.
bin_dir=$(cd `dirname $0`; pwd)
home_dir=$(cd "${bin_dir}/.."; pwd)
source "$bin_dir/env.sh"
model_dir="$bin_dir/output"
log_file="$bin_dir/train.log"
pushd "$home_dir"
cfg=trainer_config.lr.py
paddle train \
--config=$cfg \
--save_dir=${model_dir} \
--trainer_count=4 \
--local=0 \
--log_period=100 \
--num_passes=15 \
--use_gpu=false \
--show_parameter_stats_period=100 \
--test_all_data_in_one_period=1 \
--num_gradient_servers=1 \
--nics=`get_nics` \
--port=7164 \
--ports_num=1 \
--pservers="127.0.0.1" \
--comment="paddle_trainer" \
2>&1 | tee "$log_file"
popd
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
function get_nics() {
machine=`uname -s`
local nics=""
if [ "$machine" == "Linux" ]; then
nics="lo"
elif [ "$machine" == "Darwin" ]; then
nics="lo0"
else
nics="unsupport"
fi
echo $nics
}
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
bin_dir=$(cd `dirname $0`; pwd)
source "$bin_dir/env.sh"
paddle pserver \
--nics=`get_nics` \
--port=7164 \
--ports_num=1 \
--ports_num_for_sparse=1 \
--num_gradient_servers=1 \
--comment="paddle_pserver" \
2>&1 | tee 'pserver.log'
......@@ -20,15 +20,11 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
/// Import PaddlePaddle's enumeration into global namespace.
using namespace paddle::enumeration_wrapper; // NOLINT
#define DISABLE_COPY_AND_ASSIGN(classname) \
classname(const classname& other); \
classname& operator=(const classname& other)
/**
* @brief Initialize paddle.
*
......@@ -102,7 +98,7 @@ const size_t NO_SPARSE_ID = -1UL;
struct MatrixPrivate;
class Matrix {
Matrix(); // User Cannot Create Matrix.
DISABLE_COPY_AND_ASSIGN(Matrix);
DISABLE_COPY(Matrix);
static Matrix* createByPaddleMatrixPtr(void* sharedPtr);
public:
......@@ -242,7 +238,7 @@ private:
struct VectorPrivate;
class Vector {
DISABLE_COPY_AND_ASSIGN(Vector);
DISABLE_COPY(Vector);
Vector();
static Vector* createByPaddleVectorPtr(void* ptr);
......@@ -322,7 +318,7 @@ private:
struct IVectorPrivate;
class IVector {
IVector();
DISABLE_COPY_AND_ASSIGN(IVector);
DISABLE_COPY(IVector);
static IVector* createByPaddleVectorPtr(void* ptr);
public:
......@@ -402,7 +398,7 @@ struct ArgumentsPrivate;
class Arguments {
private:
Arguments(); // Internal Create.
DISABLE_COPY_AND_ASSIGN(Arguments);
DISABLE_COPY(Arguments);
public:
/**
......@@ -472,7 +468,7 @@ enum GradientMatchineCreateMode {
struct ParameterConfigPrivate;
class ParameterConfig {
DISABLE_COPY_AND_ASSIGN(ParameterConfig);
DISABLE_COPY(ParameterConfig);
ParameterConfig();
/**
......@@ -502,7 +498,7 @@ private:
struct OptimizationConfigPrivate;
class OptimizationConfig {
DISABLE_COPY_AND_ASSIGN(OptimizationConfig);
DISABLE_COPY(OptimizationConfig);
OptimizationConfig();
public:
......@@ -527,7 +523,7 @@ struct ParameterPrivate;
class Parameter {
private:
Parameter();
DISABLE_COPY_AND_ASSIGN(Parameter);
DISABLE_COPY(Parameter);
public:
virtual ~Parameter();
......@@ -572,7 +568,7 @@ struct ModelConfigPrivate;
class ModelConfig {
private:
ModelConfig();
DISABLE_COPY_AND_ASSIGN(ModelConfig);
DISABLE_COPY(ModelConfig);
public:
virtual ~ModelConfig();
......@@ -593,7 +589,7 @@ struct TrainerConfigPrivate;
class TrainerConfig {
private:
TrainerConfig();
DISABLE_COPY_AND_ASSIGN(TrainerConfig);
DISABLE_COPY(TrainerConfig);
public:
virtual ~TrainerConfig();
......@@ -633,7 +629,7 @@ public:
struct ParameterTraverseCallbackPrivate;
class ParameterTraverseCallback {
DISABLE_COPY_AND_ASSIGN(ParameterTraverseCallback);
DISABLE_COPY(ParameterTraverseCallback);
ParameterTraverseCallback();
public:
......@@ -655,7 +651,7 @@ private:
*/
struct ParameterOptimizerPrivate;
class ParameterOptimizer {
DISABLE_COPY_AND_ASSIGN(ParameterOptimizer);
DISABLE_COPY(ParameterOptimizer);
ParameterOptimizer();
public:
......@@ -692,7 +688,7 @@ struct GradientMachinePrivate;
class GradientMachine {
private:
GradientMachine();
DISABLE_COPY_AND_ASSIGN(GradientMachine);
DISABLE_COPY(GradientMachine);
public:
virtual ~GradientMachine();
......@@ -908,7 +904,7 @@ private:
TrainerPrivate* m;
Trainer();
Trainer(TrainerConfig* optConfig, GradientMachine* gm);
DISABLE_COPY_AND_ASSIGN(Trainer);
DISABLE_COPY(Trainer);
public:
virtual ~Trainer();
......@@ -974,7 +970,7 @@ public:
struct SequenceGeneratorPrivate;
class SequenceGenerator {
DISABLE_COPY_AND_ASSIGN(SequenceGenerator);
DISABLE_COPY(SequenceGenerator);
SequenceGenerator();
public:
......
......@@ -141,9 +141,12 @@ try:
def c_flag(self):
if self.with_coverage:
return ["-fprofile-arcs", "-ftest-coverage", "-O0", "-g"]
return [
"-fprofile-arcs", "-ftest-coverage", "-O0", "-g",
"-std=c++11"
]
else:
return None
return ["-std=c++11"]
except ImportError:
class PaddleLDFlag(object):
......
......@@ -16,7 +16,31 @@ limitations under the License. */
#define HL_BASE_H_
#include <cstddef>
#include "paddle/utils/TypeDefs.h"
#ifdef PADDLE_TYPE_DOUBLE
#define HL_FLOAT_MAX 3.40282347e+38F
#define HL_FLOAT_MIN 1.17549435e-38F
using real = double;
#else
#define HL_FLOAT_MAX 1.7976931348623157e+308
#define HL_FLOAT_MIN 2.2250738585072014e-308
using real = float;
#endif
/**
* The maximum input value for exp, used to avoid overflow problem.
* currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
/**
* @brief DIVUP(x, y) is similar to ceil(x / y).
* @note For CUDA, DIVUP will be used to specify
* the size of blockDim.
*/
#ifndef DIVUP
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#endif
/**
* HPPL is an internal high performance parallel computing library
......@@ -181,46 +205,6 @@ typedef struct {
size_t nnz;
} _hl_sparse_matrix_s, *hl_sparse_matrix_s;
#ifndef PADDLE_TYPE_DOUBLE
/**
* HPPL data type: real (float or double)
*
* if real == float
*
* HL_FLOAT_MAX: 3.40282347e+38F
*
* HL_FLOAT_MIN: 1.17549435e-38F
*/
#define HL_FLOAT_MAX 3.40282347e+38F
/**
* if real == double
*
* HL_FLOAT_MAX: 1.7976931348623157e+308
*
* HL_FLOAT_MIN: 2.2250738585072014e-308
*/
#define HL_FLOAT_MIN 1.17549435e-38F
#else
#define HL_FLOAT_MAX 1.7976931348623157e+308
#define HL_FLOAT_MIN 2.2250738585072014e-308
#endif
/**
* The maximum input value for exp, used to avoid overflow problem.
*
* Currently only used for tanh function.
*/
#define EXP_MAX_INPUT 40.0
/**
* @brief DIVUP(x, y) is similar to ceil(x / y).
* @note For CUDA, DIVUP will be used to specify
* the size of blockDim.
*/
#ifndef DIVUP
#define DIVUP(x, y) (((x) + (y)-1) / (y))
#endif
#ifdef __NVCC__
#include "cuda_runtime.h"
......
......@@ -34,8 +34,8 @@ limitations under the License. */
#include "paddle/utils/Logging.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/common.h"
namespace paddle {
/**
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "ModelConfig.pb.h"
#include "hl_gpu.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include "ModelConfig.pb.h"
#include "hl_gpu.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <memory>
#include <random>
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <stdint.h>
#include <cstddef>
#include "TensorExpression.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -27,7 +27,7 @@ limitations under the License. */
#include "MemoryHandle.h"
#include "Vector.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -17,7 +17,7 @@ limitations under the License. */
#include <cstddef>
#include "hl_tensor_ops.h"
#include "paddle/utils/Logging.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -22,7 +22,7 @@ limitations under the License. */
#include "BaseMatrix.h"
#include "MemoryHandle.h"
#include "paddle/utils/Thread.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -28,7 +28,7 @@ limitations under the License. */
#include "paddle/parameter/ParameterUpdateFunctions.h"
#include "paddle/utils/Flags.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
#include "ParameterConfig.pb.h"
......
......@@ -29,8 +29,8 @@ limitations under the License. */
#include "paddle/utils/GlobalConstants.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/math/Vector.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/math/Matrix.h"
#include "paddle/pserver/ProtoServer.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
namespace paddle {
......
......@@ -26,8 +26,8 @@ limitations under the License. */
#include "paddle/utils/Flags.h"
#include "paddle/utils/Locks.h"
#include "paddle/utils/Queue.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/Util.h"
#include "paddle/utils/common.h"
#include "ParameterService.pb.h"
......
......@@ -32,7 +32,7 @@ limitations under the License. */
#include "paddle/utils/Locks.h"
#include "paddle/utils/Stat.h"
#include "paddle/utils/ThreadLocal.h"
#include "paddle/utils/TypeDefs.h"
#include "paddle/utils/common.h"
#include "ParameterService.pb.h"
......
......@@ -30,8 +30,10 @@ is_lin = (system == 'linux')
# The extra links will passed from COMAKE
# because generate paddle LDFLAGS is too complicated to do in setup.py
# it just read COMAKE generated LDFLAGS.
extra_comps = []
extra_links = []
obj = api.paddle_ld_flags.PaddleLDFlag()
extra_comps = obj.c_flag()
ldflags = obj.ldflag_str()
if ldflags is not None:
extra_links.extend(ldflags.split(" "))
......@@ -51,20 +53,15 @@ elif is_osx == True:
include_dirs = [np.get_include(), "../"] # include numpy and paddle.
extra_c = obj.c_flag()
attr=dict()
if extra_c is not None:
attr["extra_compile_args"] = extra_c
setup(name="py_paddle",
version="@PADDLE_VERSION@",
ext_modules=[
Extension('py_paddle._swig_paddle', # Build SWIG Extension.
['Paddle_wrap.cxx'],
language = "c++",
include_dirs = include_dirs,
extra_link_args = extra_links,
**attr
extra_compile_args = extra_comps
)
],
packages=['py_paddle'],
......
......@@ -33,8 +33,8 @@ namespace paddle {
because at the current moment, the merging on CPU is happening on the
main thread, and the its parameter size can be much larger than the one GPU.
Thus, for GPU, the parameter updates happens in updateImpl() function, which
is called by gradient machines as a callback function as a callback function
supplied to backward() and forwardBackward().
is called by gradient machines as a callback function supplied to backward()
and forwardBackward().
For CPU, the parameter updates happens in separate threads maintained by this
class.
*/
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#pragma once
#include "DisableCopy.h"
#include "common.h"
namespace paddle {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
/**
* Disable copy macro.
*/
#define DISABLE_COPY(CLASS_NAME) \
CLASS_NAME(CLASS_NAME &&) = delete; \
CLASS_NAME(const CLASS_NAME &other) = delete; \
CLASS_NAME &operator=(const CLASS_NAME &other) = delete
......@@ -19,7 +19,7 @@ limitations under the License. */
#include <condition_variable>
#include <mutex>
#include "DisableCopy.h"
#include "common.h"
namespace paddle {
......
......@@ -26,12 +26,11 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "DisableCopy.h"
#include "Logging.h"
#include "TrainerConfig.pb.h"
#include "common.h"
#include "Flags.h"
#include "TypeDefs.h"
#include "hl_gpu.h"
/**
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <stddef.h>
#include <iostream>
#include "TypeDefs.h"
#include "common.h"
namespace paddle {
......
......@@ -15,12 +15,19 @@ limitations under the License. */
#pragma once
namespace paddle {
/**
* Disable copy macro.
*/
#define DISABLE_COPY(class_name) \
class_name(class_name &&) = delete; \
class_name(const class_name &other) = delete; \
class_name &operator=(const class_name &other) = delete
#ifdef PADDLE_TYPE_DOUBLE
typedef double real;
using real = double;
#else
typedef float real;
using real = float;
#endif
} // namespace paddle
using paddle::real;
......@@ -21,6 +21,5 @@ from networks import *
from optimizers import *
from attrs import *
from config_parser_utils import *
# This will enable operator overload for LayerOutput
import math as layer_math
import layer_math
......@@ -19,34 +19,34 @@ __all__ = [
def convert_and_compare(x, Type):
"""
Convert x to be the same type as Type and then convert back to
check whether there is a loss of information
:param x: object to be checked
:param Type: target type to check x over
"""
Convert x to be the same type as Type and then convert back to
check whether there is a loss of information
:param x: object to be checked
:param Type: target type to check x over
"""
return type(x)(Type(x)) == x
def is_compatible_with(x, Type):
"""
Check if x has a type compatible with Type
:param x: object to be checked
:param Type: target type to check x over
"""
Check if x has a type compatible with Type
:param x: object to be checked
:param Type: target type to check x over
"""
if type(x) == Type:
return True
try:
if float == Type or int == Type:
# avoid those types that can be converted to float/int but not very
# meaningful and could potentially lead to error
# i.e., str and bool typed value should not be used for initializing float/int variable
# avoid those types that can be converted to float/int but not very
# meaningful and could potentially lead to error
# i.e., str and bool typed value should not be used for initializing float/int variable
if not isinstance(x, str) and not isinstance(x, bool):
return convert_and_compare(x, Type)
elif bool == Type:
# should not use string type to initialize bool variable
# should not use string type to initialize bool variable
if not isinstance(x, str):
return convert_and_compare(x, Type)
else:
......@@ -88,6 +88,10 @@ class ParameterAttribute(object):
:type learning_rate: float or None
:param momentum: The parameter momentum. None means use global value.
:type momentum: float or None
:param gradient_clipping_threshold: gradient clipping threshold. If gradient
value larger than some value, will be
clipped.
:type gradient_clipping_threshold: float
:param sparse_update: Enable sparse update for this parameter. It will
enable both local and remote sparse update.
:type sparse_update: bool
......@@ -104,6 +108,7 @@ class ParameterAttribute(object):
l2_rate=None,
learning_rate=None,
momentum=None,
gradient_clipping_threshold=None,
sparse_update=False):
# initialize strategy.
if is_static:
......@@ -152,6 +157,11 @@ class ParameterAttribute(object):
self.attr['sparse_update'] = True
self.attr['sparse_remote_update'] = True
if gradient_clipping_threshold is not None and \
is_compatible_with(gradient_clipping_threshold, float):
self.attr['gradient_clipping_threshold'] = \
gradient_clipping_threshold
def set_default_parameter_name(self, name):
"""
Set default parameter name. If parameter not set, then will use default
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册