提交 74701b26 编写于 作者: Q qiaolongfei

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-prefetch

...@@ -19,7 +19,7 @@ BasedOnStyle: Google ...@@ -19,7 +19,7 @@ BasedOnStyle: Google
IndentWidth: 2 IndentWidth: 2
TabWidth: 2 TabWidth: 2
ContinuationIndentWidth: 4 ContinuationIndentWidth: 4
AccessModifierOffset: -2 # The private/protected/public has no indent in class AccessModifierOffset: -1 # The private/protected/public has no indent in class
Standard: Cpp11 Standard: Cpp11
AllowAllParametersOfDeclarationOnNextLine: true AllowAllParametersOfDeclarationOnNextLine: true
BinPackParameters: false BinPackParameters: false
......
...@@ -34,6 +34,14 @@ repos: ...@@ -34,6 +34,14 @@ repos:
entry: bash ./tools/codestyle/cpplint_pre_commit.hook entry: bash ./tools/codestyle/cpplint_pre_commit.hook
language: system language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$ files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx)$
- repo: local
hooks:
- id: pylint-doc-string
name: pylint
description: Check python docstring style using docstring_checker.
entry: bash ./tools/codestyle/pylint_pre_commit.hook
language: system
files: \.(py)$
- repo: https://github.com/PaddlePaddle/pre-commit-golang - repo: https://github.com/PaddlePaddle/pre-commit-golang
sha: 8337620115c25ff8333f1b1a493bd031049bd7c0 sha: 8337620115c25ff8333f1b1a493bd031049bd7c0
hooks: hooks:
......
...@@ -18,6 +18,8 @@ env: ...@@ -18,6 +18,8 @@ env:
addons: addons:
ssh_known_hosts: 13.229.163.131 ssh_known_hosts: 13.229.163.131
before_install: before_install:
# For pylint dockstring checker
- sudo pip install pylint pytest astroid isort
- | - |
function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; } function timeout() { perl -e 'alarm shift; exec @ARGV' "$@"; }
script: script:
......
...@@ -79,6 +79,9 @@ RUN pip install pre-commit 'ipython==5.3.0' && \ ...@@ -79,6 +79,9 @@ RUN pip install pre-commit 'ipython==5.3.0' && \
pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \
pip install opencv-python pip install opencv-python
#For docstring checker
RUN pip install pylint pytest astroid isort
COPY ./python/requirements.txt /root/ COPY ./python/requirements.txt /root/
RUN pip install -r /root/requirements.txt RUN pip install -r /root/requirements.txt
......
...@@ -24,22 +24,22 @@ Currently supported `--model` argument include: ...@@ -24,22 +24,22 @@ Currently supported `--model` argument include:
* Run the following command to start a benchmark job locally: * Run the following command to start a benchmark job locally:
```bash ```bash
python fluid_benchmark.py --model mnist --parallel 1 --device GPU --with_test python fluid_benchmark.py --model mnist --device GPU
``` ```
You can choose to use GPU/CPU training. With GPU training, you can specify You can choose to use GPU/CPU training. With GPU training, you can specify
`--parallel 1` to run multi GPU training. `--gpus <gpu_num>` to run multi GPU training.
* Run distributed training with parameter servers: * Run distributed training with parameter servers:
* start parameter servers: * start parameter servers:
```bash ```bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --parallel 0 --device GPU --update_method pserver PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver
``` ```
* start trainers: * start trainers:
```bash ```bash
PADDLE_TRAINING_ROLE=PSERVER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --parallel 0 --device GPU --update_method pserver PADDLE_TRAINING_ROLE=TRAINER PADDLE_PSERVER_PORT=7164 PADDLE_PSERVER_IPS=127.0.0.1 PADDLE_TRAINERS=1 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method pserver
``` ```
* Run distributed training using NCCL2 * Run distributed training using NCCL2
```bash ```bash
PADDLE_PSERVER_PORT=7164 PADDLE_TRAINER_IPS=192.168.0.2,192.168.0.3 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --parallel 0 --device GPU --update_method nccl2 PADDLE_PSERVER_PORT=7164 PADDLE_TRAINER_IPS=192.168.0.2,192.168.0.3 PADDLE_CURRENT_IP=127.0.0.1 PADDLE_TRAINER_ID=0 python fluid_benchmark.py --model mnist --device GPU --update_method nccl2
``` ```
## Run Distributed Benchmark on Kubernetes Cluster ## Run Distributed Benchmark on Kubernetes Cluster
...@@ -48,7 +48,7 @@ We provide a script `kube_gen_job.py` to generate Kubernetes yaml files to submi ...@@ -48,7 +48,7 @@ We provide a script `kube_gen_job.py` to generate Kubernetes yaml files to submi
distributed benchmark jobs to your cluster. To generate a job yaml, just run: distributed benchmark jobs to your cluster. To generate a job yaml, just run:
```bash ```bash
python kube_gen_job.py --jobname myjob --pscpu 4 --cpu 8 --gpu 8 --psmemory 20 --memory 40 --pservers 4 --trainers 4 --entry "python fluid_benchmark.py --model mnist --parallel 1 --device GPU --update_method pserver --with_test" --disttype pserver python kube_gen_job.py --jobname myjob --pscpu 4 --cpu 8 --gpu 8 --psmemory 20 --memory 40 --pservers 4 --trainers 4 --entry "python fluid_benchmark.py --model mnist --parallel 1 --device GPU --update_method pserver " --disttype pserver
``` ```
Then the yaml files are generated under directory `myjob`, you can run: Then the yaml files are generated under directory `myjob`, you can run:
......
...@@ -56,24 +56,28 @@ set(dst_dir "${FLUID_INSTALL_DIR}/third_party/eigen3") ...@@ -56,24 +56,28 @@ set(dst_dir "${FLUID_INSTALL_DIR}/third_party/eigen3")
copy(eigen3_lib copy(eigen3_lib
SRCS ${EIGEN_INCLUDE_DIR}/Eigen/Core ${EIGEN_INCLUDE_DIR}/Eigen/src ${EIGEN_INCLUDE_DIR}/unsupported/Eigen SRCS ${EIGEN_INCLUDE_DIR}/Eigen/Core ${EIGEN_INCLUDE_DIR}/Eigen/src ${EIGEN_INCLUDE_DIR}/unsupported/Eigen
DSTS ${dst_dir}/Eigen ${dst_dir}/Eigen ${dst_dir}/unsupported DSTS ${dst_dir}/Eigen ${dst_dir}/Eigen ${dst_dir}/unsupported
DEPS eigen3
) )
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/gflags") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/gflags")
copy(gflags_lib copy(gflags_lib
SRCS ${GFLAGS_INCLUDE_DIR} ${GFLAGS_LIBRARIES} SRCS ${GFLAGS_INCLUDE_DIR} ${GFLAGS_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib DSTS ${dst_dir} ${dst_dir}/lib
DEPS gflags
) )
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/glog") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/glog")
copy(glog_lib copy(glog_lib
SRCS ${GLOG_INCLUDE_DIR} ${GLOG_LIBRARIES} SRCS ${GLOG_INCLUDE_DIR} ${GLOG_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib DSTS ${dst_dir} ${dst_dir}/lib
DEPS glog
) )
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/boost/") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/boost/")
copy(boost_lib copy(boost_lib
SRCS ${BOOST_INCLUDE_DIR}/boost SRCS ${BOOST_INCLUDE_DIR}/boost
DSTS ${dst_dir} DSTS ${dst_dir}
DEPS boost
) )
if(NOT PROTOBUF_FOUND) if(NOT PROTOBUF_FOUND)
...@@ -81,6 +85,7 @@ if(NOT PROTOBUF_FOUND) ...@@ -81,6 +85,7 @@ if(NOT PROTOBUF_FOUND)
copy(protobuf_lib copy(protobuf_lib
SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY} SRCS ${PROTOBUF_INCLUDE_DIR} ${PROTOBUF_LIBRARY}
DSTS ${dst_dir} ${dst_dir}/lib DSTS ${dst_dir} ${dst_dir}/lib
DEPS extern_protobuf
) )
endif() endif()
...@@ -89,12 +94,14 @@ if(NOT CBLAS_FOUND) ...@@ -89,12 +94,14 @@ if(NOT CBLAS_FOUND)
copy(openblas_lib copy(openblas_lib
SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include SRCS ${CBLAS_INSTALL_DIR}/lib ${CBLAS_INSTALL_DIR}/include
DSTS ${dst_dir} ${dst_dir} DSTS ${dst_dir} ${dst_dir}
DEPS extern_openblas
) )
elseif (WITH_MKLML) elseif (WITH_MKLML)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mklml") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/mklml")
copy(mklml_lib copy(mklml_lib
SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR} SRCS ${MKLML_LIB} ${MKLML_IOMP_LIB} ${MKLML_INC_DIR}
DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir} DSTS ${dst_dir}/lib ${dst_dir}/lib ${dst_dir}
DEPS mklml
) )
endif() endif()
...@@ -103,6 +110,7 @@ if(WITH_MKLDNN) ...@@ -103,6 +110,7 @@ if(WITH_MKLDNN)
copy(mkldnn_lib copy(mkldnn_lib
SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB} SRCS ${MKLDNN_INC_DIR} ${MKLDNN_SHARED_LIB}
DSTS ${dst_dir} ${dst_dir}/lib DSTS ${dst_dir} ${dst_dir}/lib
DEPS mkldnn
) )
endif() endif()
...@@ -110,17 +118,20 @@ if(NOT MOBILE_INFERENCE AND NOT RPI) ...@@ -110,17 +118,20 @@ if(NOT MOBILE_INFERENCE AND NOT RPI)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappy")
copy(snappy_lib copy(snappy_lib
SRCS ${SNAPPY_INCLUDE_DIR} ${SNAPPY_LIBRARIES} SRCS ${SNAPPY_INCLUDE_DIR} ${SNAPPY_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib
DEPS snappy)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappystream") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/snappystream")
copy(snappystream_lib copy(snappystream_lib
SRCS ${SNAPPYSTREAM_INCLUDE_DIR} ${SNAPPYSTREAM_LIBRARIES} SRCS ${SNAPPYSTREAM_INCLUDE_DIR} ${SNAPPYSTREAM_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib
DEPS snappystream)
set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib") set(dst_dir "${FLUID_INSTALL_DIR}/third_party/install/zlib")
copy(zlib_lib copy(zlib_lib
SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES} SRCS ${ZLIB_INCLUDE_DIR} ${ZLIB_LIBRARIES}
DSTS ${dst_dir} ${dst_dir}/lib) DSTS ${dst_dir} ${dst_dir}/lib
DEPS zlib)
endif() endif()
# paddle fluid module # paddle fluid module
......
...@@ -94,7 +94,7 @@ void UpdateCallback::apply(Parameter* p) { ...@@ -94,7 +94,7 @@ void UpdateCallback::apply(Parameter* p) {
} }
class UpdateCallbackWrapper { class UpdateCallbackWrapper {
public: public:
explicit UpdateCallbackWrapper(const UpdateCallback& callback) explicit UpdateCallbackWrapper(const UpdateCallback& callback)
: callback(const_cast<UpdateCallback&>(callback)) {} : callback(const_cast<UpdateCallback&>(callback)) {}
...@@ -105,7 +105,7 @@ public: ...@@ -105,7 +105,7 @@ public:
delete p; delete p;
} }
private: private:
UpdateCallback& callback; UpdateCallback& callback;
}; };
......
...@@ -59,9 +59,10 @@ class RangeError {}; ...@@ -59,9 +59,10 @@ class RangeError {};
/// Not support Error, such as access GPU memory directly, etc. /// Not support Error, such as access GPU memory directly, etc.
class UnsupportError : public std::runtime_error { class UnsupportError : public std::runtime_error {
public: public:
UnsupportError() : std::runtime_error(" "){}; UnsupportError() : std::runtime_error(" ") {}
UnsupportError(const std::string& message) : std::runtime_error(message){}; explicit UnsupportError(const std::string& message)
: std::runtime_error(message) {}
}; };
/// This type will map to python's list of float. /// This type will map to python's list of float.
...@@ -105,7 +106,7 @@ class Matrix { ...@@ -105,7 +106,7 @@ class Matrix {
DISABLE_COPY(Matrix); DISABLE_COPY(Matrix);
static Matrix* createByPaddleMatrixPtr(void* sharedPtr); static Matrix* createByPaddleMatrixPtr(void* sharedPtr);
public: public:
virtual ~Matrix(); virtual ~Matrix();
/** /**
...@@ -231,7 +232,7 @@ public: ...@@ -231,7 +232,7 @@ public:
bool isGpu() const; bool isGpu() const;
private: private:
void* getSharedPtr() const; void* getSharedPtr() const;
MatrixPrivate* m; MatrixPrivate* m;
...@@ -248,7 +249,7 @@ class Vector { ...@@ -248,7 +249,7 @@ class Vector {
void* getSharedPtr(); void* getSharedPtr();
public: public:
~Vector(); ~Vector();
/// Create Vector filled with zero. /// Create Vector filled with zero.
...@@ -310,10 +311,10 @@ public: ...@@ -310,10 +311,10 @@ public:
/// __len__ in python /// __len__ in python
size_t getSize() const; size_t getSize() const;
private: private:
VectorPrivate* m; VectorPrivate* m;
private: private:
friend class Parameter; friend class Parameter;
friend class ParameterOptimizer; friend class ParameterOptimizer;
friend struct ParameterTraverseCallbackPrivate; friend struct ParameterTraverseCallbackPrivate;
...@@ -325,7 +326,7 @@ class IVector { ...@@ -325,7 +326,7 @@ class IVector {
DISABLE_COPY(IVector); DISABLE_COPY(IVector);
static IVector* createByPaddleVectorPtr(void* ptr); static IVector* createByPaddleVectorPtr(void* ptr);
public: public:
/// Create IVector filled with zero /// Create IVector filled with zero
static IVector* createZero(size_t sz, bool useGpu = isUsingGpu()); static IVector* createZero(size_t sz, bool useGpu = isUsingGpu());
...@@ -389,7 +390,7 @@ public: ...@@ -389,7 +390,7 @@ public:
/// This method will map to python __len__(); /// This method will map to python __len__();
size_t getSize() const; size_t getSize() const;
private: private:
void* getSharedPtr() const; void* getSharedPtr() const;
friend class Arguments; friend class Arguments;
...@@ -400,11 +401,11 @@ struct ArgumentsPrivate; ...@@ -400,11 +401,11 @@ struct ArgumentsPrivate;
/// The Arguments is actual a std::vector<paddle::Argument> in paddle. /// The Arguments is actual a std::vector<paddle::Argument> in paddle.
class Arguments { class Arguments {
private: private:
Arguments(); // Internal Create. Arguments(); // Internal Create.
DISABLE_COPY(Arguments); DISABLE_COPY(Arguments);
public: public:
/** /**
* Create a arguments with size. * Create a arguments with size.
* Note that it can be zero. * Note that it can be zero.
...@@ -475,12 +476,12 @@ public: ...@@ -475,12 +476,12 @@ public:
float sum() const; float sum() const;
private: private:
static Arguments* createByPaddleArgumentVector(void* ptr); static Arguments* createByPaddleArgumentVector(void* ptr);
static Arguments* createByPaddleArgument(const void* ptr); static Arguments* createByPaddleArgument(const void* ptr);
void* getInternalArgumentsPtr() const; void* getInternalArgumentsPtr() const;
private: private:
ArgumentsPrivate* m; ArgumentsPrivate* m;
friend class Trainer; friend class Trainer;
friend class GradientMachine; friend class GradientMachine;
...@@ -507,7 +508,7 @@ class ParameterConfig { ...@@ -507,7 +508,7 @@ class ParameterConfig {
static ParameterConfig* createParameterConfigFromParameterPtr(void* ptr); static ParameterConfig* createParameterConfigFromParameterPtr(void* ptr);
void* getRawPtr(); void* getRawPtr();
public: public:
~ParameterConfig(); ~ParameterConfig();
/** /**
...@@ -515,10 +516,10 @@ public: ...@@ -515,10 +516,10 @@ public:
*/ */
std::string toProtoString() const; std::string toProtoString() const;
private: private:
ParameterConfigPrivate* m; ParameterConfigPrivate* m;
private: private:
friend class Parameter; friend class Parameter;
friend class ParameterOptimizer; friend class ParameterOptimizer;
friend struct ParameterTraverseCallbackPrivate; friend struct ParameterTraverseCallbackPrivate;
...@@ -529,7 +530,7 @@ class OptimizationConfig { ...@@ -529,7 +530,7 @@ class OptimizationConfig {
DISABLE_COPY(OptimizationConfig); DISABLE_COPY(OptimizationConfig);
OptimizationConfig(); OptimizationConfig();
public: public:
static OptimizationConfig* createFromProtoString(const std::string& str); static OptimizationConfig* createFromProtoString(const std::string& str);
~OptimizationConfig(); ~OptimizationConfig();
...@@ -538,7 +539,7 @@ public: ...@@ -538,7 +539,7 @@ public:
*/ */
std::string toProtoString(); std::string toProtoString();
private: private:
OptimizationConfigPrivate* m; OptimizationConfigPrivate* m;
friend class TrainerConfig; friend class TrainerConfig;
...@@ -549,11 +550,11 @@ private: ...@@ -549,11 +550,11 @@ private:
struct ParameterPrivate; struct ParameterPrivate;
class Parameter { class Parameter {
private: private:
Parameter(); Parameter();
DISABLE_COPY(Parameter); DISABLE_COPY(Parameter);
public: public:
virtual ~Parameter(); virtual ~Parameter();
/** /**
...@@ -580,11 +581,11 @@ public: ...@@ -580,11 +581,11 @@ public:
size_t getSize() const; size_t getSize() const;
private: private:
static Parameter* createFromRawPtr(void* ptr); static Parameter* createFromRawPtr(void* ptr);
static Parameter* createFromSharedPtr(void* ptr); static Parameter* createFromSharedPtr(void* ptr);
private: private:
ParameterPrivate* m; ParameterPrivate* m;
friend class UpdateCallbackWrapper; friend class UpdateCallbackWrapper;
friend class GradientMachine; friend class GradientMachine;
...@@ -598,14 +599,14 @@ struct ModelConfigPrivate; ...@@ -598,14 +599,14 @@ struct ModelConfigPrivate;
* It is used by GradientMachine. * It is used by GradientMachine.
*/ */
class ModelConfig { class ModelConfig {
private: private:
ModelConfig(); ModelConfig();
DISABLE_COPY(ModelConfig); DISABLE_COPY(ModelConfig);
public: public:
virtual ~ModelConfig(); virtual ~ModelConfig();
private: private:
ModelConfigPrivate* m; ModelConfigPrivate* m;
friend class TrainerConfig; friend class TrainerConfig;
friend struct TrainerConfigPrivate; friend struct TrainerConfigPrivate;
...@@ -619,11 +620,11 @@ struct TrainerConfigPrivate; ...@@ -619,11 +620,11 @@ struct TrainerConfigPrivate;
* It is used by GradientMachine. * It is used by GradientMachine.
*/ */
class TrainerConfig { class TrainerConfig {
private: private:
TrainerConfig(); TrainerConfig();
DISABLE_COPY(TrainerConfig); DISABLE_COPY(TrainerConfig);
public: public:
virtual ~TrainerConfig(); virtual ~TrainerConfig();
static TrainerConfig* createFromTrainerConfigFile( static TrainerConfig* createFromTrainerConfigFile(
...@@ -634,7 +635,7 @@ public: ...@@ -634,7 +635,7 @@ public:
OptimizationConfig* getOptimizationConfig() const; OptimizationConfig* getOptimizationConfig() const;
private: private:
TrainerConfigPrivate* m; TrainerConfigPrivate* m;
friend class Trainer; friend class Trainer;
}; };
...@@ -654,7 +655,7 @@ private: ...@@ -654,7 +655,7 @@ private:
* @endcode * @endcode
*/ */
class UpdateCallback { class UpdateCallback {
public: public:
virtual ~UpdateCallback(); virtual ~UpdateCallback();
virtual void apply(Parameter* p); virtual void apply(Parameter* p);
}; };
...@@ -664,14 +665,14 @@ class ParameterTraverseCallback { ...@@ -664,14 +665,14 @@ class ParameterTraverseCallback {
DISABLE_COPY(ParameterTraverseCallback); DISABLE_COPY(ParameterTraverseCallback);
ParameterTraverseCallback(); ParameterTraverseCallback();
public: public:
~ParameterTraverseCallback(); ~ParameterTraverseCallback();
void apply(const std::vector<Vector*>& vecs, void apply(const std::vector<Vector*>& vecs,
const ParameterConfig& config, const ParameterConfig& config,
size_t sparseId); size_t sparseId);
private: private:
ParameterTraverseCallbackPrivate* m; ParameterTraverseCallbackPrivate* m;
friend class ParameterOptimizer; friend class ParameterOptimizer;
}; };
...@@ -686,7 +687,7 @@ class ParameterOptimizer { ...@@ -686,7 +687,7 @@ class ParameterOptimizer {
DISABLE_COPY(ParameterOptimizer); DISABLE_COPY(ParameterOptimizer);
ParameterOptimizer(); ParameterOptimizer();
public: public:
static ParameterOptimizer* create(OptimizationConfig* config); static ParameterOptimizer* create(OptimizationConfig* config);
~ParameterOptimizer(); ~ParameterOptimizer();
...@@ -710,7 +711,7 @@ public: ...@@ -710,7 +711,7 @@ public:
ParameterTraverseCallback* needSpecialTraversal( ParameterTraverseCallback* needSpecialTraversal(
const ParameterConfig& config) const; const ParameterConfig& config) const;
private: private:
ParameterOptimizerPrivate* m; ParameterOptimizerPrivate* m;
}; };
...@@ -718,11 +719,11 @@ class SequenceGenerator; ...@@ -718,11 +719,11 @@ class SequenceGenerator;
class Evaluator; class Evaluator;
struct GradientMachinePrivate; struct GradientMachinePrivate;
class GradientMachine { class GradientMachine {
private: private:
GradientMachine(); GradientMachine();
DISABLE_COPY(GradientMachine); DISABLE_COPY(GradientMachine);
public: public:
virtual ~GradientMachine(); virtual ~GradientMachine();
/** /**
...@@ -817,7 +818,7 @@ public: ...@@ -817,7 +818,7 @@ public:
void eval(Evaluator* evaluator); void eval(Evaluator* evaluator);
private: private:
GradientMachinePrivate* m; GradientMachinePrivate* m;
static GradientMachine* createFromPaddleModelPtr( static GradientMachine* createFromPaddleModelPtr(
...@@ -833,10 +834,10 @@ private: ...@@ -833,10 +834,10 @@ private:
struct ParameterUpdaterPrivate; struct ParameterUpdaterPrivate;
class ParameterUpdater { class ParameterUpdater {
private: private:
ParameterUpdater(); ParameterUpdater();
public: public:
static ParameterUpdater* createLocalUpdater(OptimizationConfig* config); static ParameterUpdater* createLocalUpdater(OptimizationConfig* config);
static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config, static ParameterUpdater* createRemoteUpdater(OptimizationConfig* config,
int passCount, int passCount,
...@@ -911,17 +912,17 @@ public: ...@@ -911,17 +912,17 @@ public:
*/ */
void catchUpWith(); void catchUpWith();
private: private:
ParameterUpdaterPrivate* m; ParameterUpdaterPrivate* m;
}; };
struct EvaluatorPrivate; struct EvaluatorPrivate;
class Evaluator { class Evaluator {
private: private:
Evaluator(); Evaluator();
DISABLE_COPY(Evaluator); DISABLE_COPY(Evaluator);
public: public:
~Evaluator(); ~Evaluator();
/** /**
...@@ -945,7 +946,7 @@ public: ...@@ -945,7 +946,7 @@ public:
double getValue(const std::string name) const; double getValue(const std::string name) const;
private: private:
EvaluatorPrivate* m; EvaluatorPrivate* m;
friend class GradientMachine; friend class GradientMachine;
...@@ -953,13 +954,13 @@ private: ...@@ -953,13 +954,13 @@ private:
struct TrainerPrivate; struct TrainerPrivate;
class Trainer { class Trainer {
private: private:
TrainerPrivate* m; TrainerPrivate* m;
Trainer(); Trainer();
Trainer(TrainerConfig* optConfig, GradientMachine* gm); Trainer(TrainerConfig* optConfig, GradientMachine* gm);
DISABLE_COPY(Trainer); DISABLE_COPY(Trainer);
public: public:
virtual ~Trainer(); virtual ~Trainer();
/// Create A Trainer By TrainerConfig. using paddle command line. /// Create A Trainer By TrainerConfig. using paddle command line.
...@@ -1002,7 +1003,7 @@ public: ...@@ -1002,7 +1003,7 @@ public:
/// the N-Best results generated from one input sequence. /// the N-Best results generated from one input sequence.
class ISequenceResults { class ISequenceResults {
public: public:
virtual ~ISequenceResults(); virtual ~ISequenceResults();
/// Number of result. /// Number of result.
...@@ -1026,7 +1027,7 @@ class SequenceGenerator { ...@@ -1026,7 +1027,7 @@ class SequenceGenerator {
DISABLE_COPY(SequenceGenerator); DISABLE_COPY(SequenceGenerator);
SequenceGenerator(); SequenceGenerator();
public: public:
virtual ~SequenceGenerator(); virtual ~SequenceGenerator();
/** /**
...@@ -1044,10 +1045,10 @@ public: ...@@ -1044,10 +1045,10 @@ public:
void setMaxLength(size_t maxlength); void setMaxLength(size_t maxlength);
void setBeamSize(size_t beamSize); void setBeamSize(size_t beamSize);
private: private:
static SequenceGenerator* createByGradientMachineSharedPtr(void* ptr); static SequenceGenerator* createByGradientMachineSharedPtr(void* ptr);
friend class GradientMachine; friend class GradientMachine;
private: private:
SequenceGeneratorPrivate* m; SequenceGeneratorPrivate* m;
}; };
...@@ -138,7 +138,7 @@ struct SequenceGeneratorPrivate { ...@@ -138,7 +138,7 @@ struct SequenceGeneratorPrivate {
maxLength(0UL), maxLength(0UL),
feedback(__create_feedback__()) {} feedback(__create_feedback__()) {}
private: private:
static paddle::Argument __create_feedback__() { static paddle::Argument __create_feedback__() {
paddle::Argument feedback; paddle::Argument feedback;
feedback.ids = paddle::IVector::create(/* size= */ 1, FLAGS_use_gpu); feedback.ids = paddle::IVector::create(/* size= */ 1, FLAGS_use_gpu);
...@@ -157,7 +157,7 @@ SequenceGenerator::~SequenceGenerator() { delete m; } ...@@ -157,7 +157,7 @@ SequenceGenerator::~SequenceGenerator() { delete m; }
class PathSequenceResults : public ISequenceResults { class PathSequenceResults : public ISequenceResults {
// ISequenceResults interface // ISequenceResults interface
public: public:
PathSequenceResults(const std::shared_ptr<std::vector<Path>>& path, PathSequenceResults(const std::shared_ptr<std::vector<Path>>& path,
const std::shared_ptr<std::vector<std::string>>& dict) const std::shared_ptr<std::vector<std::string>>& dict)
: path_(path), dict_(dict) {} : path_(path), dict_(dict) {}
...@@ -196,7 +196,7 @@ public: ...@@ -196,7 +196,7 @@ public:
} }
} }
private: private:
std::shared_ptr<std::vector<Path>> path_; std::shared_ptr<std::vector<Path>> path_;
std::shared_ptr<std::vector<std::string>> dict_; std::shared_ptr<std::vector<std::string>> dict_;
}; };
......
...@@ -26,7 +26,7 @@ enum GradientMatchineCreateMode { ...@@ -26,7 +26,7 @@ enum GradientMatchineCreateMode {
namespace paddle { namespace paddle {
class MyNeuralNetwork : public NeuralNetwork { class MyNeuralNetwork : public NeuralNetwork {
public: public:
MyNeuralNetwork(const std::string& name, NeuralNetwork* network) MyNeuralNetwork(const std::string& name, NeuralNetwork* network)
: NeuralNetwork(name, network) {} : NeuralNetwork(name, network) {}
}; };
......
...@@ -50,7 +50,7 @@ struct PaddleTensor { ...@@ -50,7 +50,7 @@ struct PaddleTensor {
* TODO(Superjomn) Prepare another API for NLP-related usages. * TODO(Superjomn) Prepare another API for NLP-related usages.
*/ */
class PaddlePredictor { class PaddlePredictor {
public: public:
struct Config; struct Config;
PaddlePredictor() = default; PaddlePredictor() = default;
PaddlePredictor(const PaddlePredictor&) = delete; PaddlePredictor(const PaddlePredictor&) = delete;
...@@ -66,6 +66,7 @@ public: ...@@ -66,6 +66,7 @@ public:
// be thread-safe. // be thread-safe.
virtual std::unique_ptr<PaddlePredictor> Clone() = 0; virtual std::unique_ptr<PaddlePredictor> Clone() = 0;
virtual bool InitShared() { return false; }
// Destroy the Predictor. // Destroy the Predictor.
virtual ~PaddlePredictor() {} virtual ~PaddlePredictor() {}
......
...@@ -28,7 +28,7 @@ namespace { ...@@ -28,7 +28,7 @@ namespace {
// Timer for timer // Timer for timer
class Timer { class Timer {
public: public:
double start; double start;
double startu; double startu;
void tic() { void tic() {
...@@ -135,8 +135,8 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs, ...@@ -135,8 +135,8 @@ bool PaddlePredictorImpl::Run(const std::vector<PaddleTensor> &inputs,
std::unique_ptr<PaddlePredictor> PaddlePredictorImpl::Clone() { std::unique_ptr<PaddlePredictor> PaddlePredictorImpl::Clone() {
VLOG(3) << "Predictor::clone"; VLOG(3) << "Predictor::clone";
std::unique_ptr<PaddlePredictorImpl> cls(new PaddlePredictorImpl(config_)); std::unique_ptr<PaddlePredictor> cls(new PaddlePredictorImpl(config_));
if (!cls->InitShared(this)) { if (!cls->InitShared()) {
LOG(ERROR) << "fail to call InitShared"; LOG(ERROR) << "fail to call InitShared";
return nullptr; return nullptr;
} }
...@@ -144,7 +144,7 @@ std::unique_ptr<PaddlePredictor> PaddlePredictorImpl::Clone() { ...@@ -144,7 +144,7 @@ std::unique_ptr<PaddlePredictor> PaddlePredictorImpl::Clone() {
} }
// TODO(panyx0718): Consider merge with Init()? // TODO(panyx0718): Consider merge with Init()?
bool PaddlePredictorImpl::InitShared(PaddlePredictorImpl *cls) { bool PaddlePredictorImpl::InitShared() {
VLOG(3) << "Predictor::init_shared"; VLOG(3) << "Predictor::init_shared";
// 1. Define place, executor, scope // 1. Define place, executor, scope
if (this->config_.device >= 0) { if (this->config_.device >= 0) {
......
...@@ -41,7 +41,7 @@ struct VisConfig : public PaddlePredictor::Config { ...@@ -41,7 +41,7 @@ struct VisConfig : public PaddlePredictor::Config {
* Do not use this, just a demo indicating how to customize a Predictor. * Do not use this, just a demo indicating how to customize a Predictor.
*/ */
class PaddlePredictorImpl : public PaddlePredictor { class PaddlePredictorImpl : public PaddlePredictor {
public: public:
explicit PaddlePredictorImpl(const VisConfig &config) : config_(config) {} explicit PaddlePredictorImpl(const VisConfig &config) : config_(config) {}
bool Init(); bool Init();
...@@ -53,8 +53,8 @@ public: ...@@ -53,8 +53,8 @@ public:
~PaddlePredictorImpl() override{}; ~PaddlePredictorImpl() override{};
private: private:
bool InitShared(PaddlePredictorImpl *cls); bool InitShared();
bool SetFeed(const std::vector<PaddleTensor> &input_datas, bool SetFeed(const std::vector<PaddleTensor> &input_datas,
std::vector<paddle::framework::LoDTensor> *feeds); std::vector<paddle::framework::LoDTensor> *feeds);
bool GetFetch(const std::vector<paddle::framework::LoDTensor> &fetchs, bool GetFetch(const std::vector<paddle::framework::LoDTensor> &fetchs,
......
...@@ -31,7 +31,7 @@ struct DemoConfig : public PaddlePredictor::Config { ...@@ -31,7 +31,7 @@ struct DemoConfig : public PaddlePredictor::Config {
* Do not use this, just a demo indicating how to customize a Predictor. * Do not use this, just a demo indicating how to customize a Predictor.
*/ */
class DemoPredictor : public PaddlePredictor { class DemoPredictor : public PaddlePredictor {
public: public:
explicit DemoPredictor(const DemoConfig &config) { explicit DemoPredictor(const DemoConfig &config) {
LOG(INFO) << "I get other_config " << config.other_config; LOG(INFO) << "I get other_config " << config.other_config;
} }
......
...@@ -44,7 +44,7 @@ TEST(paddle_inference_api_impl, word2vec) { ...@@ -44,7 +44,7 @@ TEST(paddle_inference_api_impl, word2vec) {
VisConfig config; VisConfig config;
config.model_dir = FLAGS_dirname + "word2vec.inference.model"; config.model_dir = FLAGS_dirname + "word2vec.inference.model";
LOG(INFO) << "dirname " << config.model_dir; LOG(INFO) << "dirname " << config.model_dir;
config.fraction_of_gpu_memory = 0.85; config.fraction_of_gpu_memory = 0.15;
config.device = 0; config.device = 0;
config.share_variables = true; config.share_variables = true;
......
...@@ -31,7 +31,7 @@ namespace hppl { ...@@ -31,7 +31,7 @@ namespace hppl {
*/ */
template <class T> template <class T>
class Active { class Active {
public: public:
typedef T (*forward)(T); typedef T (*forward)(T);
typedef T (*backward)(T, T); typedef T (*backward)(T, T);
}; };
......
...@@ -23,128 +23,128 @@ namespace unary { ...@@ -23,128 +23,128 @@ namespace unary {
template <class T> template <class T>
class add_scale { class add_scale {
private: private:
const T p; const T p;
public: public:
INLINE add_scale(const T s) : p(s) {} INLINE add_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a + p; } INLINE T operator()(const T a) const { return a + p; }
}; };
template <class T> template <class T>
class sub_scale { class sub_scale {
private: private:
const T p; const T p;
public: public:
INLINE sub_scale(const T s) : p(s) {} INLINE sub_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a - p; } INLINE T operator()(const T a) const { return a - p; }
}; };
template <class T> template <class T>
class mul_scale { class mul_scale {
private: private:
const T p; const T p;
public: public:
INLINE mul_scale(const T s) : p(s) {} INLINE mul_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a * p; } INLINE T operator()(const T a) const { return a * p; }
}; };
template <class T> template <class T>
class div_scale { class div_scale {
private: private:
const T p; const T p;
public: public:
INLINE div_scale(const T s) : p(s) {} INLINE div_scale(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a / p; } INLINE T operator()(const T a) const { return a / p; }
}; };
template <class T> template <class T>
class neg { class neg {
public: public:
INLINE T operator()(const T a) const { return -a; } INLINE T operator()(const T a) const { return -a; }
}; };
template <class T> template <class T>
class exp_op { class exp_op {
public: public:
INLINE T operator()(const T a) const { return std::exp(a); } INLINE T operator()(const T a) const { return std::exp(a); }
}; };
template <class T> template <class T>
class log_op { class log_op {
public: public:
INLINE T operator()(const T a) const { return std::log(a); } INLINE T operator()(const T a) const { return std::log(a); }
}; };
template <class T> template <class T>
class sqrt_op { class sqrt_op {
public: public:
INLINE T operator()(const T a) const { return std::sqrt(a); } INLINE T operator()(const T a) const { return std::sqrt(a); }
}; };
template <class T> template <class T>
class square { class square {
public: public:
INLINE T operator()(const T a) const { return a * a; } INLINE T operator()(const T a) const { return a * a; }
}; };
template <class T> template <class T>
class reciprocal { class reciprocal {
public: public:
INLINE T operator()(const T a) const { return T(1) / a; } INLINE T operator()(const T a) const { return T(1) / a; }
}; };
template <class T> template <class T>
class abs { class abs {
public: public:
INLINE T operator()(const T a) const { return a > 0 ? a : -a; } INLINE T operator()(const T a) const { return a > 0 ? a : -a; }
}; };
template <class T> template <class T>
class sign { class sign {
public: public:
INLINE T operator()(const T a) const { return (a > 0) - (a < 0); } INLINE T operator()(const T a) const { return (a > 0) - (a < 0); }
}; };
template <class T> template <class T>
class min { class min {
private: private:
const T p; const T p;
public: public:
INLINE min(const T s) : p(s) {} INLINE min(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a > p ? p : a; } INLINE T operator()(const T a) const { return a > p ? p : a; }
}; };
template <class T> template <class T>
class max { class max {
private: private:
const T p; const T p;
public: public:
INLINE max(const T s) : p(s) {} INLINE max(const T s) : p(s) {}
INLINE T operator()(const T a) const { return a < p ? p : a; } INLINE T operator()(const T a) const { return a < p ? p : a; }
}; };
template <class T> template <class T>
class pow_op { class pow_op {
private: private:
const T p; const T p;
public: public:
INLINE pow_op(const T s) : p(s) {} INLINE pow_op(const T s) : p(s) {}
INLINE T operator()(const T a) const { return std::pow(a, p); } INLINE T operator()(const T a) const { return std::pow(a, p); }
}; };
template <class T> template <class T>
class constant { class constant {
private: private:
const T p; const T p;
public: public:
INLINE constant(const T s) : p(s) {} INLINE constant(const T s) : p(s) {}
INLINE T operator()(int i) const { return p; } INLINE T operator()(int i) const { return p; }
INLINE T operator()(int i, int j) const { return p; } INLINE T operator()(int i, int j) const { return p; }
...@@ -152,80 +152,80 @@ public: ...@@ -152,80 +152,80 @@ public:
template <class T> template <class T>
class cmp_eq { class cmp_eq {
private: private:
const T p; const T p;
public: public:
INLINE cmp_eq(const T s) : p(s) {} INLINE cmp_eq(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a == p; } INLINE bool operator()(const T a) const { return a == p; }
}; };
template <class T> template <class T>
class cmp_ne { class cmp_ne {
private: private:
const T p; const T p;
public: public:
INLINE cmp_ne(const T s) : p(s) {} INLINE cmp_ne(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a != p; } INLINE bool operator()(const T a) const { return a != p; }
}; };
template <class T> template <class T>
class cmp_le { class cmp_le {
private: private:
const T p; const T p;
public: public:
INLINE cmp_le(const T s) : p(s) {} INLINE cmp_le(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a <= p; } INLINE bool operator()(const T a) const { return a <= p; }
}; };
template <class T> template <class T>
class cmp_lt { class cmp_lt {
private: private:
const T p; const T p;
public: public:
INLINE cmp_lt(const T s) : p(s) {} INLINE cmp_lt(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a < p; } INLINE bool operator()(const T a) const { return a < p; }
}; };
template <class T> template <class T>
class cmp_ge { class cmp_ge {
private: private:
const T p; const T p;
public: public:
INLINE cmp_ge(const T s) : p(s) {} INLINE cmp_ge(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a >= p; } INLINE bool operator()(const T a) const { return a >= p; }
}; };
template <class T> template <class T>
class cmp_gt { class cmp_gt {
private: private:
const T p; const T p;
public: public:
INLINE cmp_gt(const T s) : p(s) {} INLINE cmp_gt(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a > p; } INLINE bool operator()(const T a) const { return a > p; }
}; };
template <class T> template <class T>
class and_op { class and_op {
private: private:
const T p; const T p;
public: public:
INLINE and_op(const T s) : p(s) {} INLINE and_op(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a && p; } INLINE bool operator()(const T a) const { return a && p; }
}; };
template <class T> template <class T>
class or_op { class or_op {
private: private:
const T p; const T p;
public: public:
INLINE or_op(const T s) : p(s) {} INLINE or_op(const T s) : p(s) {}
INLINE bool operator()(const T a) const { return a || p; } INLINE bool operator()(const T a) const { return a || p; }
}; };
...@@ -235,96 +235,96 @@ public: ...@@ -235,96 +235,96 @@ public:
namespace binary { namespace binary {
template <class T> template <class T>
class add { class add {
public: public:
INLINE T operator()(const T a, const T b) const { return a + b; } INLINE T operator()(const T a, const T b) const { return a + b; }
}; };
template <class T> template <class T>
class add_scale { class add_scale {
private: private:
const T p1; const T p1;
const T p2; const T p2;
public: public:
INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {} INLINE add_scale(const T s1, const T s2) : p1(s1), p2(s2) {}
INLINE T operator()(const T a, const T b) const { return p1 * a + p2 * b; } INLINE T operator()(const T a, const T b) const { return p1 * a + p2 * b; }
}; };
template <class T> template <class T>
class sub { class sub {
public: public:
INLINE T operator()(const T a, const T b) const { return a - b; } INLINE T operator()(const T a, const T b) const { return a - b; }
}; };
template <class T> template <class T>
class mul { class mul {
public: public:
INLINE T operator()(const T a, const T b) const { return a * b; } INLINE T operator()(const T a, const T b) const { return a * b; }
}; };
template <class T> template <class T>
class div { class div {
public: public:
INLINE T operator()(const T a, const T b) const { return a / b; } INLINE T operator()(const T a, const T b) const { return a / b; }
}; };
template <class T> template <class T>
class cmp_eq { class cmp_eq {
public: public:
INLINE bool operator()(const T a, const T b) const { return a == b; } INLINE bool operator()(const T a, const T b) const { return a == b; }
}; };
template <class T> template <class T>
class cmp_ne { class cmp_ne {
public: public:
INLINE bool operator()(const T a, const T b) const { return a != b; } INLINE bool operator()(const T a, const T b) const { return a != b; }
}; };
template <class T> template <class T>
class cmp_le { class cmp_le {
public: public:
INLINE bool operator()(const T a, const T b) const { return a <= b; } INLINE bool operator()(const T a, const T b) const { return a <= b; }
}; };
template <class T> template <class T>
class cmp_lt { class cmp_lt {
public: public:
INLINE bool operator()(const T a, const T b) const { return a < b; } INLINE bool operator()(const T a, const T b) const { return a < b; }
}; };
template <class T> template <class T>
class cmp_ge { class cmp_ge {
public: public:
INLINE bool operator()(const T a, const T b) const { return a >= b; } INLINE bool operator()(const T a, const T b) const { return a >= b; }
}; };
template <class T> template <class T>
class cmp_gt { class cmp_gt {
public: public:
INLINE bool operator()(const T a, const T b) const { return a > b; } INLINE bool operator()(const T a, const T b) const { return a > b; }
}; };
template <class T> template <class T>
class and_op { class and_op {
public: public:
INLINE bool operator()(const T a, const T b) const { return a && b; } INLINE bool operator()(const T a, const T b) const { return a && b; }
}; };
template <class T> template <class T>
class or_op { class or_op {
public: public:
INLINE bool operator()(const T a, const T b) const { return a || b; } INLINE bool operator()(const T a, const T b) const { return a || b; }
}; };
template <class T> template <class T>
class min { class min {
public: public:
INLINE T operator()(const T a, const T b) const { return a > b ? b : a; } INLINE T operator()(const T a, const T b) const { return a > b ? b : a; }
}; };
template <class T> template <class T>
class max { class max {
public: public:
INLINE T operator()(const T a, const T b) const { return a < b ? b : a; } INLINE T operator()(const T a, const T b) const { return a < b ? b : a; }
}; };
...@@ -332,7 +332,7 @@ public: ...@@ -332,7 +332,7 @@ public:
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
template <> template <>
class add<__m128> { class add<__m128> {
public: public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_add_ps(a, b); return _mm_add_ps(a, b);
} }
...@@ -340,11 +340,11 @@ public: ...@@ -340,11 +340,11 @@ public:
template <> template <>
class add_scale<__m128> { class add_scale<__m128> {
private: private:
const __m128 p1; const __m128 p1;
const __m128 p2; const __m128 p2;
public: public:
INLINE add_scale(const __m128 s1, const __m128 s2) : p1(s1), p2(s2) {} INLINE add_scale(const __m128 s1, const __m128 s2) : p1(s1), p2(s2) {}
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_add_ps(_mm_mul_ps(p1, a), _mm_mul_ps(p2, b)); return _mm_add_ps(_mm_mul_ps(p1, a), _mm_mul_ps(p2, b));
...@@ -353,7 +353,7 @@ public: ...@@ -353,7 +353,7 @@ public:
template <> template <>
class sub<__m128> { class sub<__m128> {
public: public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_sub_ps(a, b); return _mm_sub_ps(a, b);
} }
...@@ -361,7 +361,7 @@ public: ...@@ -361,7 +361,7 @@ public:
template <> template <>
class mul<__m128> { class mul<__m128> {
public: public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_mul_ps(a, b); return _mm_mul_ps(a, b);
} }
...@@ -369,7 +369,7 @@ public: ...@@ -369,7 +369,7 @@ public:
template <> template <>
class div<__m128> { class div<__m128> {
public: public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_div_ps(a, b); return _mm_div_ps(a, b);
} }
...@@ -377,7 +377,7 @@ public: ...@@ -377,7 +377,7 @@ public:
template <> template <>
class min<__m128> { class min<__m128> {
public: public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_min_ps(a, b); return _mm_min_ps(a, b);
} }
...@@ -385,7 +385,7 @@ public: ...@@ -385,7 +385,7 @@ public:
template <> template <>
class max<__m128> { class max<__m128> {
public: public:
INLINE __m128 operator()(const __m128 a, const __m128 b) const { INLINE __m128 operator()(const __m128 a, const __m128 b) const {
return _mm_max_ps(a, b); return _mm_max_ps(a, b);
} }
...@@ -393,7 +393,7 @@ public: ...@@ -393,7 +393,7 @@ public:
#else #else
template <> template <>
class add<__m128d> { class add<__m128d> {
public: public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_add_pd(a, b); return _mm_add_pd(a, b);
} }
...@@ -401,11 +401,11 @@ public: ...@@ -401,11 +401,11 @@ public:
template <> template <>
class add_scale<__m128d> { class add_scale<__m128d> {
private: private:
const __m128d p1; const __m128d p1;
const __m128d p2; const __m128d p2;
public: public:
INLINE add_scale(const __m128d s1, const __m128d s2) : p1(s1), p2(s2) {} INLINE add_scale(const __m128d s1, const __m128d s2) : p1(s1), p2(s2) {}
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_add_pd(_mm_mul_pd(p1, a), _mm_mul_pd(p2, b)); return _mm_add_pd(_mm_mul_pd(p1, a), _mm_mul_pd(p2, b));
...@@ -414,7 +414,7 @@ public: ...@@ -414,7 +414,7 @@ public:
template <> template <>
class sub<__m128d> { class sub<__m128d> {
public: public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_sub_pd(a, b); return _mm_sub_pd(a, b);
} }
...@@ -422,7 +422,7 @@ public: ...@@ -422,7 +422,7 @@ public:
template <> template <>
class mul<__m128d> { class mul<__m128d> {
public: public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_mul_pd(a, b); return _mm_mul_pd(a, b);
} }
...@@ -430,7 +430,7 @@ public: ...@@ -430,7 +430,7 @@ public:
template <> template <>
class div<__m128d> { class div<__m128d> {
public: public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_div_pd(a, b); return _mm_div_pd(a, b);
} }
...@@ -438,7 +438,7 @@ public: ...@@ -438,7 +438,7 @@ public:
template <> template <>
class min<__m128d> { class min<__m128d> {
public: public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_min_pd(a, b); return _mm_min_pd(a, b);
} }
...@@ -446,7 +446,7 @@ public: ...@@ -446,7 +446,7 @@ public:
template <> template <>
class max<__m128d> { class max<__m128d> {
public: public:
INLINE __m128d operator()(const __m128d a, const __m128d b) const { INLINE __m128d operator()(const __m128d a, const __m128d b) const {
return _mm_max_pd(a, b); return _mm_max_pd(a, b);
} }
...@@ -458,7 +458,7 @@ public: ...@@ -458,7 +458,7 @@ public:
#ifndef PADDLE_TYPE_DOUBLE #ifndef PADDLE_TYPE_DOUBLE
template <> template <>
class add<float32x4_t> { class add<float32x4_t> {
public: public:
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const { const float32x4_t b) const {
return vaddq_f32(a, b); return vaddq_f32(a, b);
...@@ -467,11 +467,11 @@ public: ...@@ -467,11 +467,11 @@ public:
template <> template <>
class add_scale<float32x4_t> { class add_scale<float32x4_t> {
private: private:
const float32x4_t p1; const float32x4_t p1;
const float32x4_t p2; const float32x4_t p2;
public: public:
INLINE add_scale(const float32x4_t s1, const float32x4_t s2) INLINE add_scale(const float32x4_t s1, const float32x4_t s2)
: p1(s1), p2(s2) {} : p1(s1), p2(s2) {}
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
...@@ -482,7 +482,7 @@ public: ...@@ -482,7 +482,7 @@ public:
template <> template <>
class sub<float32x4_t> { class sub<float32x4_t> {
public: public:
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const { const float32x4_t b) const {
return vsubq_f32(a, b); return vsubq_f32(a, b);
...@@ -491,7 +491,7 @@ public: ...@@ -491,7 +491,7 @@ public:
template <> template <>
class mul<float32x4_t> { class mul<float32x4_t> {
public: public:
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const { const float32x4_t b) const {
return vmulq_f32(a, b); return vmulq_f32(a, b);
...@@ -500,7 +500,7 @@ public: ...@@ -500,7 +500,7 @@ public:
template <> template <>
class div<float32x4_t> { class div<float32x4_t> {
public: public:
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const { const float32x4_t b) const {
float32x4_t tmp = vrecpeq_f32(b); float32x4_t tmp = vrecpeq_f32(b);
...@@ -510,7 +510,7 @@ public: ...@@ -510,7 +510,7 @@ public:
template <> template <>
class min<float32x4_t> { class min<float32x4_t> {
public: public:
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const { const float32x4_t b) const {
return vminq_f32(a, b); return vminq_f32(a, b);
...@@ -519,7 +519,7 @@ public: ...@@ -519,7 +519,7 @@ public:
template <> template <>
class max<float32x4_t> { class max<float32x4_t> {
public: public:
INLINE float32x4_t operator()(const float32x4_t a, INLINE float32x4_t operator()(const float32x4_t a,
const float32x4_t b) const { const float32x4_t b) const {
return vmaxq_f32(a, b); return vmaxq_f32(a, b);
......
...@@ -30,7 +30,7 @@ bool hl_lstm_sequence_parallel(int frameSize) { ...@@ -30,7 +30,7 @@ bool hl_lstm_sequence_parallel(int frameSize) {
} }
class frameValue { class frameValue {
public: public:
real *value_; real *value_;
__device__ frameValue(real *value) : value_(value) {} __device__ frameValue(real *value) : value_(value) {}
template <int reversed, int frameSize> template <int reversed, int frameSize>
......
...@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context ...@@ -3,7 +3,7 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory) cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry) cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base) cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph) cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
...@@ -26,7 +26,7 @@ endif() ...@@ -26,7 +26,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor) cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle) scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto) cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
...@@ -12,12 +12,13 @@ ...@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include <fstream>
#include <utility> #include <utility>
#include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h" #include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
...@@ -28,6 +29,10 @@ ...@@ -28,6 +29,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
...@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, ...@@ -79,9 +84,44 @@ void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
} }
} }
bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainSendVars(
OpDesc *send_op) const { const ProgramDesc &program) const {
if (send_op == nullptr) { std::vector<std::string> send_vars;
// since parameters are all in block 0,
// it's enough to only scan send ops in block 0
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find send op,
// instead of the the hard code string
if (op->Type() == "send_vars") {
auto op_vars = op->InputArgumentNames();
send_vars.reserve(send_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
send_vars.insert(send_vars.end(), op_vars.begin(), op_vars.end());
}
}
return send_vars;
}
std::vector<std::string> MultiDevSSAGraphBuilder::FindDistTrainRecvVars(
const ProgramDesc &program) const {
std::vector<std::string> recv_vars;
for (auto *op : program.Block(0).AllOps()) {
// TODO(Yancey1989): use a graceful method to find recv op,
// instead of the hard code string
if (op->Type() == "recv") {
auto op_vars = op->OutputArgumentNames();
recv_vars.reserve(recv_vars.size() +
std::distance(op_vars.begin(), op_vars.end()));
recv_vars.insert(recv_vars.end(), op_vars.begin(), op_vars.end());
}
}
return recv_vars;
}
bool MultiDevSSAGraphBuilder::IsDistTrainOp(
const OpDesc &op, const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const {
if (send_vars.size() == 0 || recv_vars.size() == 0) {
return false; return false;
} }
...@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, ...@@ -89,22 +129,21 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op,
* Check any of opvars contains `.block` and in sendvars * Check any of opvars contains `.block` and in sendvars
*/ */
auto checker = [](const std::vector<std::string> &opvars, auto checker = [](const std::vector<std::string> &opvars,
const std::vector<std::string> &sendvars) -> bool { const std::vector<std::string> &rpc_vars) -> bool {
for (auto &var : opvars) { for (auto &var : opvars) {
// a variable name with the suffix `.block` means it's a splited
// variable by (DistributeTranspiler)
// [python/paddle/fluid/transpiler/distribute_transpiler.py]
if (var.find(".block") != std::string::npos && if (var.find(".block") != std::string::npos &&
std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { std::find(rpc_vars.begin(), rpc_vars.end(), var) != rpc_vars.end()) {
return true; return true;
} }
} }
return false; return false;
}; };
if (op.Type() == "split" || op.Type() == "split_byref") { return checker(op.OutputArgumentNames(), send_vars) ||
return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); checker(op.InputArgumentNames(), recv_vars);
} else if (op.Type() == "concat") {
return checker(op.InputArgumentNames(), send_op->OutputArgumentNames());
}
return false;
} }
std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
...@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -123,8 +162,10 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>( std::unordered_map<std::string, std::vector<std::unique_ptr<VarHandle>>>>(
places_.size()); places_.size());
// Find "send" op first for split is in front of send. // find send/recv vars so that we can place the distributed training
OpDesc *send_op = GetSendOpDesc(program); // realted op in the place 0
auto send_vars = FindDistTrainSendVars(program);
auto recv_vars = FindDistTrainRecvVars(program);
size_t cur_device_id = 0; size_t cur_device_id = 0;
std::vector<std::unordered_set<std::string>> var_name_on_devices; std::vector<std::unordered_set<std::string>> var_name_on_devices;
...@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -134,12 +175,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
bool is_forwarding = true; bool is_forwarding = true;
for (auto *op : program.Block(0).AllOps()) { for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") { if (boost::get<int>(
// append send op if program is distributed trainer main program. op->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) ==
static_cast<int>(OpRole::kRPC)) {
// append rpc op if program is distributed trainer main program.
// always use the first device // always use the first device
CreateSendOp(&result, *op); CreateRPCOp(&result, *op);
} else if (IsDistTrainOp(*op, send_op)) { } else if (IsDistTrainOp(*op, send_vars, recv_vars)) {
CreateComputationalOps(&result, *op, 1); CreateDistTrainOp(&result, *op);
} else if (IsScaleLossOp(*op)) { } else if (IsScaleLossOp(*op)) {
// user can customize loss@grad if not use_default_grad_scale_ // user can customize loss@grad if not use_default_grad_scale_
if (strategy_.gradient_scale_ != if (strategy_.gradient_scale_ !=
...@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -218,9 +261,8 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
AddOutputToLeafOps(&result); AddOutputToLeafOps(&result);
if (VLOG_IS_ON(10)) { if (VLOG_IS_ON(10)) {
std::ostringstream sout; std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, sout); PrintGraphviz(*graph, fout);
VLOG(10) << sout.str();
} }
return std::unique_ptr<SSAGraph>(graph); return std::unique_ptr<SSAGraph>(graph);
...@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result, ...@@ -270,15 +312,6 @@ void MultiDevSSAGraphBuilder::CreateComputationalOp(SSAGraph *result,
CreateOpHandleIOs(result, op, dev_id); CreateOpHandleIOs(result, op, dev_id);
} }
OpDesc *MultiDevSSAGraphBuilder::GetSendOpDesc(
const ProgramDesc &program) const {
for (auto *op : program.Block(0).AllOps()) {
if (op->Type() == "send") {
return op;
}
}
return nullptr;
}
void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp(
SSAGraph *result, const std::string &og) const { SSAGraph *result, const std::string &og) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -401,14 +434,48 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
return var; return var;
} }
void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, void MultiDevSSAGraphBuilder::ConnectOp(SSAGraph *result, OpHandleBase *op,
const OpDesc &op) const { const std::string &prev_op_name) const {
for (auto &prev_op : result->ops_) {
if (prev_op->Name() == prev_op_name) {
auto *dep_var = new DummyVarHandle();
prev_op->AddOutput(dep_var);
result->dep_vars_.emplace(dep_var);
op->AddInput(dep_var);
}
}
}
void MultiDevSSAGraphBuilder::CreateDistTrainOp(SSAGraph *result,
const OpDesc &op) const {
CreateComputationalOp(result, op, 0);
if (op.Type() == "concat") {
ConnectOp(result, result->ops_.back().get(), "fetch_barrier");
}
}
void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
const OpDesc &op) const {
auto &p = places_[0]; auto &p = places_[0];
auto *s = local_scopes_[0]; auto *s = local_scopes_[0];
// FIXME(wuyi): send op always copy from GPU 0 result->ops_.emplace_back(new RPCOpHandle(op, s, p, op.Type()));
result->ops_.emplace_back(new SendOpHandle(op, s, p));
// Create inputs for output on original place and no ssa output if (op.Type() == "send_barrier") {
// is created for send op. ConnectOp(result, result->ops_.back().get(), "send_vars");
} else if (op.Type() == "recv") {
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]");
}
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0); CreateOpHandleIOs(result, op, 0);
} }
......
...@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -64,12 +64,24 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
bool IsScaleLossOp(const OpDesc &op) const; bool IsScaleLossOp(const OpDesc &op) const;
void CreateSendOp(SSAGraph *result, const OpDesc &op) const; void CreateRPCOp(SSAGraph *result, const OpDesc &op) const;
void CreateDistTrainOp(SSAGraph *result, const OpDesc &op) const;
/** /**
* Is this operator as the end-point operator before/after send operator. * Is this operator as the end-point operator before/after send operator.
*/ */
bool IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const; bool IsDistTrainOp(const OpDesc &op,
const std::vector<std::string> &send_vars,
const std::vector<std::string> &recv_vars) const;
std::vector<std::string> FindDistTrainSendVars(
const ProgramDesc &program) const;
std::vector<std::string> FindDistTrainRecvVars(
const ProgramDesc &program) const;
void ConnectOp(SSAGraph *result, OpHandleBase *op,
const std::string &prev_op_name) const;
void CreateComputationalOps(SSAGraph *result, const OpDesc &op, void CreateComputationalOps(SSAGraph *result, const OpDesc &op,
size_t num_places) const; size_t num_places) const;
...@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder { ...@@ -93,12 +105,6 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void CreateBroadcastOp(SSAGraph *result, const std::string &p_name, void CreateBroadcastOp(SSAGraph *result, const std::string &p_name,
size_t src_dev_id) const; size_t src_dev_id) const;
/**
* Get send op in the global block of program.
* nullptr if not found.
*/
OpDesc *GetSendOpDesc(const ProgramDesc &program) const;
bool IsSparseGradient( bool IsSparseGradient(
const std::unordered_map<std::string, proto::VarType::Type> &var_types, const std::unordered_map<std::string, proto::VarType::Type> &var_types,
const std::string &og) const; const std::string &og) const;
......
...@@ -12,24 +12,26 @@ ...@@ -12,24 +12,26 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/details/rpc_op_handle.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc, RPCOpHandle::RPCOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope, const Scope *local_scope, const platform::Place &place,
const platform::Place &place) const std::string &name)
: op_(framework::OpRegistry::CreateOp(op_desc)), : op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope), local_scope_(local_scope),
place_(place) {} place_(place),
name_(name) {}
void SendOpHandle::RunImpl() { void RPCOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle. // TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done // Wait input done
for (auto *in : inputs_) { for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_; auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK if (in->DebugString() == "dummy") { // HACK
continue; continue;
} }
...@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() { ...@@ -43,7 +45,7 @@ void SendOpHandle::RunImpl() {
op_->Run(*tmp_scope, place_); op_->Run(*tmp_scope, place_);
} }
std::string SendOpHandle::Name() const { return "send"; } std::string RPCOpHandle::Name() const { return name_; }
} // namespace details } // namespace details
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -27,9 +27,9 @@ namespace paddle { ...@@ -27,9 +27,9 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct SendOpHandle : public OpHandleBase { struct RPCOpHandle : public OpHandleBase {
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope, RPCOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place); const platform::Place& place, const std::string& name);
std::string Name() const override; std::string Name() const override;
...@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase { ...@@ -44,6 +44,7 @@ struct SendOpHandle : public OpHandleBase {
std::unique_ptr<OperatorBase> op_; std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_; const Scope* local_scope_;
const platform::Place& place_; const platform::Place& place_;
const std::string name_;
}; };
} // namespace details } // namespace details
......
...@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto, ...@@ -66,7 +66,7 @@ void OpProtoAndCheckerMaker::operator()(proto::OpProto* proto,
.InEnum( .InEnum(
{static_cast<int>(OpRole::kForward), {static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kOptimize), static_cast<int>(OpRole::kRPC),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward), static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kForward),
static_cast<int>(OpRole::kLoss) | static_cast<int>(OpRole::kLoss) |
static_cast<int>(OpRole::kBackward), static_cast<int>(OpRole::kBackward),
......
...@@ -24,6 +24,7 @@ enum class OpRole { ...@@ -24,6 +24,7 @@ enum class OpRole {
kForward = 0x0000, kForward = 0x0000,
kBackward = 0x0001, kBackward = 0x0001,
kOptimize = 0x0002, kOptimize = 0x0002,
kRPC = 0x0003,
kLoss = 0x0100, kLoss = 0x0100,
// The default value of op's role. This should be only used for unittests and // The default value of op's role. This should be only used for unittests and
......
...@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) { ...@@ -35,7 +35,7 @@ TEST(DataFlowGraph, BFS) {
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes(); auto nodes = trait.nodes();
int count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << it->name();
++count; ++count;
...@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) { ...@@ -49,7 +49,7 @@ TEST(DataFlowGraph, DFS) {
dfg.Build(); dfg.Build();
GraphTraits<DataFlowGraph> trait(&dfg); GraphTraits<DataFlowGraph> trait(&dfg);
auto nodes = trait.nodes_in_DFS(); auto nodes = trait.nodes_in_DFS();
int count = 0; size_t count = 0;
for (auto it = nodes.begin(); it != nodes.end(); ++it) { for (auto it = nodes.begin(); it != nodes.end(); ++it) {
LOG(INFO) << "visiting " << it->name(); LOG(INFO) << "visiting " << it->name();
++count; ++count;
......
...@@ -24,6 +24,15 @@ namespace paddle { ...@@ -24,6 +24,15 @@ namespace paddle {
namespace inference { namespace inference {
namespace analysis { namespace analysis {
template <typename Vec>
int AccuDims(Vec &&vec, int size) {
int res = 1;
for (int i = 0; i < size; i++) {
res *= std::forward<Vec>(vec)[i];
}
return res;
}
#define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__; #define SET_TYPE(type__) dic_[typeid(type__).hash_code()] = #type__;
/* /*
* Map typeid to representation. * Map typeid to representation.
...@@ -101,7 +110,5 @@ class OrderedRegistry { ...@@ -101,7 +110,5 @@ class OrderedRegistry {
} // namespace paddle } // namespace paddle
#define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \ #define PADDLE_DISALLOW_COPY_AND_ASSIGN(type__) \
\
type__(const type__ &) = delete; \ type__(const type__ &) = delete; \
\
void operator=(const type__ &) = delete; void operator=(const type__ &) = delete;
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES}) # Add TRT tests
nv_test(test_op_converter SRCS test_op_converter.cc mul_op.cc conv2d_op.cc DEPS ${FLUID_CORE_MODULES} tensorrt_engine)
# This test is not stable # This test is not stable
# See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828 # See https://paddleci.ngrok.io/viewLog.html?tab=buildLog&buildTypeId=Paddle_PrCi2&buildId=36834&_focus=8828
#nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc #nv_test(test_trt_activation_op SRCS test_activation_op.cc activation_op.cc io_converter.cc
# DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine # DEPS ${FLUID_CORE_MODULES} activation_op tensorrt_engine
# SERIAL) # SERIAL)
nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor) nv_test(test_io_converter SRCS test_io_converter.cc io_converter.cc DEPS dynload_cuda dynamic_loader lod_tensor)
nv_test(test_trt_mul_op SRCS test_mul_op.cc mul_op.cc
DEPS ${FLUID_CORE_MODULES} tensorrt_engine mul_op SERIAL)
...@@ -18,11 +18,25 @@ namespace paddle { ...@@ -18,11 +18,25 @@ namespace paddle {
namespace inference { namespace inference {
namespace tensorrt { namespace tensorrt {
/*
* MulOp, IMatrixMultiplyLayer in TRT. This Layer doesn't has weights.
*/
class MulOpConverter : public OpConverter { class MulOpConverter : public OpConverter {
public: public:
MulOpConverter() {} MulOpConverter() {}
void operator()(const framework::proto::OpDesc& op) override { void operator()(const framework::proto::OpDesc& op) override {
LOG(INFO) << "convert a fluid mul op to tensorrt fc layer without bias"; VLOG(4) << "convert a fluid mul op to tensorrt fc layer without bias";
framework::OpDesc op_desc(op, nullptr, nullptr);
// Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
auto* input2 = engine_->GetITensor(op_desc.Input("Y")[0]);
// Both the input1 and input2 do not need transpose.
auto* layer = TRT_ENGINE_ADD_LAYER(
engine_, MatrixMultiply, *const_cast<nvinfer1::ITensor*>(input1), false,
*const_cast<nvinfer1::ITensor*>(input2), false);
engine_->DeclareOutput(layer, 0, op_desc.Output("Out")[0]);
} }
}; };
......
...@@ -102,3 +102,5 @@ TEST(OpConverter, ConvertRelu) { ...@@ -102,3 +102,5 @@ TEST(OpConverter, ConvertRelu) {
} // namespace tensorrt } // namespace tensorrt
} // namespace inference } // namespace inference
} // namespace paddle } // namespace paddle
USE_OP(activation);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/tensorrt/convert/ut_helper.h"
namespace paddle {
namespace inference {
namespace tensorrt {
TEST(MulOpConverter, main) {
TRTConvertValidation validator(10, 1000);
validator.DeclInputVar("mul-X", nvinfer1::Dims2(10, 6));
validator.DeclInputVar("mul-Y", nvinfer1::Dims2(6, 10));
validator.DeclOutputVar("mul-Out", nvinfer1::Dims2(10, 10));
// Prepare Op description
framework::OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {"mul-X"});
desc.SetInput("Y", {"mul-Y"});
desc.SetOutput("Out", {"mul-Out"});
LOG(INFO) << "set OP";
validator.SetOp(*desc.Proto());
LOG(INFO) << "execute";
validator.Execute(10);
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
USE_OP(mul);
...@@ -23,8 +23,6 @@ namespace tensorrt { ...@@ -23,8 +23,6 @@ namespace tensorrt {
TEST(OpConverter, ConvertBlock) { TEST(OpConverter, ConvertBlock) {
framework::ProgramDesc prog; framework::ProgramDesc prog;
auto* block = prog.MutableBlock(0); auto* block = prog.MutableBlock(0);
auto* mul_op = block->AppendOp();
mul_op->SetType("mul");
auto* conv2d_op = block->AppendOp(); auto* conv2d_op = block->AppendOp();
conv2d_op->SetType("conv2d"); conv2d_op->SetType("conv2d");
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
/*
* This file implements a UT framework to make the validation of transforming
* Fluid Op to TRT Layer.
*/
#pragma once
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Get a random float value between [low, high]
*/
float random(float low, float high) {
static std::random_device rd;
static std::mt19937 mt(rd());
std::uniform_real_distribution<double> dist(1.0, 10.0);
return dist(mt);
}
void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
const platform::DeviceContext& ctx) {
auto dims = tensor->dims();
size_t num_elements = analysis::AccuDims(dims, dims.size());
PADDLE_ENFORCE_GT(num_elements, 0);
auto* data = tensor->mutable_data<float>(place);
for (size_t i = 0; i < num_elements; i++) {
*(data + i) = random(0., 1.);
}
}
/*
* Help to validate the correctness between Fluid Op and the corresponding TRT
* layer.
*/
class TRTConvertValidation {
public:
TRTConvertValidation() = delete;
TRTConvertValidation(int batch_size, int workspace_size = 1 << 10) {
// create engine.
engine_.reset(new TensorRTEngine(10, 1 << 10, &stream_));
engine_->InitNetwork();
PADDLE_ENFORCE_EQ(cudaStreamCreate(&stream_), 0);
}
// Declare a Variable as input with random initialization.
void DeclInputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
// Declare TRT inputs.
engine_->DeclareInput(name, nvinfer1::DataType::kFLOAT, dims);
}
void DeclOutputVar(const std::string& name, const nvinfer1::Dims& dims) {
DeclVar(name, dims);
}
void DeclVar(const std::string& name, const nvinfer1::Dims& dims) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// Init Fluid tensor.
std::vector<int> dim_vec(dims.nbDims);
for (int i = 0; i < dims.nbDims; i++) {
dim_vec[i] = dims.d[i];
}
auto* x = scope_.Var(name);
auto* x_tensor = x->GetMutable<framework::LoDTensor>();
x_tensor->Resize(framework::make_ddim(dim_vec));
RandomizeTensor(x_tensor, place, ctx);
}
void SetOp(const framework::proto::OpDesc& desc) {
op_ = framework::OpRegistry::CreateOp(desc);
OpConverter op_converter;
op_converter.ConvertOp(desc, engine_.get());
engine_->FreezeNetwork();
// Declare outputs.
op_desc_.reset(new framework::OpDesc(desc, nullptr, nullptr));
// Set Inputs.
for (const auto& input : op_desc_->InputArgumentNames()) {
auto* var = scope_.FindVar(input);
PADDLE_ENFORCE(var);
auto tensor = var->GetMutable<framework::LoDTensor>();
engine_->SetInputFromCPU(
input, static_cast<void*>(tensor->data<float>()),
sizeof(float) *
analysis::AccuDims(tensor->dims(), tensor->dims().size()));
}
}
void Execute(int batch_size) {
// Execute Fluid Op
// Execute TRT
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
engine_->Execute(batch_size);
op_->Run(scope_, place);
ASSERT_FALSE(op_desc_->OutputArgumentNames().empty());
for (const auto& output : op_desc_->OutputArgumentNames()) {
std::vector<float> fluid_out;
std::vector<float> trt_out(200);
engine_->GetOutputInCPU(output, &trt_out[0], 200 * sizeof(float));
auto* var = scope_.FindVar(output);
auto tensor = var->GetMutable<framework::LoDTensor>();
framework::TensorToVector(*tensor, ctx, &fluid_out);
// Compare two output
ASSERT_FALSE(fluid_out.empty());
for (size_t i = 0; i < fluid_out.size(); i++) {
EXPECT_LT(std::abs(fluid_out[i] - trt_out[i]), 0.001);
}
}
}
framework::Scope& scope() { return scope_; }
private:
std::unique_ptr<TensorRTEngine> engine_;
cudaStream_t stream_;
framework::Scope scope_;
std::unique_ptr<framework::OperatorBase> op_;
std::unique_ptr<framework::OpDesc> op_desc_;
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
...@@ -18,6 +18,7 @@ limitations under the License. */ ...@@ -18,6 +18,7 @@ limitations under the License. */
#include <cuda.h> #include <cuda.h>
#include <glog/logging.h> #include <glog/logging.h>
#include <string> #include <string>
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/tensorrt/helper.h" #include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -71,9 +72,10 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -71,9 +72,10 @@ void TensorRTEngine::FreezeNetwork() {
for (auto& item : buffer_sizes_) { for (auto& item : buffer_sizes_) {
if (item.second == 0) { if (item.second == 0) {
auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str()); auto slot_offset = infer_engine_->getBindingIndex(item.first.c_str());
auto dims = infer_engine_->getBindingDimensions(slot_offset);
item.second = kDataTypeSize[static_cast<int>( item.second = kDataTypeSize[static_cast<int>(
infer_engine_->getBindingDataType(slot_offset))] * infer_engine_->getBindingDataType(slot_offset))] *
AccumDims(infer_engine_->getBindingDimensions(slot_offset)); analysis::AccuDims(dims.d, dims.nbDims);
} }
auto& buf = buffer(item.first); auto& buf = buffer(item.first);
CHECK(buf.buffer == nullptr); // buffer should be allocated only once. CHECK(buf.buffer == nullptr); // buffer should be allocated only once.
...@@ -85,14 +87,15 @@ void TensorRTEngine::FreezeNetwork() { ...@@ -85,14 +87,15 @@ void TensorRTEngine::FreezeNetwork() {
nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name, nvinfer1::ITensor* TensorRTEngine::DeclareInput(const std::string& name,
nvinfer1::DataType dtype, nvinfer1::DataType dtype,
const nvinfer1::Dims& dim) { const nvinfer1::Dims& dims) {
PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s", PADDLE_ENFORCE_EQ(0, buffer_sizes_.count(name), "duplicate input name %s",
name); name);
PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first"); PADDLE_ENFORCE(infer_network_ != nullptr, "should initnetwork first");
auto* input = infer_network_->addInput(name.c_str(), dtype, dim); auto* input = infer_network_->addInput(name.c_str(), dtype, dims);
PADDLE_ENFORCE(input, "infer network add input %s failed", name); PADDLE_ENFORCE(input, "infer network add input %s failed", name);
buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] * AccumDims(dim); buffer_sizes_[name] = kDataTypeSize[static_cast<int>(dtype)] *
analysis::AccuDims(dims.d, dims.nbDims);
TensorRTEngine::SetITensor(name, input); TensorRTEngine::SetITensor(name, input);
return input; return input;
} }
...@@ -162,13 +165,13 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data, ...@@ -162,13 +165,13 @@ void TensorRTEngine::SetInputFromCPU(const std::string& name, void* data,
void TensorRTEngine::SetITensor(const std::string& name, void TensorRTEngine::SetITensor(const std::string& name,
nvinfer1::ITensor* tensor) { nvinfer1::ITensor* tensor) {
PADDLE_ENFORCE(tensor != nullptr); PADDLE_ENFORCE(tensor != nullptr);
PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate itensor name %s", PADDLE_ENFORCE_EQ(0, itensor_map_.count(name), "duplicate ITensor name %s",
name); name);
itensor_map_[name] = tensor; itensor_map_[name] = tensor;
} }
nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) { nvinfer1::ITensor* TensorRTEngine::GetITensor(const std::string& name) {
PADDLE_ENFORCE(itensor_map_.count(name), "no itensor %s", name); PADDLE_ENFORCE(itensor_map_.count(name), "no ITensor %s", name);
return itensor_map_[name]; return itensor_map_[name];
} }
......
...@@ -26,15 +26,6 @@ namespace tensorrt { ...@@ -26,15 +26,6 @@ namespace tensorrt {
namespace dy = paddle::platform::dynload; namespace dy = paddle::platform::dynload;
static size_t AccumDims(nvinfer1::Dims dims) {
size_t num = dims.nbDims == 0 ? 0 : 1;
for (int i = 0; i < dims.nbDims; i++) {
PADDLE_ENFORCE_GT(dims.d[i], 0);
num *= dims.d[i];
}
return num;
}
// TensorRT data type to size // TensorRT data type to size
const int kDataTypeSize[] = { const int kDataTypeSize[] = {
4, // kFLOAT 4, // kFLOAT
......
...@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE) ...@@ -200,7 +200,9 @@ if(WITH_DISTRIBUTE)
op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_vars_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_vars_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS}) op_library(send_barrier_op DEPS ${DISTRIBUTE_DEPS})
op_library(fetch_barrier_op DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(send_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fetch_barrier_op.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
#cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op
# listen_and_serv_op sum_op executor SERIAL) # listen_and_serv_op sum_op executor SERIAL)
...@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE) ...@@ -214,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif() endif()
else() else()
set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} send_op prefetch_op recv_op listen_and_serv_op send_vars_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
......
...@@ -25,6 +25,21 @@ namespace paddle { ...@@ -25,6 +25,21 @@ namespace paddle {
namespace operators { namespace operators {
namespace detail { namespace detail {
std::once_flag RPCClient::init_flag_;
std::unique_ptr<RPCClient> RPCClient::rpc_client_(nullptr);
RPCClient* RPCClient::GetInstance() {
std::call_once(init_flag_, &RPCClient::Init);
return rpc_client_.get();
}
void RPCClient::Init() {
if (rpc_client_.get() == nullptr) {
rpc_client_.reset(new RPCClient());
}
}
bool RPCClient::AsyncSendVariable(const std::string& ep, bool RPCClient::AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
...@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, ...@@ -60,7 +75,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep,
call->StartCall(); call->StartCall();
call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s)); call->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
}); });
req_count_++; req_count_++;
return true; return true;
...@@ -249,8 +263,9 @@ bool RPCClient::Proceed() { ...@@ -249,8 +263,9 @@ bool RPCClient::Proceed() {
delete c; delete c;
return true; return true;
} }
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(ep); auto it = channels_.find(ep);
if (it != channels_.end()) { if (it != channels_.end()) {
return it->second; return it->second;
...@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) { ...@@ -263,7 +278,6 @@ std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
auto ch = auto ch =
grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args); grpc::CreateCustomChannel(ep, grpc::InsecureChannelCredentials(), args);
channels_[ep] = ch; channels_[ep] = ch;
return ch; return ch;
} }
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <map> #include <map>
#include <mutex> // NOLINT
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -35,6 +36,7 @@ limitations under the License. */ ...@@ -35,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h" #include "paddle/fluid/operators/detail/sendrecvop_utils.h"
#include "paddle/fluid/platform/macros.h" // for DISABLE_COPY_AND_ASSIGN
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -161,6 +163,10 @@ class FetchBarrierProcessor : public BaseProcessor {
class RPCClient { class RPCClient {
public: public:
RPCClient() {}
static RPCClient* GetInstance();
bool AsyncSendVariable(const std::string& ep, bool AsyncSendVariable(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
...@@ -191,11 +197,17 @@ class RPCClient { ...@@ -191,11 +197,17 @@ class RPCClient {
private: private:
bool Proceed(); bool Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
// Init is called by GetInstance.
static void Init();
private: private:
grpc::CompletionQueue cq_; grpc::CompletionQueue cq_;
std::map<std::string, std::shared_ptr<grpc::Channel>> channels_; std::map<std::string, std::shared_ptr<grpc::Channel>> channels_;
int64_t req_count_ = 0; std::atomic<int64_t> req_count_{0};
std::mutex mutex_;
static std::unique_ptr<RPCClient> rpc_client_;
static std::once_flag init_flag_;
DISABLE_COPY_AND_ASSIGN(RPCClient);
}; };
} // namespace detail } // namespace detail
......
...@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) { ...@@ -121,10 +121,10 @@ TEST(PREFETCH, DISABLED_CPU) {
std::string in_var_name("ids"); std::string in_var_name("ids");
std::string out_var_name("out"); std::string out_var_name("out");
detail::RPCClient client; auto client = detail::RPCClient::GetInstance();
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name, client->AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name); out_var_name);
client.Wait(); client->Wait();
auto var = scope.Var(out_var_name); auto var = scope.Var(out_var_name);
auto value = var->GetMutable<framework::SelectedRows>()->value(); auto value = var->GetMutable<framework::SelectedRows>()->value();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
namespace operators {
class FetchBarrierOp : public framework::OperatorBase {
public:
FetchBarrierOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE(rpc_client->Wait());
for (auto& ep : eps) {
VLOG(3) << "fetch barrier, ep: " << ep;
rpc_client->AsyncSendFetchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
}
};
class FetchBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddComment(R"DOC(
SendBarrier operator
This operator will send a send barrier signal to list_and_serv op, so that
the Parameter Server would knew all variables have been sent.
)DOC");
AddAttr<std::vector<std::string>>("endpoints",
"(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"});
}
};
class FetchBarrierOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fetch_barrier, ops::FetchBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::FetchBarrierOpMaker,
ops::FetchBarrierOpShapeInference);
...@@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> { ...@@ -24,6 +24,14 @@ class FillConstantBatchSizeLikeOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::LoDTensor>("Input");
if (in->lod().size() && ctx.Attr<int>("input_dim_idx") == 0) {
// set the correct batch size for the LoDTensor.
auto odims = out->dims();
int output_dim_idx = ctx.Attr<int>("output_dim_idx");
odims[output_dim_idx] = static_cast<int>(in->lod().back().size()) - 1;
out->mutable_data<T>(odims, ctx.GetPlace());
}
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
auto value = ctx.Attr<float>("value"); auto value = ctx.Attr<float>("value");
......
...@@ -46,7 +46,10 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> { ...@@ -46,7 +46,10 @@ class CrossEntropyFunctor<platform::CPUDeviceContext, T> {
const int64_t* label_data = labels->data<int64_t>(); const int64_t* label_data = labels->data<int64_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
int index = i * class_num + label_data[i]; int lbl = label_data[i];
PADDLE_ENFORCE_GE(lbl, 0);
PADDLE_ENFORCE_LT(lbl, class_num);
int index = i * class_num + lbl;
loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index])); loss_data[i] = -math::TolerableValue<T>()(std::log(prob_data[index]));
} }
} }
......
...@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase { ...@@ -41,12 +41,7 @@ class PrefetchOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -66,9 +61,6 @@ class PrefetchOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable(); AddInput("X", "(LoDTensor) Input Id variables to be sent").AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddOutput("Out", AddOutput("Out",
"(LoDTensor) result " "(LoDTensor) result "
"to be fetched from parameter server") "to be fetched from parameter server")
...@@ -87,17 +79,6 @@ the parameter server and fetch result back. ...@@ -87,17 +79,6 @@ the parameter server and fetch result back.
} }
}; };
class PrefetchOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class PrefetchOpShapeInference : public framework::InferShapeBase { class PrefetchOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -110,5 +91,4 @@ namespace ops = paddle::operators; ...@@ -110,5 +91,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(prefetch, ops::PrefetchOp, REGISTER_OPERATOR(prefetch, ops::PrefetchOp,
paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker, paddle::framework::EmptyGradOpMaker, ops::PrefetchOpMaker,
ops::PrefetchOpVarTypeInference,
ops::PrefetchOpShapeInference); ops::PrefetchOpShapeInference);
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase { ...@@ -36,19 +37,23 @@ class RecvOp : public framework::OperatorBase {
const platform::Place& place) const override { const platform::Place& place) const override {
auto outs = Outputs("Out"); auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
int sync_mode = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
// For profiling
platform::RecordEvent record_event(Type(), &ctx);
auto rpc_client = detail::RPCClient::GetInstance();
for (size_t i = 0; i < outs.size(); i++) { for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i]; VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait());
} }
PADDLE_ENFORCE(client_.Wait());
} }
private:
mutable detail::RPCClient client_;
}; };
class RecvOpMaker : public framework::OpProtoAndCheckerMaker { class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -65,6 +70,10 @@ This operator can get variables from server side. ...@@ -65,6 +70,10 @@ This operator can get variables from server side.
"Server endpoints in the order of input " "Server endpoints in the order of input "
"variables for mapping") "variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
} }
}; };
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase { ...@@ -36,31 +37,30 @@ class SendBarrierOp : public framework::OperatorBase {
void RunImpl(const framework::Scope& scope, void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override { const platform::Place& place) const override {
std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints"); std::vector<std::string> eps = Attr<std::vector<std::string>>("endpoints");
bool sync_mode = Attr<bool>("sync_mode");
auto client_var_name = Output("RPCClient"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), auto& ctx = *pool.Get(place);
"Can not find variable '%s' in the scope.", // For profiling
client_var_name); platform::RecordEvent record_event(Type(), &ctx);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>(); auto rpc_client = detail::RPCClient::GetInstance();
// need to wait before sending send_barrier message // need to wait before sending send_barrier message
PADDLE_ENFORCE(rpc_client->Wait()); PADDLE_ENFORCE(rpc_client->Wait());
if (sync_mode) {
for (auto& ep : eps) { for (auto& ep : eps) {
VLOG(3) << "send barrier, ep: " << ep; VLOG(3) << "send barrier, ep: " << ep;
rpc_client->AsyncSendBatchBarrier(ep); rpc_client->AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(rpc_client->Wait());
} }
PADDLE_ENFORCE(rpc_client->Wait());
} }
}; };
class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker { class SendBarrierOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
SendBarrier operator SendBarrier operator
...@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent. ...@@ -72,17 +72,7 @@ the Parameter Server would knew all variables have been sent.
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints to send variables to.") "Server endpoints to send variables to.")
.SetDefault({"127.0.0.1:6164"}); .SetDefault({"127.0.0.1:6164"});
} AddAttr<bool>("sync_mode", "work in sync_mode or not").SetDefault(true);
};
class SendBarrierOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
} }
}; };
...@@ -98,5 +88,4 @@ namespace ops = paddle::operators; ...@@ -98,5 +88,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp, REGISTER_OPERATOR(send_barrier, ops::SendBarrierOp,
paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker, paddle::framework::EmptyGradOpMaker, ops::SendBarrierOpMaker,
ops::SendBarrierOpVarTypeInference,
ops::SendBarrierOpShapeInference); ops::SendBarrierOpShapeInference);
...@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase { ...@@ -49,12 +49,7 @@ class SendOp : public framework::OperatorBase {
// For profiling // For profiling
platform::RecordEvent record_event(Type(), &ctx); platform::RecordEvent record_event(Type(), &ctx);
auto client_var_name = Output("RPCClient"); auto rpc_client = detail::RPCClient::GetInstance();
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name),
"Can not find variable '%s' in the scope.",
client_var_name);
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -96,9 +91,6 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable(); AddInput("X", "(Tensor) Input tensor to be sent").AsDuplicable();
AddOutput("Out", "(Tensor) Output tensor to be received from server") AddOutput("Out", "(Tensor) Output tensor to be received from server")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which is"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
...@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server. ...@@ -119,17 +111,6 @@ This operator will send tensor to recv_op at the parameter server.
} }
}; };
class SendOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendOpShapeInference : public framework::InferShapeBase { class SendOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase { ...@@ -141,5 +122,4 @@ class SendOpShapeInference : public framework::InferShapeBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker, REGISTER_OPERATOR(send, ops::SendOp, paddle::framework::EmptyGradOpMaker,
ops::SendOpMaker, ops::SendOpVarTypeInference, ops::SendOpMaker, ops::SendOpShapeInference);
ops::SendOpShapeInference);
...@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) { ...@@ -156,6 +156,7 @@ TEST(SendRecvOp, CPUDense) {
std::thread server_thread(StartServerNet, false, &initialized); std::thread server_thread(StartServerNet, false, &initialized);
while (!initialized) { while (!initialized) {
} }
static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get()) static_cast<paddle::operators::ListenAndServOp *>(listen_and_serv_op.get())
->WaitServerReady(); ->WaitServerReady();
...@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) { ...@@ -175,9 +176,10 @@ TEST(SendRecvOp, CPUDense) {
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( const f::VariableNameMap &inputs = {{"X", {"x1"}}};
"send", {{"X", {"x1"}}}, const f::VariableNameMap &outputs = {{"Out", {"Out"}}};
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
auto send_op = f::OpRegistry::CreateOp("send", inputs, outputs, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto in_var = scope.Var("x1"); auto in_var = scope.Var("x1");
...@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) { ...@@ -220,9 +222,8 @@ TEST(SendRecvOp, CPUSparse) {
std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port); std::string endpoint = paddle::string::Sprintf("127.0.0.1:%d", selected_port);
attrs.insert({"endpoints", std::vector<std::string>({endpoint})}); attrs.insert({"endpoints", std::vector<std::string>({endpoint})});
attrs.insert({"epmap", std::vector<std::string>({endpoint})}); attrs.insert({"epmap", std::vector<std::string>({endpoint})});
auto send_op = f::OpRegistry::CreateOp( auto send_op = f::OpRegistry::CreateOp("send", {{"X", {"x1"}}},
"send", {{"X", {"x1"}}}, {{"Out", {"Out"}}}, attrs);
{{"Out", {"Out"}}, {"RPCClient", {"RPC_CLIENT_VAR"}}}, attrs);
send_op->Run(scope, place); send_op->Run(scope, place);
auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>(); auto x0 = scope.Var("x0")->GetMutable<f::SelectedRows>();
......
...@@ -20,6 +20,9 @@ namespace operators { ...@@ -20,6 +20,9 @@ namespace operators {
inline bool NeedSend(const framework::Scope& scope, inline bool NeedSend(const framework::Scope& scope,
const std::string& varname) { const std::string& varname) {
// dummy variable is only used in parallel executor to represent
// some dependency relationship, we don't need to send/recv it.
if (varname == "dummy") return false;
auto* var = scope.FindVar(varname); auto* var = scope.FindVar(varname);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.", PADDLE_ENFORCE_NOT_NULL(var, "Can not find variable '%s' in the send side.",
varname); varname);
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/grpc_client.h" #include "paddle/fluid/operators/detail/grpc_client.h"
#include "paddle/fluid/operators/send_recv_util.h" #include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase { ...@@ -41,12 +42,10 @@ class SendVarsOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place); auto& ctx = *pool.Get(place);
auto client_var_name = Output("RPCClient"); // For profiling
PADDLE_ENFORCE_NOT_NULL(scope.FindVar(client_var_name), platform::RecordEvent record_event(Type(), &ctx);
"Can not find variable '%s' in the scope.",
client_var_name); auto rpc_client = detail::RPCClient::GetInstance();
auto* client_var = scope.FindVar(client_var_name);
detail::RPCClient* rpc_client = client_var->GetMutable<detail::RPCClient>();
for (size_t i = 0; i < ins.size(); i++) { for (size_t i = 0; i < ins.size(); i++) {
if (NeedSend(scope, ins[i])) { if (NeedSend(scope, ins[i])) {
...@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,9 +68,6 @@ class SendVarsOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() { void Make() {
AddInput("X", "(Tensor, SelectedRows) Input variables to be sent") AddInput("X", "(Tensor, SelectedRows) Input variables to be sent")
.AsDuplicable(); .AsDuplicable();
AddOutput("RPCClient",
"(RPCClient) The RPC client object which will be"
"initialized at most once.");
AddComment(R"DOC( AddComment(R"DOC(
Send operator Send operator
...@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server ...@@ -89,17 +85,6 @@ This operator will send variables to listen_and_serve op at the parameter server
} }
}; };
class SendVarsOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output("RPCClient").front();
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SendVarsOpShapeInference : public framework::InferShapeBase { class SendVarsOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext* ctx) const override {} void operator()(framework::InferShapeContext* ctx) const override {}
...@@ -112,5 +97,4 @@ namespace ops = paddle::operators; ...@@ -112,5 +97,4 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(send_vars, ops::SendVarsOp, REGISTER_OPERATOR(send_vars, ops::SendVarsOp,
paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker, paddle::framework::EmptyGradOpMaker, ops::SendVarsOpMaker,
ops::SendVarsOpVarTypeInference,
ops::SendVarsOpShapeInference); ops::SendVarsOpShapeInference);
...@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) { ...@@ -32,7 +32,8 @@ void BindConstValue(pybind11::module* m) {
.value("Forward", framework::OpRole::kForward) .value("Forward", framework::OpRole::kForward)
.value("Backward", framework::OpRole::kBackward) .value("Backward", framework::OpRole::kBackward)
.value("Optimize", framework::OpRole::kOptimize) .value("Optimize", framework::OpRole::kOptimize)
.value("Loss", framework::OpRole::kLoss); .value("Loss", framework::OpRole::kLoss)
.value("RPC", framework::OpRole::kRPC);
op_proto_and_checker_maker.def( op_proto_and_checker_maker.def(
"kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName); "kOpRoleAttrName", framework::OpProtoAndCheckerMaker::OpRoleAttrName);
......
...@@ -33,7 +33,7 @@ namespace paddle { ...@@ -33,7 +33,7 @@ namespace paddle {
* \param outputs[0] Image data of NCHW format. * \param outputs[0] Image data of NCHW format.
*/ */
class BlockExpandFunction : public FunctionBase { class BlockExpandFunction : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments // function arguments
strides_ = config.get<std::vector<size_t>>("strides"); strides_ = config.get<std::vector<size_t>>("strides");
...@@ -81,7 +81,7 @@ public: ...@@ -81,7 +81,7 @@ public:
(size_t)blockW()}); (size_t)blockW()});
} }
protected: protected:
std::vector<size_t> strides_; std::vector<size_t> strides_;
std::vector<size_t> paddings_; std::vector<size_t> paddings_;
std::vector<size_t> blocks_; std::vector<size_t> blocks_;
...@@ -101,7 +101,7 @@ protected: ...@@ -101,7 +101,7 @@ protected:
template <DeviceType Device> template <DeviceType Device>
class BlockExpandForward : public BlockExpandFunction { class BlockExpandForward : public BlockExpandFunction {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
BlockExpandFunction::init(config); BlockExpandFunction::init(config);
} }
...@@ -149,7 +149,7 @@ public: ...@@ -149,7 +149,7 @@ public:
template <DeviceType Device> template <DeviceType Device>
class BlockExpandBackward : public BlockExpandFunction { class BlockExpandBackward : public BlockExpandFunction {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
BlockExpandFunction::init(config); BlockExpandFunction::init(config);
} }
......
...@@ -63,12 +63,12 @@ enum ArgType { ...@@ -63,12 +63,12 @@ enum ArgType {
ADD_TO = 2, ADD_TO = 2,
}; };
class BufferArg { class BufferArg {
public: public:
void setArgType(ArgType argType) { argType_ = argType; } void setArgType(ArgType argType) { argType_ = argType; }
ArgType getArgType() const { return argType_; } ArgType getArgType() const { return argType_; }
public: public:
BufferArg(ValueType valueType, BufferArg(ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
...@@ -169,7 +169,7 @@ public: ...@@ -169,7 +169,7 @@ public:
const SequenceArg& sequence() const; const SequenceArg& sequence() const;
const SparseMatrixArg& sparse() const; const SparseMatrixArg& sparse() const;
protected: protected:
void* buf_; void* buf_;
ValueType valueType_; ValueType valueType_;
TensorShape shape_; TensorShape shape_;
...@@ -185,7 +185,7 @@ protected: ...@@ -185,7 +185,7 @@ protected:
// valueType_ = int32 // valueType_ = int32
// if a < b then value_.buf_[a] < value_.buf_[b] // if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg { class SequenceIdArg : public BufferArg {
public: public:
SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED) SequenceIdArg(const TensorShape& shape, ArgType argType = UNSPECIFIED)
: BufferArg(VALUE_TYPE_INT32, shape, argType) { : BufferArg(VALUE_TYPE_INT32, shape, argType) {
bufferType_ = TENSOR_SEQUENCE_ID; bufferType_ = TENSOR_SEQUENCE_ID;
...@@ -212,7 +212,7 @@ public: ...@@ -212,7 +212,7 @@ public:
size_t numSeqs() const { return numSeqs_; } size_t numSeqs() const { return numSeqs_; }
private: private:
size_t numSeqs_; size_t numSeqs_;
}; };
...@@ -222,7 +222,7 @@ private: ...@@ -222,7 +222,7 @@ private:
// SequenceArg can be used to represent sequences that contain multiple // SequenceArg can be used to represent sequences that contain multiple
// unequal lengths. // unequal lengths.
class SequenceArg : public BufferArg { class SequenceArg : public BufferArg {
public: public:
SequenceArg(ValueType valueType, SequenceArg(ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
ArgType argType = UNSPECIFIED) ArgType argType = UNSPECIFIED)
...@@ -255,7 +255,7 @@ public: ...@@ -255,7 +255,7 @@ public:
SequenceIdArg& getSequenceId() { return startPositions_; } SequenceIdArg& getSequenceId() { return startPositions_; }
const SequenceIdArg& getSequenceId() const { return startPositions_; } const SequenceIdArg& getSequenceId() const { return startPositions_; }
private: private:
SequenceIdArg startPositions_; SequenceIdArg startPositions_;
}; };
...@@ -263,7 +263,7 @@ private: ...@@ -263,7 +263,7 @@ private:
// valueType_ == float or double // valueType_ == float or double
// shape_.ndims() == 2 // shape_.ndims() == 2
class SparseMatrixArg : public BufferArg { class SparseMatrixArg : public BufferArg {
public: public:
SparseMatrixArg(void* buf, SparseMatrixArg(void* buf,
ValueType valueType, ValueType valueType,
const TensorShape& shape, const TensorShape& shape,
...@@ -353,7 +353,7 @@ public: ...@@ -353,7 +353,7 @@ public:
SparseDataType dataType() const { return type_; } SparseDataType dataType() const { return type_; }
private: private:
BufferArg row_; BufferArg row_;
BufferArg col_; BufferArg col_;
size_t nnz_; size_t nnz_;
......
...@@ -100,7 +100,7 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat, ...@@ -100,7 +100,7 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionForwardFunc : public FunctionBase { class ContextProjectionForwardFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length"); context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start"); context_start_ = config.get<int>("context_start");
...@@ -146,7 +146,7 @@ public: ...@@ -146,7 +146,7 @@ public:
begin_pad_); begin_pad_);
} }
private: private:
size_t context_length_; size_t context_length_;
int context_start_; int context_start_;
size_t begin_pad_; size_t begin_pad_;
...@@ -223,7 +223,7 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat, ...@@ -223,7 +223,7 @@ void ContextProjectionBackward<DEVICE_TYPE_CPU>(const CpuMatrix& out_grad_mat,
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionBackwardFunc : public FunctionBase { class ContextProjectionBackwardFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length"); context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start"); context_start_ = config.get<int>("context_start");
...@@ -278,7 +278,7 @@ public: ...@@ -278,7 +278,7 @@ public:
total_pad_); total_pad_);
} }
private: private:
size_t context_length_; size_t context_length_;
int context_start_; int context_start_;
size_t begin_pad_; size_t begin_pad_;
...@@ -299,7 +299,7 @@ private: ...@@ -299,7 +299,7 @@ private:
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionBackwardDataFunc : public FunctionBase { class ContextProjectionBackwardDataFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length"); context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start"); context_start_ = config.get<int>("context_start");
...@@ -331,7 +331,7 @@ public: ...@@ -331,7 +331,7 @@ public:
out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_); out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_);
} }
private: private:
size_t context_length_; size_t context_length_;
int context_start_; int context_start_;
}; };
...@@ -348,7 +348,7 @@ private: ...@@ -348,7 +348,7 @@ private:
*/ */
template <DeviceType Device> template <DeviceType Device>
class ContextProjectionBackwardWeightFunc : public FunctionBase { class ContextProjectionBackwardWeightFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length"); context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start"); context_start_ = config.get<int>("context_start");
...@@ -382,7 +382,7 @@ public: ...@@ -382,7 +382,7 @@ public:
begin_pad_); begin_pad_);
} }
private: private:
size_t context_length_; size_t context_length_;
int context_start_; int context_start_;
size_t begin_pad_; size_t begin_pad_;
......
...@@ -56,7 +56,7 @@ namespace paddle { ...@@ -56,7 +56,7 @@ namespace paddle {
* H and W is height and width of filter. * H and W is height and width of filter.
*/ */
class ConvFunctionBase : public FunctionBase { class ConvFunctionBase : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments // function arguments
strides_ = config.get<std::vector<size_t>>("strides"); strides_ = config.get<std::vector<size_t>>("strides");
...@@ -101,7 +101,7 @@ public: ...@@ -101,7 +101,7 @@ public:
} }
} }
protected: protected:
size_t getFilterHeight(const TensorShape& filter) const { size_t getFilterHeight(const TensorShape& filter) const {
return filter[filter.ndims() - 2]; return filter[filter.ndims() - 2];
} }
......
...@@ -97,7 +97,7 @@ class CosSimForwardFunc : public FunctionBase { ...@@ -97,7 +97,7 @@ class CosSimForwardFunc : public FunctionBase {
CosSimForward<Device>(out_mat, in1_mat, in2_mat, scale_); CosSimForward<Device>(out_mat, in1_mat, in2_mat, scale_);
} }
private: private:
real scale_; real scale_;
}; };
...@@ -227,7 +227,7 @@ class CosSimBackwardFunc : public FunctionBase { ...@@ -227,7 +227,7 @@ class CosSimBackwardFunc : public FunctionBase {
out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_); out_grad, out_val, in1_val, in2_val, in1_grad, in2_grad, scale_);
} }
private: private:
real scale_; real scale_;
}; };
......
...@@ -112,7 +112,7 @@ void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad, ...@@ -112,7 +112,7 @@ void CropGrad<DEVICE_TYPE_CPU>(const real* inGrad,
*/ */
template <DeviceType Device> template <DeviceType Device>
class CropFunc : public FunctionBase { class CropFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { conf_ = config; } void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -130,7 +130,7 @@ public: ...@@ -130,7 +130,7 @@ public:
conf_); conf_);
} }
private: private:
FuncConfig conf_; FuncConfig conf_;
}; };
...@@ -145,7 +145,7 @@ private: ...@@ -145,7 +145,7 @@ private:
template <DeviceType Device> template <DeviceType Device>
class CropGradFunc : public FunctionBase { class CropGradFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { conf_ = config; } void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -163,7 +163,7 @@ public: ...@@ -163,7 +163,7 @@ public:
conf_); conf_);
} }
private: private:
FuncConfig conf_; FuncConfig conf_;
}; };
......
...@@ -160,7 +160,7 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad, ...@@ -160,7 +160,7 @@ void CrossMapNormalGrad<DEVICE_TYPE_CPU>(real* inputsGrad,
*/ */
template <DeviceType Device> template <DeviceType Device>
class CrossMapNormalFunc : public FunctionBase { class CrossMapNormalFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments // function arguments
size_ = config.get<size_t>("size"); size_ = config.get<size_t>("size");
...@@ -220,7 +220,7 @@ public: ...@@ -220,7 +220,7 @@ public:
return ops; return ops;
} }
private: private:
size_t size_; size_t size_;
real scale_; real scale_;
real pow_; real pow_;
...@@ -260,7 +260,7 @@ private: ...@@ -260,7 +260,7 @@ private:
*/ */
template <DeviceType Device> template <DeviceType Device>
class CrossMapNormalGradFunc : public FunctionBase { class CrossMapNormalGradFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments // function arguments
size_ = config.get<size_t>("size"); size_ = config.get<size_t>("size");
...@@ -328,7 +328,7 @@ public: ...@@ -328,7 +328,7 @@ public:
return ops; return ops;
} }
private: private:
size_t size_; size_t size_;
real scale_; real scale_;
real pow_; real pow_;
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
template <class T> template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(const T* inputData, void operator()(const T* inputData,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
template <class T> template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvGradInputFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(const T* outputGrad, void operator()(const T* outputGrad,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
...@@ -66,7 +66,7 @@ public: ...@@ -66,7 +66,7 @@ public:
template <class T> template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_CPU, T> { class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_CPU, T> {
public: public:
void operator()(const T* outputGrad, void operator()(const T* outputGrad,
const T* inputData, const T* inputData,
int batchSize, int batchSize,
...@@ -93,7 +93,7 @@ public: ...@@ -93,7 +93,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class DepthwiseConvFunction : public ConvFunctionBase { class DepthwiseConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
...@@ -156,7 +156,7 @@ public: ...@@ -156,7 +156,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class DepthwiseConvGradInputFunction : public ConvFunctionBase { class DepthwiseConvGradInputFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
...@@ -220,7 +220,7 @@ public: ...@@ -220,7 +220,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class DepthwiseConvGradFilterFunction : public ConvFunctionBase { class DepthwiseConvGradFilterFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
......
...@@ -44,7 +44,7 @@ namespace paddle { ...@@ -44,7 +44,7 @@ namespace paddle {
*/ */
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvFunctor { class DepthwiseConvFunctor {
public: public:
void operator()(const T* inputData, void operator()(const T* inputData,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
...@@ -89,7 +89,7 @@ public: ...@@ -89,7 +89,7 @@ public:
*/ */
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvGradInputFunctor { class DepthwiseConvGradInputFunctor {
public: public:
void operator()(const T* outputGrad, void operator()(const T* outputGrad,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
...@@ -135,7 +135,7 @@ public: ...@@ -135,7 +135,7 @@ public:
*/ */
template <DeviceType Device, class T> template <DeviceType Device, class T>
class DepthwiseConvGradFilterFunctor { class DepthwiseConvGradFilterFunctor {
public: public:
void operator()(const T* outputGrad, void operator()(const T* outputGrad,
const T* inputData, const T* inputData,
int batchSize, int batchSize,
......
...@@ -199,7 +199,7 @@ __global__ void ConvolutionDepthwiseFilterBackward(const int num_i, ...@@ -199,7 +199,7 @@ __global__ void ConvolutionDepthwiseFilterBackward(const int num_i,
template <class T> template <class T>
class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T> { class DepthwiseConvFunctor<DEVICE_TYPE_GPU, T> {
public: public:
void operator()(const T* inputData, void operator()(const T* inputData,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
...@@ -249,7 +249,7 @@ public: ...@@ -249,7 +249,7 @@ public:
template <class T> template <class T>
class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T> { class DepthwiseConvGradInputFunctor<DEVICE_TYPE_GPU, T> {
public: public:
void operator()(const T* outputGrad, void operator()(const T* outputGrad,
const T* filterData, const T* filterData,
int batchSize, int batchSize,
...@@ -300,7 +300,7 @@ public: ...@@ -300,7 +300,7 @@ public:
template <class T> template <class T>
class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> { class DepthwiseConvGradFilterFunctor<DEVICE_TYPE_GPU, T> {
public: public:
void operator()(const T* outputGrad, void operator()(const T* outputGrad,
const T* inputData, const T* inputData,
int batchSize, int batchSize,
......
...@@ -46,7 +46,7 @@ int GetCpuCount() { return 1; } ...@@ -46,7 +46,7 @@ int GetCpuCount() { return 1; }
#endif #endif
class EigenDeviceWarpper { class EigenDeviceWarpper {
public: // NOLINT public: // NOLINT
#if EIGEN_USE_THREADS #if EIGEN_USE_THREADS
static Eigen::ThreadPoolDevice* device() { static Eigen::ThreadPoolDevice* device() {
const int num_cpus = GetCpuCount(); const int num_cpus = GetCpuCount();
......
...@@ -29,7 +29,7 @@ namespace paddle { ...@@ -29,7 +29,7 @@ namespace paddle {
* The argument type of Function::init. * The argument type of Function::init.
*/ */
class FuncConfig { class FuncConfig {
public: public:
template <typename T> template <typename T>
T get(const std::string& key, Error* err = nullptr) const { T get(const std::string& key, Error* err = nullptr) const {
try { try {
...@@ -59,7 +59,7 @@ public: ...@@ -59,7 +59,7 @@ public:
return *this; return *this;
} }
protected: protected:
mutable std::unordered_map<std::string, any> valueMap_; mutable std::unordered_map<std::string, any> valueMap_;
}; };
...@@ -77,7 +77,7 @@ protected: ...@@ -77,7 +77,7 @@ protected:
* in the BufferArgs life time. * in the BufferArgs life time.
*/ */
class BufferArgs { class BufferArgs {
public: public:
BufferArgs() {} BufferArgs() {}
~BufferArgs() { ~BufferArgs() {
...@@ -137,7 +137,7 @@ public: ...@@ -137,7 +137,7 @@ public:
void addArg(SparseMatrixArg& arg) { args_.push_back(&arg); } void addArg(SparseMatrixArg& arg) { args_.push_back(&arg); }
private: private:
std::vector<BufferArg*> args_; std::vector<BufferArg*> args_;
// The BufferArg object is constructed and freed by BufferArgs. // The BufferArg object is constructed and freed by BufferArgs.
std::vector<BufferArg*> _args_; std::vector<BufferArg*> _args_;
...@@ -163,7 +163,7 @@ private: ...@@ -163,7 +163,7 @@ private:
* If Function has more than one output, each output can have different modes. * If Function has more than one output, each output can have different modes.
*/ */
class FunctionBase { class FunctionBase {
public: public:
virtual ~FunctionBase() {} virtual ~FunctionBase() {}
virtual void init(const FuncConfig& config) {} virtual void init(const FuncConfig& config) {}
...@@ -192,7 +192,7 @@ public: ...@@ -192,7 +192,7 @@ public:
static ClassRegistrar<FunctionBase> funcRegistrar_; static ClassRegistrar<FunctionBase> funcRegistrar_;
protected: protected:
// numInputs_ and numOutputs_ represents the maximum // numInputs_ and numOutputs_ represents the maximum
// input and output supported by Function. // input and output supported by Function.
// Some functions are optimized for input and output, // Some functions are optimized for input and output,
......
...@@ -39,7 +39,7 @@ struct Allocator<DEVICE_TYPE_GPU> { ...@@ -39,7 +39,7 @@ struct Allocator<DEVICE_TYPE_GPU> {
// Copy argument1 to argument2 // Copy argument1 to argument2
template <DeviceType DType1, DeviceType DType2> template <DeviceType DType1, DeviceType DType2>
class CopyArgument { class CopyArgument {
public: public:
void operator()(const BufferArg& arg1, BufferArg& arg2) { void operator()(const BufferArg& arg1, BufferArg& arg2) {
CHECK_EQ(arg1.valueType(), arg2.valueType()); CHECK_EQ(arg1.valueType(), arg2.valueType());
CHECK_LE(arg1.shape().getElements(), arg2.shape().getElements()); CHECK_LE(arg1.shape().getElements(), arg2.shape().getElements());
...@@ -95,7 +95,7 @@ public: ...@@ -95,7 +95,7 @@ public:
*/ */
template <DeviceType DType1, DeviceType DType2> template <DeviceType DType1, DeviceType DType2>
class Compare2Function { class Compare2Function {
public: public:
typedef typename test::Allocator<DType1>::type Allocator1; typedef typename test::Allocator<DType1>::type Allocator1;
typedef typename test::Allocator<DType2>::type Allocator2; typedef typename test::Allocator<DType2>::type Allocator2;
typedef typename Tensor<real, DType1>::Vector Vector1; typedef typename Tensor<real, DType1>::Vector Vector1;
...@@ -305,7 +305,7 @@ public: ...@@ -305,7 +305,7 @@ public:
std::shared_ptr<FunctionBase> getFunction2() const { return function2_; } std::shared_ptr<FunctionBase> getFunction2() const { return function2_; }
protected: protected:
// only init cpu argument, gpu argument copy from cpu argument. // only init cpu argument, gpu argument copy from cpu argument.
void initArg(BufferArg& arg) { void initArg(BufferArg& arg) {
Vector1 vector(arg.shape().getElements(), (real*)arg.data()); Vector1 vector(arg.shape().getElements(), (real*)arg.data());
...@@ -381,7 +381,7 @@ protected: ...@@ -381,7 +381,7 @@ protected:
} }
} }
protected: protected:
std::shared_ptr<FunctionBase> function1_; std::shared_ptr<FunctionBase> function1_;
std::shared_ptr<FunctionBase> function2_; std::shared_ptr<FunctionBase> function2_;
std::vector<std::shared_ptr<Allocator1>> func1Memory_; std::vector<std::shared_ptr<Allocator1>> func1Memory_;
...@@ -400,7 +400,7 @@ protected: ...@@ -400,7 +400,7 @@ protected:
class CpuGpuFuncCompare class CpuGpuFuncCompare
: public Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> { : public Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> {
public: public:
CpuGpuFuncCompare(const std::string& name, const FuncConfig& config) CpuGpuFuncCompare(const std::string& name, const FuncConfig& config)
: Compare2Function(name + "-CPU", name + "-GPU", config) {} : Compare2Function(name + "-CPU", name + "-GPU", config) {}
......
...@@ -24,7 +24,7 @@ namespace paddle { ...@@ -24,7 +24,7 @@ namespace paddle {
*/ */
template <DeviceType Device> template <DeviceType Device>
class GemmConvFunction : public ConvFunctionBase { class GemmConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
...@@ -136,7 +136,7 @@ public: ...@@ -136,7 +136,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class GemmConvMobileFunction : public ConvFunctionBase { class GemmConvMobileFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
...@@ -297,7 +297,7 @@ public: ...@@ -297,7 +297,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class GemmConvGradInputFunction : public ConvFunctionBase { class GemmConvGradInputFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
...@@ -404,7 +404,7 @@ public: ...@@ -404,7 +404,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class GemmConvGradFilterFunction : public ConvFunctionBase { class GemmConvGradFilterFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
......
...@@ -70,7 +70,7 @@ enum ColFormat { kCFO = 0, kOCF = 1 }; ...@@ -70,7 +70,7 @@ enum ColFormat { kCFO = 0, kOCF = 1 };
*/ */
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, DeviceType Device, class T>
class Im2ColFunctor { class Im2ColFunctor {
public: public:
void operator()(const T* imData, void operator()(const T* imData,
const TensorShape& imShape, const TensorShape& imShape,
T* colData, T* colData,
...@@ -85,7 +85,7 @@ public: ...@@ -85,7 +85,7 @@ public:
template <ColFormat Format, DeviceType Device, class T> template <ColFormat Format, DeviceType Device, class T>
class Col2ImFunctor { class Col2ImFunctor {
public: public:
void operator()(T* imData, void operator()(T* imData,
const TensorShape& imShape, const TensorShape& imShape,
const T* colData, const T* colData,
...@@ -100,7 +100,7 @@ public: ...@@ -100,7 +100,7 @@ public:
template <class T> template <class T>
class Im2ColMobileFunctor { class Im2ColMobileFunctor {
public: public:
void operator()(const T* imData, void operator()(const T* imData,
const TensorShape& imShape, const TensorShape& imShape,
T* colData, T* colData,
......
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
*/ */
template <class T> template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> { class Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> {
public: public:
void operator()(const T* imData, void operator()(const T* imData,
const TensorShape& imShape, const TensorShape& imShape,
T* colData, T* colData,
...@@ -75,7 +75,7 @@ public: ...@@ -75,7 +75,7 @@ public:
*/ */
template <class T> template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> { class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, T> {
public: public:
void operator()(T* imData, void operator()(T* imData,
const TensorShape& imShape, const TensorShape& imShape,
const T* colData, const T* colData,
...@@ -130,7 +130,7 @@ template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, double>; ...@@ -130,7 +130,7 @@ template class Col2ImFunctor<kCFO, DEVICE_TYPE_CPU, double>;
*/ */
template <class T> template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> { class Im2ColFunctor<kOCF, DEVICE_TYPE_CPU, T> {
public: public:
void operator()(const T* imData, void operator()(const T* imData,
const TensorShape& imShape, const TensorShape& imShape,
T* colData, T* colData,
...@@ -188,7 +188,7 @@ public: ...@@ -188,7 +188,7 @@ public:
*/ */
template <class T> template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> { class Col2ImFunctor<kOCF, DEVICE_TYPE_CPU, T> {
public: public:
void operator()(T* imData, void operator()(T* imData,
const TensorShape& imShape, const TensorShape& imShape,
const T* colData, const T* colData,
......
...@@ -71,7 +71,7 @@ __global__ void im2col(const T* data_im, ...@@ -71,7 +71,7 @@ __global__ void im2col(const T* data_im,
*/ */
template <class T> template <class T>
class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, T> { class Im2ColFunctor<kCFO, DEVICE_TYPE_GPU, T> {
public: public:
void operator()(const T* imData, void operator()(const T* imData,
const TensorShape& imShape, const TensorShape& imShape,
T* colData, T* colData,
...@@ -184,7 +184,7 @@ __global__ void col2im(size_t n, ...@@ -184,7 +184,7 @@ __global__ void col2im(size_t n,
*/ */
template <class T> template <class T>
class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, T> { class Col2ImFunctor<kCFO, DEVICE_TYPE_GPU, T> {
public: public:
void operator()(T* imData, void operator()(T* imData,
const TensorShape& imShape, const TensorShape& imShape,
const T* colData, const T* colData,
...@@ -292,7 +292,7 @@ __global__ void im2colOCF(const T* imData, ...@@ -292,7 +292,7 @@ __global__ void im2colOCF(const T* imData,
*/ */
template <class T> template <class T>
class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, T> { class Im2ColFunctor<kOCF, DEVICE_TYPE_GPU, T> {
public: public:
void operator()(const T* imData, void operator()(const T* imData,
const TensorShape& imShape, const TensorShape& imShape,
T* colData, T* colData,
...@@ -399,7 +399,7 @@ __global__ void col2imOCF(T* imData, ...@@ -399,7 +399,7 @@ __global__ void col2imOCF(T* imData,
*/ */
template <class T> template <class T>
class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, T> { class Col2ImFunctor<kOCF, DEVICE_TYPE_GPU, T> {
public: public:
void operator()(T* imData, void operator()(T* imData,
const TensorShape& imShape, const TensorShape& imShape,
const T* colData, const T* colData,
......
...@@ -240,7 +240,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out, ...@@ -240,7 +240,7 @@ void MulOp<DEVICE_TYPE_CPU>(CpuMatrix& out,
*/ */
template <DeviceType Device> template <DeviceType Device>
class MulFunc : public FunctionBase { class MulFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
aTrans_ = config.get<bool>("aTrans"); aTrans_ = config.get<bool>("aTrans");
bTrans_ = config.get<bool>("bTrans"); bTrans_ = config.get<bool>("bTrans");
...@@ -335,7 +335,7 @@ public: ...@@ -335,7 +335,7 @@ public:
} }
} }
private: private:
bool aTrans_; bool aTrans_;
bool bTrans_; bool bTrans_;
}; };
......
...@@ -24,7 +24,7 @@ namespace paddle { ...@@ -24,7 +24,7 @@ namespace paddle {
*/ */
template <class T> template <class T>
class NaiveConvFunctor { class NaiveConvFunctor {
public: public:
void operator()(const T* inputData, void operator()(const T* inputData,
size_t batchSize, size_t batchSize,
size_t inputChannels, size_t inputChannels,
...@@ -85,7 +85,7 @@ public: ...@@ -85,7 +85,7 @@ public:
template <DeviceType Device> template <DeviceType Device>
class NaiveConvFunction : public ConvFunctionBase { class NaiveConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
......
...@@ -132,7 +132,7 @@ static inline PadConf castToPadConf(const FuncConfig& conf) { ...@@ -132,7 +132,7 @@ static inline PadConf castToPadConf(const FuncConfig& conf) {
template <DeviceType Device> template <DeviceType Device>
class PadFunc : public FunctionBase { class PadFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { pad_ = castToPadConf(config); } void init(const FuncConfig& config) override { pad_ = castToPadConf(config); }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -157,7 +157,7 @@ public: ...@@ -157,7 +157,7 @@ public:
pad_); pad_);
} }
private: private:
PadConf pad_; PadConf pad_;
}; };
...@@ -173,7 +173,7 @@ private: ...@@ -173,7 +173,7 @@ private:
template <DeviceType Device> template <DeviceType Device>
class PadGradFunc : public FunctionBase { class PadGradFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { pad_ = castToPadConf(config); } void init(const FuncConfig& config) override { pad_ = castToPadConf(config); }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -201,7 +201,7 @@ public: ...@@ -201,7 +201,7 @@ public:
pad_); pad_);
} }
private: private:
PadConf pad_; PadConf pad_;
}; };
......
...@@ -129,7 +129,7 @@ void RowConvGrad<DEVICE_TYPE_CPU>(const CpuMatrix& outG, ...@@ -129,7 +129,7 @@ void RowConvGrad<DEVICE_TYPE_CPU>(const CpuMatrix& outG,
template <DeviceType Device> template <DeviceType Device>
class RowConvFunc : public FunctionBase { class RowConvFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override {} void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -176,7 +176,7 @@ public: ...@@ -176,7 +176,7 @@ public:
template <DeviceType Device> template <DeviceType Device>
class RowConvGradFunc : public FunctionBase { class RowConvGradFunc : public FunctionBase {
// TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc // TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc
public: public:
void init(const FuncConfig& config) override {} void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
......
...@@ -92,7 +92,7 @@ void ScaleSubRegionGrad<DEVICE_TYPE_CPU>(const real* inGrad, ...@@ -92,7 +92,7 @@ void ScaleSubRegionGrad<DEVICE_TYPE_CPU>(const real* inGrad,
*/ */
template <DeviceType Device> template <DeviceType Device>
class ScaleSubRegionFunc : public FunctionBase { class ScaleSubRegionFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { conf_ = config; } void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -109,7 +109,7 @@ public: ...@@ -109,7 +109,7 @@ public:
conf_); conf_);
} }
private: private:
FuncConfig conf_; FuncConfig conf_;
}; };
...@@ -124,7 +124,7 @@ private: ...@@ -124,7 +124,7 @@ private:
template <DeviceType Device> template <DeviceType Device>
class ScaleSubRegionGradFunc : public FunctionBase { class ScaleSubRegionGradFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { conf_ = config; } void init(const FuncConfig& config) override { conf_ = config; }
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -141,7 +141,7 @@ public: ...@@ -141,7 +141,7 @@ public:
conf_); conf_);
} }
private: private:
FuncConfig conf_; FuncConfig conf_;
}; };
......
...@@ -75,7 +75,7 @@ void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs, ...@@ -75,7 +75,7 @@ void NHWC2NCHW<DEVICE_TYPE_CPU>(real* outputs,
*/ */
template <DeviceType Device> template <DeviceType Device>
class NCHW2NHWCFunc : public FunctionBase { class NCHW2NHWCFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override {} void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
...@@ -108,7 +108,7 @@ public: ...@@ -108,7 +108,7 @@ public:
*/ */
template <DeviceType Device> template <DeviceType Device>
class NHWC2NCHWFunc : public FunctionBase { class NHWC2NCHWFunc : public FunctionBase {
public: public:
void init(const FuncConfig& config) override {} void init(const FuncConfig& config) override {}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
* TensorShape used to represent shape of normal tensor. * TensorShape used to represent shape of normal tensor.
*/ */
class TensorShape { class TensorShape {
public: public:
TensorShape() : ndims_(0), nelements_(0) { initDims(0); } TensorShape() : ndims_(0), nelements_(0) { initDims(0); }
TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); }; TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); };
...@@ -80,7 +80,7 @@ public: ...@@ -80,7 +80,7 @@ public:
bool operator!=(const TensorShape& t) const { return !(*this == t); } bool operator!=(const TensorShape& t) const { return !(*this == t); }
private: private:
// compute number of elements // compute number of elements
void numElements() { void numElements() {
nelements_ = 1; nelements_ = 1;
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
template <DeviceType Device> template <DeviceType Device>
class NeonDepthwiseConvFunction : public ConvFunctionBase { class NeonDepthwiseConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
template <DeviceType Device> template <DeviceType Device>
class NeonDepthwiseConvTransposeFunction : public ConvFunctionBase { class NeonDepthwiseConvTransposeFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
} }
......
...@@ -46,7 +46,7 @@ nnp_convolution_algorithm get_nnp_convolution_algorithm( ...@@ -46,7 +46,7 @@ nnp_convolution_algorithm get_nnp_convolution_algorithm(
template <DeviceType Device> template <DeviceType Device>
class NNPACKConvFunction : public ConvFunctionBase { class NNPACKConvFunction : public ConvFunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
ConvFunctionBase::init(config); ConvFunctionBase::init(config);
algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo")); algorithm_ = get_nnp_convolution_algorithm(config.get<std::string>("algo"));
...@@ -231,7 +231,7 @@ public: ...@@ -231,7 +231,7 @@ public:
} }
} }
private: private:
nnp_convolution_algorithm algorithm_; nnp_convolution_algorithm algorithm_;
nnp_convolution_transform_strategy transform_strategy_; nnp_convolution_transform_strategy transform_strategy_;
void* workspaceBuffer_; void* workspaceBuffer_;
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册