diff --git a/.clang_format.hook b/.clang_format.hook new file mode 100755 index 0000000000000000000000000000000000000000..1d928216867c0ba3897d71542fea44debf8d72a0 --- /dev/null +++ b/.clang_format.hook @@ -0,0 +1,15 @@ +#!/bin/bash +set -e + +readonly VERSION="3.8" + +version=$(clang-format -version) + +if ! [[ $version == *"$VERSION"* ]]; then + echo "clang-format version check failed." + echo "a version contains '$VERSION' is needed, but get '$version'" + echo "you can install the right version, and make an soft-link to '\$PATH' env" + exit -1 +fi + +clang-format $@ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb8c88787d37faf9ce4d7d856a307c11f1085d98..a772125df64aaf2eafe6cb9e022f62cc29043eb7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,10 +19,10 @@ - id: end-of-file-fixer - repo: local hooks: - - id: clang-format + - id: clang-format-with-version-check name: clang-format description: Format files with ClangFormat. - entry: clang-format -i + entry: ./.clang_format.hook -i language: system files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto)$ - repo: https://github.com/PaddlePaddle/pre-commit-golang diff --git a/CMakeLists.txt b/CMakeLists.txt index c75b83e50cf9cef8290c37f88b38cdc3d77df39c..dcd1218a5b0b62f2739b727391aca31b48ed9ccb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,8 +36,8 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) -option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF) -option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF) +option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND}) +option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND}) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON) option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON) diff --git a/Dockerfile b/Dockerfile index 41b6729124228cec16be35d9b26da8042824b0b0..98f61ba586a681e53b435d592c8e43b1cc964139 100644 --- a/Dockerfile +++ b/Dockerfile @@ -34,9 +34,6 @@ RUN apt-get update && \ net-tools && \ apt-get clean -y -# paddle is using numpy.flip, which is introduced since 1.12.0 -RUN pip --no-cache-dir install 'numpy>=1.12.0' - # Install Go and glide RUN wget -qO- https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz | \ tar -xz -C /usr/local && \ @@ -58,33 +55,22 @@ RUN localedef -i en_US -f UTF-8 en_US.UTF-8 # FIXME: due to temporary ipykernel dependency issue, specify ipykernel jupyter # version util jupyter fixes this issue. RUN pip install --upgrade pip && \ - pip install -U 'protobuf==3.1.0' && \ - pip install -U wheel pillow BeautifulSoup && \ + pip install -U wheel && \ pip install -U docopt PyYAML sphinx && \ - pip install -U sphinx-rtd-theme==0.1.9 recommonmark && \ - pip install pre-commit 'requests==2.9.2' 'ipython==5.3.0' && \ + pip install -U sphinx-rtd-theme==0.1.9 recommonmark + +RUN pip install pre-commit 'ipython==5.3.0' && \ pip install 'ipykernel==4.6.0' 'jupyter==1.0.0' && \ - pip install opencv-python rarfile 'scipy>=0.19.0' 'nltk>=3.2.2' + pip install opencv-python + +COPY ./python/requirements.txt /root/ +RUN pip install -r /root/requirements.txt # To fix https://github.com/PaddlePaddle/Paddle/issues/1954, we use # the solution in https://urllib3.readthedocs.io/en/latest/user-guide.html#ssl-py2 RUN apt-get install -y libssl-dev libffi-dev RUN pip install certifi urllib3[secure] -# TODO(qijun) The template library Eigen doesn't work well with GCC 5 -# coming with the default Docker image, so we switch to use GCC 4.8 -# by default. And I will check Eigen library later. - -RUN ln -sf gcc-4.8 /usr/bin/gcc && \ - ln -sf gcc-ar-4.8 /usr/bin/gcc-ar && \ - ln -sf gcc-nm-4.8 /usr/bin/gcc-nm && \ - ln -sf gcc-ranlib-4.8 /usr/bin/gcc-ranlib && \ - ln -sf gcc-4.8 /usr/bin/x86_64-linux-gnu-gcc && \ - ln -sf gcc-ar-4.8 /usr/bin/x86_64-linux-gnu-gcc-ar && \ - ln -sf gcc-nm-4.8 /usr/bin/x86_64-linux-gnu-gcc-nm && \ - ln -sf gcc-ranlib-4.8 /usr/bin/x86_64-linux-gnu-gcc-ranlib && \ - ln -sf g++-4.8 /usr/bin/g++ && \ - ln -sf g++-4.8 /usr/bin/x86_64-linux-gnu-g++ # Install woboq_codebrowser to /woboq RUN git clone https://github.com/woboq/woboq_codebrowser /woboq && \ diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index db09232c0e69016bf18c1d981e4620e9e804ff7c..0eeccbf7d8a1df17351c8914df6dabf005802787 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -73,10 +73,18 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) # linear algebra libraries for cc_library(xxx SRCS xxx.c DEPS cblas) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) FILE(WRITE ${dummyfile} "const char * dummy = \"${dummyfile}\";") -ADD_LIBRARY(cblas STATIC ${dummyfile}) +IF(${CBLAS_PROVIDER} MATCHES MKL) + ADD_LIBRARY(cblas SHARED ${dummyfile}) +ELSE() + ADD_LIBRARY(cblas STATIC ${dummyfile}) +ENDIF() TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES}) IF(NOT ${CBLAS_FOUND}) ADD_DEPENDENCIES(cblas extern_openblas) LIST(APPEND external_project_dependencies cblas) +ELSE() + IF("${CBLAS_PROVIDER}" STREQUAL "MKLML") + ADD_DEPENDENCIES(cblas mklml) + ENDIF() ENDIF(NOT ${CBLAS_FOUND}) diff --git a/cmake/flags.cmake b/cmake/flags.cmake index b27eb71550b68b5c27e47bf067ae0df329bbd628..ff246b2eb4ed97dd14d45763569b661cefd203c8 100644 --- a/cmake/flags.cmake +++ b/cmake/flags.cmake @@ -9,13 +9,6 @@ function(CheckCompilerCXX11Flag) if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8) message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.") endif() - if(NOT ANDROID) - # TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem. - # Use Debug mode instead for now. - if(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9) - set(CMAKE_BUILD_TYPE "Debug" CACHE STRING "" FORCE) - endif() - endif() elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang") # cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang" # Apple Clang is a different compiler than upstream Clang which havs different version numbers. @@ -160,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF) # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # So, don't set these flags here. -LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread) +LIST(APPEND CUDA_NVCC_FLAGS -std=c++11) LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math) if(CMAKE_BUILD_TYPE STREQUAL "Debug") diff --git a/doc/design/cluster_train/large_model_dist_train.md b/doc/design/cluster_train/large_model_dist_train.md new file mode 100644 index 0000000000000000000000000000000000000000..0c4b5bc24c854b7062d509249bea9c50d42bd5f1 --- /dev/null +++ b/doc/design/cluster_train/large_model_dist_train.md @@ -0,0 +1,101 @@ +# Alalysis of large model distributed training in Paddle + +***NOTE: This is only some note for how we implemeted this scheme in V1, not a new design.*** + +## What is it + +We often encounter cases that the embedding layer parameters(sparse) are so large that we can not store it in the trainer's memory when training. So we need to put them to several servers, and fetch them row by row instead of fetch all of the parameters. + +## How to use + +Specify command-line argument like `--loadsave_parameters_in_pserver=true --ports_num_for_sparse=1 --use_old_updater=1` when starting the paddle trainer. And also add something like `--ports_num_for_sparse=1 --pserver_num_threads=5` when starting pserver processes. + +Accrodingly, configure your embedding layers like: + +```python +SPARSE_REMOTE=True + +w1 = data_layer(name="w1", size=dict_size) +emb1 = embedding_layer(input=w1, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE)) +w2 = data_layer(name="w2", size=dict_size) +emb2 = embedding_layer(input=w2, size=32, param_attr=ParameterAttribute(sparse_update=SPARSE_REMOTE)) +... +``` + +## Implementation details + +```c++ +enum MatType { + MAT_NORMAL, + MAT_NORMAL_SHARED, + MAT_VALUE_SHARED, + MAT_SPARSE_ROW_IDS, + MAT_SPARSE_ROW_AUTO_GROW, + MAT_CACHE_ROW, + MAT_SPARSE_ROW, + MAT_SPARSE_ROW_PREFETCH, + MAT_SPARSE_ROW_PREFETCH_FULL_SIZE, +}; +``` + +`MAT_SPARSE_ROW_PREFETCH` is what we use when configured to fetch only row of matrix when training. + +In `trainer_internal.cpp:L93 trainOneBatch`: + +```c++ + if (config_->getOptConfig().use_sparse_remote_updater()) { + REGISTER_TIMER("prefetch"); + gradientMachine_->prefetch(inArgs); + parameterUpdater_->getParametersRemote(); + } +``` + +When doing actual network forward and backward, at the beginning of each batch, the trainer will try to download one row of data from pserver. + +In `trainer/RemoteParameterUpdater.cpp`: `parameterUpdater_->getParametersRemote();`: + +```c++ +if (fullSize) { + ... +} else { +getParams = [&] { + parameterClient_->getParameterSparse( + /* recvParameterType= */ PARAMETER_VALUE, sendBackParameterType); +}; +applyL1 = [](Parameter& para, real decayRate) { + para.getMat(PARAMETER_VALUE)->applyL1(/*lr=*/1.0f, decayRate); +}; +} +``` + +Calling `parameterClient_->getParameterSparse` will do remote call to pserver's `getParameterSparse`: + +```c++ +void ParameterServer2::getParameterSparse(const SendParameterRequest& request, + std::vector& inputBuffers, + SendParameterResponse* response, + std::vector* outputBuffers) { + (void)inputBuffers; + auto& buffer = *readWriteBuffer_; + size_t numReals = 0; + for (const auto& block : request.blocks()) { + numReals += getParameterConfig(block).dims(1); + } + buffer.resize(numReals); + + VLOG(3) << "pserver: getParameterSparse, numReals=" << numReals; + + ReadLockGuard guard(parameterMutex_); + size_t offset = 0; + for (const auto& block : request.blocks()) { + size_t width = getParameterConfig(block).dims(1); + Buffer buf = {buffer.data() + offset, width}; + int type = request.send_back_parameter_type(); + sendBackParameterSparse(block, type, response, &buf, width, outputBuffers); + offset += width; + } +} +``` + +`getParameterConfig(block).dims(1)` returns the width of the current "parameter block"(a shard of parameter object), +then `getParameterSparse` remote call returns only one row of data to the client. diff --git a/doc/design/mkldnn/README.MD b/doc/design/mkldnn/README.MD index e956994431fbb43438c56dcd96ad8313cf516090..fe8da907d9d45a2164031430ac5b7a3d5523967a 100644 --- a/doc/design/mkldnn/README.MD +++ b/doc/design/mkldnn/README.MD @@ -101,6 +101,7 @@ if use_mkldnn 5. 在**Argument**里添加两个`MkldnnMatrixPtr`,取名为`mkldnnValue`和`mkldnnGrad`,用于存放`MkldnnLayer`会用到的memory buffer。 并且添加函数cvt(会修改为一个更加合适的函数名),用于处理"CPU device"和"MKL-DNN device"之间memory的相互转化。 6. 在父类`Layer`中的`getOutput`函数中添加一段逻辑,用于判断`deviceId`,并针对device在MKL-DNN和CPU之间不统一的情况,做一个前期转换。 也就是调用`Argument`的cvt函数把output统一到需要的device上。 7. 在原来的`FLAGS`中添加一个`use_mkldnn`的flag,用于选择是否使用MKL-DNN的相关功能。 +8. 关于MKLDNN参数的保存。由于MKLDNN参数的格式与PaddlePaddle原有的格式存在不一样的情况,所以需要在保存参数时同时保存该格式信息。目前准备扩展[Header](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/parameter/Parameter.h#L247)里面的`int32_t version`。这个值不管是在v1还是在v2里面,一直保存的是0,所以可以充分利用这个信息,定义一个枚举处理所有MKLDNN的参数格式,从而`MKLDNNLayer`就可以从输入的参数中获取需要的格式信息。 ## References diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 03985260241689a099ae9ebc136bd04831a44167..68304c9fc8b8fa13cb1f99b82517abc87c71496c 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -38,7 +38,7 @@ add_custom_command(TARGET framework_py_proto POST_BUILD WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) cc_library(backward SRCS backward.cc DEPS net_op) -cc_test(backward_test SRCS backward_test.cc DEPS backward) +cc_test(backward_test SRCS backward_test.cc DEPS backward recurrent_op device_context) if(WITH_PYTHON) cc_library(paddle_pybind SHARED diff --git a/paddle/framework/backward.cc b/paddle/framework/backward.cc index 315bdde76d3ffe57b656aa69688def6d274f592c..c226e4e3d2a58d1a647e204c4cd26f4eb6bcd968 100644 --- a/paddle/framework/backward.cc +++ b/paddle/framework/backward.cc @@ -17,12 +17,13 @@ #include #include "paddle/framework/op_registry.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/recurrent_op.h" namespace paddle { namespace framework { template -static void ForEachVarName(Map& names, T callback) { +static void ForEachVarName(const Map& names, T callback) { for (auto& name : names) { for (auto& n : name.second) { if (callback(n)) return; @@ -30,6 +31,7 @@ static void ForEachVarName(Map& names, T callback) { } } +// return whether all the names + suffixes in the set static bool AllInSet( const std::map>& names, const std::string& suffix, const std::unordered_set& set) { @@ -43,12 +45,12 @@ static bool AllInSet( static std::shared_ptr NOP() { auto net_op = std::make_shared(); - net_op->type_ = "@NOP@"; + net_op->SetType("@NOP@"); net_op->CompleteAddOp(); return net_op; } -// Get backward operator from a forward operator, recursively implementation. +// Get backward operator from a forward operator, a recursive implementation. // // no_grad_names the gradient variable names without gradient calculating. // @@ -56,28 +58,31 @@ static std::shared_ptr NOP() { // BackwardRecursive. use `uid = uniq_id++;` to get the unique index, and // pass `uniq_id` through recursive calling. // -// returns The backward operator. For simple situation, it is a simple -// operator. For complex situation, it is a NetOp. +// returns The backward operator. In a simple situation, it may be a simple +// operator, in a complex situation, it maybe a NetOp. // // See Backward.h for details static std::shared_ptr BackwardRecursive( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id); + std::shared_ptr BackwardRecursive( const OperatorBase& forwardOp, std::unordered_set& no_grad_names, size_t& uniq_id) { // If all input gradients of forwarding operator do not need to calculate, // just return an NOP. Not return null ptr because NOP does not take // too much time for calculation, but it is useful for simplifying logic. - if (AllInSet(forwardOp.inputs_, kGradVarSuffix, no_grad_names)) { + if (AllInSet(forwardOp.Inputs() /*names*/, kGradVarSuffix /*suffix*/, + no_grad_names /*set*/)) { return NOP(); } // All output gradients of forwarding operator do not need to calculate. // Then all input gradients cannot be computed at all, and we put them into // `no_grad_names` set. Return an NOP. - if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) { - ForEachVarName(forwardOp.inputs_, + if (AllInSet(forwardOp.Outputs() /*names*/, kGradVarSuffix /*suffix*/, + no_grad_names /*set*/)) { + ForEachVarName(forwardOp.Inputs(), [&no_grad_names](const std::string& name) -> bool { no_grad_names.insert(GradVarName(name)); return false; @@ -93,17 +98,17 @@ std::shared_ptr BackwardRecursive( auto& forwardNet = static_cast(forwardOp); // Map from output gradient variable name to operator's indices in - // backward net. That operator generates that variable. + // backward net's ops_. That operator generates that variable. std::unordered_map> dup_output_ops; size_t local_op_id = 0; - // reversely travel forwardNet + // reversely travel forwardNet and collect all duplicate outputs. for (auto it = forwardNet.ops_.rbegin(); it != forwardNet.ops_.rend(); ++it, ++local_op_id) { auto fwd = *it; auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id); net->AddOp(bwd); - ForEachVarName(bwd->outputs_, + ForEachVarName(bwd->Outputs(), [&dup_output_ops, local_op_id](const std::string& out) { dup_output_ops[out].emplace_back(local_op_id); return false; @@ -112,45 +117,51 @@ std::shared_ptr BackwardRecursive( // Get unique ID for this method. auto uid = uniq_id++; // TODO(dzh): more comment + // multiple operators which have the same output (y for example) may + // overwrite the same y variable when backward, special operations are token + // to handle this case. For each duplicate output, rename it to an alias + // (original name with a offset), append an `add` op for its operator, + // and finally sum all the alias variable to the final output variable y. using Pos = std::pair>; std::list insert_position; for (auto& dup_output_op : dup_output_ops) { const std::string& name = dup_output_op.first; auto& dup_op = dup_output_op.second; + // no duplicate output if (dup_op.size() == 1) continue; - std::vector dup_outputs; + // process the duplicate outputs + std::vector dup_outputs; for (size_t i = 0; i < dup_op.size(); ++i) { + // rename each duplicate output to an alias auto op_offset = dup_op[i]; dup_outputs.push_back(name + "@RENAME@" + std::to_string(uid) + "@" + std::to_string(i)); net->ops_[op_offset]->Rename(name, dup_outputs.back()); } + // collect all the offset to append `add` op for each alias insert_position.push_back( - {dup_op.back(), - OpRegistry::CreateOp( - "add", {{"X", {dup_outputs}}}, {{"Out", {name}}}, - {{"input_format", - std::vector{0, static_cast(dup_outputs.size())}}})}); + {dup_op.back(), OpRegistry::CreateOp("add", {{"X", {dup_outputs}}}, + {{"Out", {name}}}, {})}); } + // make sure the inserted `add` ops follow the BFS order. insert_position.sort( [](const Pos& l, const Pos& r) { return l.first > r.first; }); for (auto& pos : insert_position) { net->InsertOp(pos.first + 1, pos.second); } - } else { std::shared_ptr grad_op = OpRegistry::CreateGradOp(forwardOp); - ForEachVarName(grad_op->inputs_, [&no_grad_names, - &net](std::string& grad_input) { + ForEachVarName(grad_op->Inputs(), [&no_grad_names, &net, + grad_op](const std::string& grad_input) { if (no_grad_names.count(grad_input)) { // +1 for \0 std::string prefix = grad_input.substr( 0, grad_input.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); - grad_input = prefix + kZeroVarSuffix; + grad_op->Rename(grad_input, prefix + kZeroVarSuffix); // If part of input gradient of that operator is not calculated, fill // zero variables to that input gradient. @@ -160,23 +171,39 @@ std::shared_ptr BackwardRecursive( return false; }); - ForEachVarName(grad_op->outputs_, - [&no_grad_names](std::string& grad_output) { + ForEachVarName(grad_op->Outputs(), + [&no_grad_names, &grad_op](const std::string& grad_output) { if (no_grad_names.count(grad_output)) { - grad_output = kEmptyVarName; + grad_op->Rename(grad_output, kEmptyVarName); } return false; }); + // process recurrent gradient op as a special operator. + if (forwardOp.Type() == "recurrent_op") { + // NOTE clean up cycle call somewhere (RNN's stepnet constains itself), or + // this will result in infinite loop. + const auto& rnnop = + *static_cast(&forwardOp); + auto rnn_grad_op = + static_cast(grad_op.get()); + const auto& stepnet_op = + *static_cast(&rnnop.stepnet()); + // create stepnet's gradient op + auto grad_stepnet = BackwardRecursive(stepnet_op, no_grad_names, uniq_id); + rnn_grad_op->set_stepnet( + std::static_pointer_cast(grad_stepnet)); + } + if (net->ops_.empty()) { // Current no aux op is added to network return grad_op; } net->AddOp(grad_op); } - net->type_ = "@GENERATED_BACKWARD@"; + net->SetType("@GENERATED_BACKWARD@"); net->CompleteAddOp(); return net; -} +} // namespace framework // See header for comments std::shared_ptr Backward( diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index ebe52d5f284a8d271b666483001544a805d598ac..d942604bf05998ab9e1ee147b28586e7e4e9777d 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -28,13 +28,6 @@ using OpAttrChecker = framework::OpAttrChecker; using Scope = framework::Scope; using DeviceContext = platform::DeviceContext; -class EmptyOp : public OperatorBase { - public: - using OperatorBase::OperatorBase; - void InferShape(const Scope &scope) const override {} - void Run(const Scope &scope, const DeviceContext &dev_ctx) const override {} -}; - class RowWiseAddOpMaker : public OpProtoAndCheckerMaker { public: RowWiseAddOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -155,27 +148,24 @@ class AddOpMaker : public OpProtoAndCheckerMaker { namespace f = paddle::framework; namespace ops = paddle::operators; using EnforceNotMet = paddle::platform::EnforceNotMet; -REGISTER_OP(rowwise_add, f::EmptyOp, f::RowWiseAddOpMaker); -REGISTER_GRADIENT_OP(rowwise_add, rowwise_add_grad, f::EmptyOp); -REGISTER_OP(mul, f::EmptyOp, f::MulOpMaker); -REGISTER_GRADIENT_OP(mul, mul_grad, f::EmptyOp); -REGISTER_OP(sigmoid, f::EmptyOp, f::SigmoidOpMaker); -REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, f::EmptyOp); -REGISTER_OP(nograd, f::EmptyOp, f::NoGradOpMaker); -REGISTER_OP(fill_zeros_like, f::EmptyOp, f::FillZeroOpMaker); -REGISTER_OP(add, f::EmptyOp, f::AddOpMaker); -REGISTER_GRADIENT_OP(add, add_grad, f::EmptyOp); -REGISTER_OP(fc, f::FcOp, f::FcOpMaker); -REGISTER_OP(many_output_op, f::EmptyOp, f::ManyOutputOpMaker); -REGISTER_GRADIENT_OP(many_output_op, many_output_op_grad, f::EmptyOp); +REGISTER_OP(rowwise_add, f::NOP, f::RowWiseAddOpMaker, rowwise_add_grad, + f::NOP); +REGISTER_OP(mul, f::NOP, f::MulOpMaker, mul_grad, f::NOP); +REGISTER_OP(sigmoid, f::NOP, f::SigmoidOpMaker, sigmoid_grad, f::NOP); +REGISTER_OP_WITHOUT_GRADIENT(nograd, f::NOP, f::NoGradOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, f::NOP, f::FillZeroOpMaker); +REGISTER_OP(add, f::NOP, f::AddOpMaker, add_grad, f::NOP); +REGISTER_OP_WITHOUT_GRADIENT(fc, f::FcOp, f::FcOpMaker); +REGISTER_OP(many_output_op, f::NOP, f::ManyOutputOpMaker, many_output_op_grad, + f::NOP); TEST(Backward, simple_op_grad) { auto fwd = f::OpRegistry::CreateOp( "rowwise_add", {{"X", {"x"}}, {"b", {"b"}}}, {{"Out", {"out"}}}, {}); ASSERT_NE(fwd, nullptr); auto gop = f::OpRegistry::CreateGradOp(*fwd); - ASSERT_EQ(1UL, gop->inputs_.size()); - ASSERT_EQ("rowwise_add_grad", gop->type_); + ASSERT_EQ(1UL, gop->Inputs().size()); + ASSERT_EQ("rowwise_add_grad", gop->Type()); ASSERT_EQ(f::GradVarName("x"), gop->Output(f::GradVarName("X"))); ASSERT_EQ(f::GradVarName("b"), gop->Output(f::GradVarName("b"))); } @@ -211,13 +201,13 @@ TEST(Backward, net_fc_backward_normal) { ASSERT_EQ(3UL, net->ops_.size()); f::OperatorBase &d_sigmoid = *net->ops_[0]; - ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + ASSERT_EQ("sigmoid_grad", d_sigmoid.Type()); f::OperatorBase &d_add = *net->ops_[1]; - ASSERT_EQ("rowwise_add_grad", d_add.type_); + ASSERT_EQ("rowwise_add_grad", d_add.Type()); f::OperatorBase &d_mul = *net->ops_[2]; - ASSERT_EQ("mul_grad", d_mul.type_); + ASSERT_EQ("mul_grad", d_mul.Type()); } TEST(Backward, net_fc_backward_not_have_b) { @@ -237,10 +227,10 @@ TEST(Backward, net_fc_backward_not_have_b) { ASSERT_EQ(2UL, net->ops_.size()); f::OperatorBase &d_sigmoid = *net->ops_[0]; - ASSERT_EQ("sigmoid_grad", d_sigmoid.type_); + ASSERT_EQ("sigmoid_grad", d_sigmoid.Type()); f::OperatorBase &d_mul = *net->ops_[1]; - ASSERT_EQ("mul_grad", d_mul.type_); + ASSERT_EQ("mul_grad", d_mul.Type()); } TEST(Backward, net_input_of_network_not_need_grad) { @@ -294,7 +284,7 @@ TEST(Backward, net_shared_weight) { ASSERT_TRUE(bwd->IsNetOp()); auto bwd_net = static_cast(bwd.get()); ASSERT_EQ(3UL, bwd_net->ops_.size()); - ASSERT_EQ("add", bwd_net->ops_[2]->type_); + ASSERT_EQ("add", bwd_net->ops_[2]->Type()); } TEST(Backward, op_register_grad_not_for_network) { @@ -335,15 +325,15 @@ TEST(Backward, op_part_of_output_are_not_need) { ASSERT_EQ(net->ops_.size(), 2UL); auto &fill_zero = *net->ops_[0]; - ASSERT_EQ("fill_zeros_like", fill_zero.type_); + ASSERT_EQ("fill_zeros_like", fill_zero.Type()); ASSERT_EQ(1UL, fill_zero.Inputs("Src").size()); ASSERT_EQ("Z", fill_zero.Input("Src")); ASSERT_EQ(1UL, fill_zero.Outputs("Dst").size()); ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, fill_zero.Output("Dst")); auto &d_many_out = *net->ops_[1]; - ASSERT_EQ("many_output_op_grad", d_many_out.type_); - ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.inputs_.size()); // I/O/OG + ASSERT_EQ("many_output_op_grad", d_many_out.Type()); + ASSERT_EQ(1UL + 2UL + 2UL, d_many_out.Inputs().size()); // I/O/OG ASSERT_EQ(std::string("Z") + f::kZeroVarSuffix, d_many_out.Input(f::GradVarName("z"))); ASSERT_EQ(f::GradVarName("Y"), d_many_out.Input(f::GradVarName("y"))); @@ -355,9 +345,9 @@ TEST(Backward, op_part_of_input_are_not_need) { {{"Out", {"out"}}}, {}); auto backward = f::Backward(*fwd, {"a"}); auto &grad_mul = *backward; - ASSERT_EQ(grad_mul.type_, "mul_grad"); - ASSERT_EQ(grad_mul.inputs_.size(), 2UL + 1UL + 1UL); - ASSERT_EQ(grad_mul.outputs_.size(), 2UL); + ASSERT_EQ(grad_mul.Type(), "mul_grad"); + ASSERT_EQ(grad_mul.Inputs().size(), 2UL + 1UL + 1UL); + ASSERT_EQ(grad_mul.Outputs().size(), 2UL); ASSERT_EQ(grad_mul.Output(f::GradVarName("X")), f::kEmptyVarName); ASSERT_EQ(grad_mul.Output(f::GradVarName("Y")), f::GradVarName("b")); ASSERT_EQ(grad_mul.Input(f::GradVarName("Out")), f::GradVarName("out")); @@ -395,18 +385,18 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { auto &grad_fc = *bwd_net->ops_[0]; const char *all = paddle::operators::NetOp::kAll; - EXPECT_EQ(grad_fc.inputs_[all].size(), + EXPECT_EQ(grad_fc.Inputs(all).size(), 2UL /* external input number */ + 1UL /* external output number*/ + 1UL /* number of gradient of external output*/ + 2U /* internal variable number*/); - EXPECT_EQ(grad_fc.outputs_[all].size(), + EXPECT_EQ(grad_fc.Outputs(all).size(), 2UL /* input number of mul*/ + 2UL /* input number of rowwise_add */ + 1UL /* input number of sigmod */); - EXPECT_EQ(bwd_net->ops_[1]->inputs_[all].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[1]->outputs_[all].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->inputs_[all].size(), 0UL); - EXPECT_EQ(bwd_net->ops_[2]->outputs_[all].size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->Inputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[1]->Outputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->Inputs(all).size(), 0UL); + EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); } diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index 21bc30d1fbdae31548547bccf39e78fe16eedfaa..b73dac22d029876de9a012de533647db3dd74cbb 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -13,23 +13,20 @@ express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/grad_op_builder.h" -#include "paddle/framework/framework.pb.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace framework { enum class OpArgType { IN, OUT }; -static void TransOpArg(const OperatorBase* src_op, - OperatorBase::VarNameMap* vars, - const OpArgType& src_type, bool is_grad) { +static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type, + bool is_grad, OperatorBase::VarNameMap* vars) { const auto& src_inout = - src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; + src_type == OpArgType::IN ? src_op->Inputs() : src_op->Outputs(); auto& dst_inout = *vars; - - const OpProto& proto = OpProtos().at(src_op->type_); + const OpProto* proto = OpRegistry::op_info_map().at(src_op->Type()).proto_; const auto& src_arg_list = - src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); + src_type == OpArgType::IN ? proto->inputs() : proto->outputs(); for (const auto& arg : src_arg_list) { if (arg.no_gradient() && !is_grad) continue; const std::string src_name = arg.name(); @@ -43,22 +40,26 @@ static void TransOpArg(const OperatorBase* src_op, } OperatorBase* BuildGradOp(const OperatorBase* op) { - auto gop_type_it = OpRegistry::grad_ops().find(op->type_); - PADDLE_ENFORCE(gop_type_it != OpRegistry::grad_ops().end(), - "Operator %s do not register gradient type", op->type_); - auto& grad_op_type = gop_type_it->second; + auto it = OpRegistry::op_info_map().find(op->Type()); + PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), + "'%s' has not been registered.", op->Type()); + PADDLE_ENFORCE(it->second.proto_ != nullptr, "'%s' has no OpProto.", + op->Type()); + std::string grad_op_type = it->second.grad_op_type_; + PADDLE_ENFORCE(!grad_op_type.empty(), "'%s' has no gradient operator.", + op->Type()); + OperatorBase::VarNameMap inputs; OperatorBase::VarNameMap outputs; - TransOpArg(op, &inputs, OpArgType::IN, false); // I - TransOpArg(op, &inputs, OpArgType::OUT, false); // O - TransOpArg(op, &inputs, OpArgType::OUT, true); // OG - TransOpArg(op, &outputs, OpArgType::IN, true); // IG - auto gop_it = OpRegistry::op_creators().find(grad_op_type); - PADDLE_ENFORCE(gop_it != OpRegistry::op_creators().end(), - "Operator %s 's Gradient %s's creator cannot be found", - op->type_, grad_op_type); + TransOpArg(op, OpArgType::IN, false, &inputs); // I + TransOpArg(op, OpArgType::OUT, false, &inputs); // O + TransOpArg(op, OpArgType::OUT, true, &inputs); // OG + TransOpArg(op, OpArgType::IN, true, &outputs); // IG - return gop_it->second(grad_op_type, inputs, outputs, op->attrs_); + it = OpRegistry::op_info_map().find(grad_op_type); + PADDLE_ENFORCE(it != OpRegistry::op_info_map().end(), + "'%s' has not been registered.", grad_op_type); + return it->second.creator_(grad_op_type, inputs, outputs, op->Attrs()); } } // namespace framework diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index ebaf84545fce0d281d8821861264cddc8854893d..0c26293fd29d24a7a40c47bdf055d2758846709b 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -8,14 +8,6 @@ USE_OP(add_two); namespace paddle { namespace framework { -class NOP : public OperatorBase { - public: - using OperatorBase::OperatorBase; - void InferShape(const Scope &scope) const override {} - void Run(const Scope &scope, - const platform::DeviceContext &dev_ctx) const override {} -}; - class MutiInOutOpMaker : public OpProtoAndCheckerMaker { public: MutiInOutOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -52,8 +44,8 @@ TEST(GradOpBuilder, AddTwo) { "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); std::shared_ptr grad_add_op = f::OpRegistry::CreateGradOp(*add_op); - EXPECT_EQ(grad_add_op->inputs_.size(), 4UL); - EXPECT_EQ(grad_add_op->outputs_.size(), 2UL); + EXPECT_EQ(grad_add_op->Inputs().size(), 4UL); + EXPECT_EQ(grad_add_op->Outputs().size(), 2UL); EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Out"), "out"); @@ -62,10 +54,8 @@ TEST(GradOpBuilder, AddTwo) { EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } -REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); -REGISTER_GRADIENT_OP(mult_io, mult_io_grad, f::NOP); -REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); -REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); +REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker, mult_io_grad, f::NOP); +REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { std::shared_ptr test_op(f::OpRegistry::CreateOp( @@ -76,7 +66,7 @@ TEST(GradOpBuilder, MutiInOut) { std::shared_ptr grad_test_op = f::OpRegistry::CreateGradOp(*test_op); - ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); + ASSERT_EQ(grad_test_op->Inputs().size(), 3UL + 2UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Inputs("In2_mult"), std::vector({"in2_1", "in2_2", "in2_3"})); @@ -90,7 +80,7 @@ TEST(GradOpBuilder, MutiInOut) { std::vector( {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); + ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), std::vector({f::GradVarName("in2_1"), @@ -109,7 +99,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { f::OpRegistry::CreateGradOp(*test_op); // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->inputs_.size(), 2UL + 1UL + 2UL); + ASSERT_EQ(grad_test_op->Inputs().size(), 2UL + 1UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Inputs("In3_mult"), std::vector({"in3_1", "in3_2"})); @@ -121,7 +111,7 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), f::GradVarName("out2")); - ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); + ASSERT_EQ(grad_test_op->Outputs().size(), 3UL); EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), std::vector( diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index c0654b375dd09d7e4e06a5a46ee4bd6be9c5f82b..4fa0a2750b239f7399086a4c5356995c4852eabc 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include #include "paddle/framework/attribute.h" @@ -119,6 +120,12 @@ class OpProtoAndCheckerMaker { bool validated_{false}; }; +class NOPMaker : public OpProtoAndCheckerMaker { + public: + NOPMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) {} +}; + class OpRegistry { using VarNameMap = OperatorBase::VarNameMap; using OpCreator = std::function; public: - template - static void RegisterOp(const std::string& op_type) { - op_creators()[op_type] = []( - const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) { - return new OpType(type, inputs, outputs, attrs); - }; - OpAttrChecker& op_checker = op_checkers()[op_type]; - OpProto& op_proto = OpProtos()[op_type]; - auto maker = ProtoMakerType(&op_proto, &op_checker); - maker.Validate(); - op_proto.set_type(op_type); - PADDLE_ENFORCE( - op_proto.IsInitialized(), - "Fail to initialize %s's OpProto, because %s is not initialized", - op_type, op_proto.InitializationErrorString()); - } + struct OpInfo { + OpCreator creator_; + std::string grad_op_type_; + OpProto* proto_; + OpAttrChecker* checker_; + }; - template - static void RegisterGradOp(const std::string& op_type, - const std::string& grad_op_type) { - op_creators()[grad_op_type] = []( - const std::string& type, const VarNameMap& inputs, - const VarNameMap& outputs, const AttributeMap& attrs) { - return new GradOpType(type, inputs, outputs, attrs); + template + static void RegisterOp(const std::string& op_type, + const std::string& grad_op_type) { + PADDLE_ENFORCE(op_info_map().count(op_type) == 0, + "'%s' is registered more than once.", op_type); + OpInfo op_info; + op_info.creator_ = [](const std::string& type, const VarNameMap& inputs, + const VarNameMap& outputs, + const AttributeMap& attrs) { + return new OpType(type, inputs, outputs, attrs); }; - grad_ops()[op_type] = grad_op_type; + op_info.grad_op_type_ = grad_op_type; + if (std::type_index(typeid(ProtoMakerType)) != + std::type_index(typeid(NOPMaker))) { + op_info.proto_ = new OpProto; + op_info.checker_ = new OpAttrChecker; + auto maker = ProtoMakerType(op_info.proto_, op_info.checker_); + maker.Validate(); + op_info.proto_->set_type(op_type); + PADDLE_ENFORCE( + op_info.proto_->IsInitialized(), + "Fail to initialize %s's OpProto, because %s is not initialized", + op_type, op_info.proto_->InitializationErrorString()); + } else { + op_info.proto_ = nullptr; + op_info.checker_ = nullptr; + } + op_info_map().insert(std::make_pair(op_type, op_info)); + // register gradient op + if (!grad_op_type.empty()) { + RegisterOp(grad_op_type, ""); + } } static std::shared_ptr CreateOp(const std::string& type, const VarNameMap& inputs, const VarNameMap& outputs, AttributeMap attrs) { - auto op_create_it = op_creators().find(type); - PADDLE_ENFORCE(op_create_it != op_creators().end(), - "Operator %s cannot be found.", type); - op_checkers().at(type).Check(attrs); - - auto op = op_create_it->second(type, inputs, outputs, attrs); - + auto it = op_info_map().find(type); + PADDLE_ENFORCE(it != op_info_map().end(), + "Operator '%s' has not been registered.", type); + it->second.checker_->Check(attrs); + auto op = it->second.creator_(type, inputs, outputs, attrs); return std::shared_ptr(op); } @@ -200,49 +217,32 @@ class OpRegistry { return grad_op; } - static std::unordered_map& grad_ops() { - static std::unordered_map grad_ops_; - return grad_ops_; - } - - static std::unordered_map& op_creators() { - static std::unordered_map op_creators_; - return op_creators_; - } - - private: - static std::unordered_map& op_checkers() { - static std::unordered_map op_checkers_; - return op_checkers_; + static std::unordered_map& op_info_map() { + static std::unordered_map op_info_map_; + return op_info_map_; } }; class Registrar { public: - // In our design, various kinds of classes, e.g., operators and kernels, have - // their corresponding registry and registrar. The action of registration is - // in the constructor of a global registrar variable, which, however, are not - // used in the code that calls package framework, and would be removed from - // the generated binary file by the linker. To avoid such removal, we add - // Touch to all registrar classes and make USE_OP macros to call this - // method. So, as long as the callee code calls USE_OP, the global + // In our design, various kinds of classes, e.g., operators and kernels, + // have their corresponding registry and registrar. The action of + // registration is in the constructor of a global registrar variable, which, + // however, are not used in the code that calls package framework, and would + // be removed from the generated binary file by the linker. To avoid such + // removal, we add Touch to all registrar classes and make USE_OP macros to + // call this method. So, as long as the callee code calls USE_OP, the global // registrar variable won't be removed by the linker. void Touch() {} }; -template +template class OpRegistrar : public Registrar { public: - explicit OpRegistrar(const char* op_type) { - OpRegistry::RegisterOp(op_type); - } -}; - -template -class GradOpRegistrar : public Registrar { - public: - GradOpRegistrar(const char* op_type, const char* grad_op_type) { - OpRegistry::RegisterGradOp(op_type, grad_op_type); + explicit OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); } + OpRegistrar(const char* op_type, const char* grad_op_type) { + OpRegistry::RegisterOp(op_type, + grad_op_type); } }; @@ -268,7 +268,8 @@ class OpKernelRegistrar : public Registrar { /** * Macro to register Operator. */ -#define REGISTER_OP(op_type, op_class, op_maker_class) \ +#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, \ + grad_op_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \ class _OpClass_##op_type##_ : public op_class { \ @@ -276,33 +277,21 @@ class OpKernelRegistrar : public Registrar { DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ }; \ - static ::paddle::framework::OpRegistrar<_OpClass_##op_type##_, \ - op_maker_class> \ - __op_registrar_##op_type##__(#op_type); \ + class _OpGradClass_##op_type##_ : public grad_op_class { \ + public: \ + DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ + DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ + }; \ + static ::paddle::framework::OpRegistrar< \ + _OpClass_##op_type##_, op_maker_class, _OpGradClass_##op_type##_> \ + __op_registrar_##op_type##__(#op_type, #grad_op_type); \ int TouchOpRegistrar_##op_type() { \ __op_registrar_##op_type##__.Touch(); \ return 0; \ } -/** - * Macro to register Gradient Operator. - */ -#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##op_type##_##grad_op_type, \ - "REGISTER_GRADIENT_OP must be called in global namespace"); \ - class _OpGradClass_##op_type##_ : public grad_op_class { \ - public: \ - DEFINE_OP_CLONE_METHOD(_OpGradClass_##op_type##_); \ - DEFINE_OP_CONSTRUCTOR(_OpGradClass_##op_type##_, grad_op_class); \ - }; \ - static ::paddle::framework::GradOpRegistrar<_OpGradClass_##op_type##_> \ - __op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \ - #grad_op_type); \ - int TouchOpGradientRegistrar_##op_type() { \ - __op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \ - return 0; \ - } +#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ + REGISTER_OP(op_type, op_class, op_maker_class, , ::paddle::framework::NOP) /** * Macro to register OperatorKernel. @@ -318,14 +307,6 @@ class OpKernelRegistrar : public Registrar { return 0; \ } -/** - * Macro to Forbid user register Gradient Operator. - */ -#define NO_GRADIENT(op_type) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_gradient_op__##op_type##_##op_type##_grad, \ - "NO_GRADIENT must be called in global namespace") - #define REGISTER_OP_GPU_KERNEL(op_type, ...) \ REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__) @@ -333,7 +314,8 @@ class OpKernelRegistrar : public Registrar { REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__) /** - * Macro to mark what Operator and Kernel we will use and tell the compiler to + * Macro to mark what Operator and Kernel + * we will use and tell the compiler to * link them into target. */ #define USE_OP_ITSELF(op_type) \ @@ -344,23 +326,6 @@ class OpKernelRegistrar : public Registrar { static int use_op_itself_##op_type##_ __attribute__((unused)) = \ TouchOpRegistrar_##op_type() -// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use -// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't -// be compiled. `NO_GRAD` should be removed after all gradient ops are -// compeleted. -#define NO_GRAD -#ifndef NO_GRAD -#define USE_OP_GRADIENT(op_type) \ - STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __use_op_gradient_##op_type, \ - "USE_OP_GRADIENT must be called in global namespace"); \ - extern int TouchOpGradientRegistrar_##op_type(); \ - static int use_op_gradient_##op_type##_ __attribute__((unused)) = \ - TouchOpGradientRegistrar_##op_type() -#else -#define USE_OP_GRADIENT(op_type) -#endif - #define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __use_op_kernel_##op_type##_##DEVICE_TYPE##__, \ @@ -370,7 +335,8 @@ class OpKernelRegistrar : public Registrar { __attribute__((unused)) = \ TouchOpKernelRegistrar_##op_type##_##DEVICE_TYPE() -// TODO(fengjiayi): The following macros seems ugly, do we have better method? +// TODO(fengjiayi): The following macros +// seems ugly, do we have better method? #ifdef PADDLE_ONLY_CPU #define USE_OP_KERNEL(op_type) USE_OP_DEVICE_KERNEL(op_type, CPU) @@ -380,18 +346,13 @@ class OpKernelRegistrar : public Registrar { USE_OP_DEVICE_KERNEL(op_type, GPU) #endif -#define USE_NO_GRAD_OP(op_type) \ - USE_OP_ITSELF(op_type); \ - USE_OP_KERNEL(op_type) - -#define USE_CPU_OP(op_type) \ - USE_OP_ITSELF(op_type); \ - USE_OP_DEVICE_KERNEL(op_type, CPU); \ - USE_OP_GRADIENT(op_type) +#define USE_CPU_ONLY_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_DEVICE_KERNEL(op_type, CPU); -#define USE_OP(op_type) \ - USE_NO_GRAD_OP(op_type); \ - USE_OP_GRADIENT(op_type) +#define USE_OP(op_type) \ + USE_OP_ITSELF(op_type); \ + USE_OP_KERNEL(op_type) } // namespace framework } // namespace paddle diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 0b8f8289490135b8976c38fa3fb3c2995c50416f..1a85d568350dc04ca1df28129de19cd45b5204b8 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -59,11 +59,10 @@ static void BuildVar(const std::string& param_name, var->add_arguments(arg_name); } } - -REGISTER_OP(cos_sim, paddle::framework::CosineOp, - paddle::framework::CosineOpProtoAndCheckerMaker); -REGISTER_OP(my_test_op, paddle::framework::MyTestOp, - paddle::framework::MyTestOpProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT(cos_sim, paddle::framework::CosineOp, + paddle::framework::CosineOpProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT(my_test_op, paddle::framework::MyTestOp, + paddle::framework::MyTestOpProtoAndCheckerMaker); TEST(OpRegistry, CreateOp) { paddle::framework::OpDesc op_desc; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 13442a72b9d77a4858b5d91dd7690e089ec7ed49..0daf12e7f5f3539d460ce67d39ca1c06f5aa2237 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -33,14 +33,6 @@ ExecutionContext::GetEigenDevice() const { } #endif -static std::unordered_map* g_op_protos = nullptr; -std::unordered_map& OpProtos() { - if (g_op_protos == nullptr) { - g_op_protos = new std::unordered_map(); - } - return *g_op_protos; -} - const std::string& OperatorBase::Input(const std::string& name) const { auto& ins = Inputs(name); PADDLE_ENFORCE_EQ(ins.size(), 1UL, @@ -149,14 +141,18 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { } return ret_val; } - auto it = OpProtos().find(type_); + auto it = OpRegistry::op_info_map().find(type_); PADDLE_ENFORCE( - it != OpProtos().end(), + it != OpRegistry::op_info_map().end(), "Operator %s not registered, cannot figure out intermediate outputs", type_); + PADDLE_ENFORCE( + it->second.proto_ != nullptr, + "Operator %s has no OpProto, cannot figure out intermediate outputs", + type_); // get all OpProto::Var for outputs - for (auto& o : it->second.outputs()) { + for (auto& o : it->second.proto_->outputs()) { // ignore all intermediate output if (o.intermediate()) continue; auto out = outputs_.find(o.name()); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 9e8aef6f8563bacaa0bc5bc46abc031dadaa39ae..90e30bee0a1ae6d8e70b46981c39579df0a76028 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -50,8 +50,6 @@ inline std::string GradVarName(const std::string& var_name) { return var_name + kGradVarSuffix; } -extern std::unordered_map& OpProtos(); - class OperatorBase; class InferShapeContext; class ExecutionContext; @@ -95,6 +93,8 @@ class OperatorBase { /// rename inputs outputs name void Rename(const std::string& old_name, const std::string& new_name); + const VarNameMap& Inputs() const { return inputs_; } + const VarNameMap& Outputs() const { return outputs_; } //! Get a input with argument's name described in `op_proto` const std::string& Input(const std::string& name) const; //! Get a input which has multiple variables. @@ -108,14 +108,15 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; - std::string Type() const { return type_; } + const std::string& Type() const { return type_; } + void SetType(const std::string& type) { type_ = type; } const AttributeMap& Attrs() const { return attrs_; } // Return a new operator instance, which is as same as this. // Use unique_ptr to prevent caller forget to delete this pointer. virtual std::unique_ptr Clone() const = 0; - public: + protected: std::string type_; // NOTE: in case of OpGrad, inputs_ contains: // I (Inputs)opear @@ -131,7 +132,7 @@ class OperatorBase { // Macro for define a clone method. // If you are writing an kernel operator, `Clone` will be defined when you -// register it. +// register it. i.e. `Clone` method is not needed to define by yourself. #define DEFINE_OP_CLONE_METHOD(CLS) \ std::unique_ptr Clone() const final { \ return std::unique_ptr(new CLS(*this)); \ @@ -146,6 +147,17 @@ class OperatorBase { const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \ : PARENT_CLS(type, inputs, outputs, attrs) {} +class NOP : public OperatorBase { + public: + using OperatorBase::OperatorBase; + void InferShape(const Scope& scope) const override {} + void Run(const Scope& scope, + const platform::DeviceContext& dev_ctx) const override {} + std::unique_ptr Clone() const override { + return std::unique_ptr(new NOP(*this)); + } +}; + class InferShapeContext { public: InferShapeContext(const OperatorBase& op, const Scope& scope) @@ -227,7 +239,7 @@ class InferShapeContext { [&](const std::string& sub_name) { auto var = scope_.FindVar(sub_name); PADDLE_ENFORCE_NOT_NULL( - var, "MultiOutput(%s:%s) should not be nullptr", name, + var, "MultiOutput(%s:%s) should not be nullptr.", name, sub_name); return var->GetMutable(); }); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 8836217126581510e10bf40e411e6fdf96b9d381..2425b87779f6af01b0e8a91b5f574a28385f0efd 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -65,8 +65,9 @@ static void BuildVar(const std::string& param_name, } } -REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, - paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + test_operator, paddle::framework::OpWithoutKernelTest, + paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); TEST(OperatorBase, all) { paddle::framework::OpDesc op_desc; @@ -184,8 +185,9 @@ class CPUKernalMultiInputsTest : public OpKernel { } // namespace framework } // namespace paddle -REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, - paddle::framework::OpKernelTestProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + op_with_kernel, paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestProtoAndCheckerMaker); REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); @@ -210,8 +212,9 @@ TEST(OpKernel, all) { ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); } -REGISTER_OP(op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, - paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + op_multi_inputs_with_kernel, paddle::framework::OpWithKernelTest, + paddle::framework::OpKernelTestMultiInputsProtoAndCheckerMaker); REGISTER_OP_CPU_KERNEL(op_multi_inputs_with_kernel, paddle::framework::CPUKernalMultiInputsTest); diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 07b42c83717652bdf0120b3004f39ac7f7a98d06..fe0c87bc570825014222807cb90a3bb341b44e8e 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/framework/op_registry.h" #include "paddle/framework/tensor_py.h" #include "paddle/operators/net_op.h" +#include "paddle/operators/recurrent_op.h" #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "paddle/string/to_string.h" @@ -30,8 +31,8 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_CPU_OP(onehot_cross_entropy); -USE_NO_GRAD_OP(sgd); +USE_CPU_ONLY_OP(onehot_cross_entropy); +USE_OP(sgd); USE_OP(mul); USE_OP(mean); USE_OP(sigmoid); @@ -53,15 +54,15 @@ void ExposeOperator(ClassType &m) { .def("run", &ClassType::type::Run) .def("type", [](const typename ClassType::type &op) -> std::string { - return op.type_; + return op.Type(); }) .def("outputs", [](const typename ClassType::type &op) -> std::map> { - return op.outputs_; + return op.Outputs(); }) .def("inputs", - [](const typename ClassType::type &op) { return op.inputs_; }) + [](const typename ClassType::type &op) { return op.Inputs(); }) .def("__str__", &ClassType::type::DebugString) .def("no_intermediate_outputs", [](const typename ClassType::type &op) { @@ -160,13 +161,16 @@ All parameter, weight, gradient are variables in Paddle. //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { - auto &protos = OpProtos(); + auto &op_info_map = OpRegistry::op_info_map(); std::vector ret_values; - for (auto it = protos.begin(); it != protos.end(); ++it) { - PADDLE_ENFORCE(it->second.IsInitialized(), - "OpProto must all be initialized"); + for (auto it = op_info_map.begin(); it != op_info_map.end(); ++it) { + const OpProto *proto = it->second.proto_; + if (proto == nullptr) { + continue; + } + PADDLE_ENFORCE(proto->IsInitialized(), "OpProto must all be initialized"); std::string str; - PADDLE_ENFORCE(it->second.SerializeToString(&str), + PADDLE_ENFORCE(proto->SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); ret_values.push_back(py::bytes(str)); } @@ -229,7 +233,7 @@ All parameter, weight, gradient are variables in Paddle. net.def_static("create", []() -> std::shared_ptr { auto retv = std::make_shared(); - retv->type_ = "plain_net"; + retv->SetType("plain_net"); return retv; }) .def("add_op", &operators::NetOp::AddOp) @@ -238,6 +242,11 @@ All parameter, weight, gradient are variables in Paddle. const std::shared_ptr &net) -> void { self.AddOp(std::static_pointer_cast(net)); }) + .def("add_op", + [](operators::NetOp &self, + const std::shared_ptr &rnn) -> void { + self.AddOp(std::static_pointer_cast(rnn)); + }) .def("complete_add_op", &operators::NetOp::CompleteAddOp) .def("complete_add_op", [](std::shared_ptr &self) { self->CompleteAddOp(); @@ -245,6 +254,29 @@ All parameter, weight, gradient are variables in Paddle. ExposeOperator(net); + // recurrent_op + py::class_> + rnn(m, "RecurrentOp"); + + rnn.def_static( + "create", + [](py::bytes protobin) -> std::shared_ptr { + OpDesc desc; + PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), + "Cannot parse user input to OpDesc"); + PADDLE_ENFORCE(desc.IsInitialized(), + "User OpDesc is not initialized, reason %s", + desc.InitializationErrorString()); + auto rnn_op = OpRegistry::CreateOp(desc); + return std::dynamic_pointer_cast(rnn_op); + }) + .def("set_stepnet", + [](operators::RecurrentOp &self, + const std::shared_ptr &net) -> void { + self.set_stepnet(net); + }); + ExposeOperator(rnn); + m.def("unique_integer", UniqueIntegerGenerator); m.def("is_compile_gpu", IsCompileGPU); diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt index 8035d93bfec75b20a54c5af0521ab724cafba8ca..9cc4233e43267472d405c3e4e617f0782e1430ea 100644 --- a/paddle/memory/CMakeLists.txt +++ b/paddle/memory/CMakeLists.txt @@ -1,7 +1,7 @@ add_subdirectory(detail) cc_library(memory SRCS memory.cc) -cc_library(memcpy SRCS memcpy.cc DEPS device_context) +cc_library(memcpy SRCS memcpy.cc) cc_library(paddle_memory DEPS diff --git a/paddle/memory/memcpy.cc b/paddle/memory/memcpy.cc index aaab1142ca18d3319469a4d685fde9d30929113f..a19a3e3675e3e2e7cc0c3594f21191f932d6379f 100644 --- a/paddle/memory/memcpy.cc +++ b/paddle/memory/memcpy.cc @@ -16,8 +16,6 @@ limitations under the License. */ #include // for memcpy -#include "paddle/platform/device_context.h" - namespace paddle { namespace memory { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 373611cc0ee952de813f01d32d1516e1a8384750..a7c89787e43df6173791bc54b3faffc034867f7d 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -44,6 +44,8 @@ endfunction() add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) +cc_test(scatter_test SRCS scatter_test.cc DEPS tensor) + cc_library(net_op SRCS net_op.cc DEPS op_registry) cc_test(net_op_test SRCS net_op_test.cc DEPS net_op) @@ -64,6 +66,5 @@ op_library(sgd_op SRCS sgd_op.cc sgd_op.cu) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) -cc_test(recurrent_op_test SRCS recurrent_op_test.cc DEPS recurrent_op gtest mul_op add_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index c1f647a88e4547d96bbb9143cdb2cb07bc291635..8ab748ed71e9a5dc0ee0259a78a2b886870bec5b 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -57,8 +57,7 @@ class AddOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker); -REGISTER_GRADIENT_OP(add_two, add_two_grad, ops::AddOpGrad); +REGISTER_OP(add_two, ops::AddOp, ops::AddOpMaker, add_two_grad, ops::AddOpGrad); REGISTER_OP_CPU_KERNEL(add_two, ops::AddKernel); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 597c71d4e042e6b6a752c0b1819b909a7a9faa75..a623c551e1088365ade6f73bc6149977b6ef017e 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -68,12 +68,11 @@ OnehotCrossEntropy Operator. namespace ops = paddle::operators; REGISTER_OP(onehot_cross_entropy, ops::OnehotCrossEntropyOp, - ops::OnehotCrossEntropyOpMaker); + ops::OnehotCrossEntropyOpMaker, onehot_cross_entropy_grad, + ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL( onehot_cross_entropy, ops::OnehotCrossEntropyOpKernel); -REGISTER_GRADIENT_OP(onehot_cross_entropy, onehot_cross_entropy_grad, - ops::OnehotCrossEntropyGradientOp); REGISTER_OP_CPU_KERNEL( onehot_cross_entropy_grad, ops::OnehotCrossEntropyGradientOpKernel); diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index e42e33f1a3759ae26cee987d0b68a55b672e3f94..9d51f6e3a16fe96125599bb440d40237aeb9a028 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -46,7 +46,8 @@ The output will have the same size with input. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(fill_zeros_like, ops::FillZerosLikeOp, ops::FillZerosLikeOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(fill_zeros_like, ops::FillZerosLikeOp, + ops::FillZerosLikeOpMaker); REGISTER_OP_CPU_KERNEL( fill_zeros_like, ops::FillZerosLikeKernel); diff --git a/paddle/operators/gather.h b/paddle/operators/gather.h index 0c73717d38aca9f3430e66cafc3ecccdd2eec776..d6e6990394e46ba06c4bacfe33ca522f3ff1413a 100644 --- a/paddle/operators/gather.h +++ b/paddle/operators/gather.h @@ -29,7 +29,7 @@ void CPUGather(const T* params, const int* indices, const int slice_size, const int index_size, T* output) { const size_t slice_bytes = slice_size * sizeof(T); - for (size_t i = 0; i < index_size; ++i) { + for (int i = 0; i < index_size; ++i) { int index_ = indices[i]; memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes); } @@ -60,7 +60,7 @@ void Gather(const platform::Place& place, const paddle::framework::Tensor* src, // slice size int slice_size = 1; - for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + for (int i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i]; // Gathering if (platform::is_cpu_place(place)) { diff --git a/paddle/operators/gather_test.cc b/paddle/operators/gather_test.cc index 5de748ec461e4b1a34b75b57c9cd7d5bc9326059..d24d83f299fdb071e60fa3cc7b223c0228cb29af 100644 --- a/paddle/operators/gather_test.cc +++ b/paddle/operators/gather_test.cc @@ -35,7 +35,7 @@ TEST(Gather, GatherData) { p_src = src->mutable_data(make_ddim({3, 4}), CPUPlace()); p_index = index->mutable_data(make_ddim({2}), CPUPlace()); - for (size_t i = 0; i < 12; ++i) p_src[i] = i; + for (int i = 0; i < 12; ++i) p_src[i] = i; p_index[0] = 1; p_index[1] = 0; @@ -43,6 +43,6 @@ TEST(Gather, GatherData) { Gather(CPUPlace(), src, index, output); - for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); - for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); + for (int i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4); + for (int i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4); } diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 75249c08eb00095615fc75eb9261432d64246b2e..f30bbce9586d61063b4b61d98695bb568ef73c8d 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -81,5 +81,6 @@ Use to initialize tensor with gaussian random generator. } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(gaussian_random, ops::GaussianRandomOp, + ops::GaussianRandomOpMaker); REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel); diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index abcaf940ab0128d6948acc620d678632c8f48960..ed51d416ed9497eee45ba826ad672b8fb1ad3678 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,13 +1,8 @@ -if(WITH_MKLML) - set(BLAS_LIB mklml) -else() - set(BLAS_LIB cblas) -endif() if(WITH_GPU) - nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS cblas device_context) else() - cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context) + cc_library(math_function SRCS math_function.cc DEPS cblas device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 35e7212dde210a50285272cfd94118fa34fb7cd9..49d0f43508b1ee3df0c6b5987942970e1649e310 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -54,9 +54,8 @@ class MeanGradOp : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker); +REGISTER_OP(mean, ops::MeanOp, ops::MeanOpMaker, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean, ops::MeanKernel); -REGISTER_GRADIENT_OP(mean, mean_grad, ops::MeanGradOp); REGISTER_OP_CPU_KERNEL(mean_grad, ops::MeanGradKernel); diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 032d234197c12fe107fb195e862c160948ee354c..95d19fb6aad37143e65759b03e12e3e78bce5915 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -70,7 +70,5 @@ class MulOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker); -REGISTER_GRADIENT_OP(mul, mul_grad, ops::MulOpGrad); - +REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad); REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index 77eb07e2f9018d966897f4b6eb1308b28394c03d..a7d710511093dfbe13a13b1222b0230bba0398bd 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -29,7 +29,7 @@ void NetOp::CompleteAddOp(bool calc) { std::set input_set; std::set output_set; for (auto& op : ops_) { - for (auto& ipt : op->inputs_) { + for (auto& ipt : op->Inputs()) { for (auto& var_name : ipt.second) { if (!Contains(output_set, var_name)) { // Not other op's output input_set.insert(var_name); @@ -39,7 +39,7 @@ void NetOp::CompleteAddOp(bool calc) { } } - for (auto& opt : op->outputs_) { + for (auto& opt : op->Outputs()) { for (auto& var_name : opt.second) { output_set.insert(var_name); } diff --git a/paddle/operators/net_op_test.cc b/paddle/operators/net_op_test.cc index 6d6f8bd354817cd5cabd050b703e16bdeeaa623b..e28d4df6a570968205851c2e5b630a14c0492535 100644 --- a/paddle/operators/net_op_test.cc +++ b/paddle/operators/net_op_test.cc @@ -21,14 +21,6 @@ class TestOp : public framework::OperatorBase { } }; -class EmptyOp : public framework::OperatorBase { - public: - using framework::OperatorBase::OperatorBase; - DEFINE_OP_CLONE_METHOD(EmptyOp); - void InferShape(const Scope& scope) const override {} - void Run(const Scope& scope, const DeviceContext& dev_ctx) const override {} -}; - template void AssertSameVectorWithoutOrder(const std::vector& expected, const std::vector& actual) { @@ -58,8 +50,8 @@ TEST(OpKernel, all) { net->CompleteAddOp(); AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, - net->inputs_.at(NetOp::kAll)); - AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_.at(NetOp::kAll)); + net->Inputs(NetOp::kAll)); + AssertSameVectorWithoutOrder({"y", "z"}, net->Outputs(NetOp::kAll)); auto final_outs = net->OutputVars(false); @@ -69,9 +61,9 @@ TEST(OpKernel, all) { TEST(NetOp, insert_op) { NetOp net; - auto op1 = std::shared_ptr( - new EmptyOp("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, - {{"Out", {"y"}}}, {})); + auto op1 = std::shared_ptr( + new framework::NOP("empty", {{"X", {"x"}}, {"W", {"w1"}}, {"b", {"b1"}}}, + {{"Out", {"y"}}}, {})); net.AddOp(op1); net.InsertOp(0, op1); ASSERT_EQ(2UL, net.ops_.size()); @@ -81,8 +73,10 @@ TEST(NetOp, insert_op) { TEST(NetOp, Clone) { NetOp net; - net.AddOp(std::shared_ptr(new EmptyOp{"empty", {}, {}, {}})); - net.AddOp(std::shared_ptr(new EmptyOp{"empty2", {}, {}, {}})); + net.AddOp( + std::shared_ptr(new framework::NOP{"empty", {}, {}, {}})); + net.AddOp(std::shared_ptr( + new framework::NOP{"empty2", {}, {}, {}})); net.CompleteAddOp(true); auto new_net_op = net.Clone(); ASSERT_NE(new_net_op, nullptr); diff --git a/paddle/operators/recurrent_op.cc b/paddle/operators/recurrent_op.cc index 5ddee75581824996fd312f8ddf13007759fd9a67..78ce0ba3c0fa4fe380e49a848c2434fe593cd00b 100644 --- a/paddle/operators/recurrent_op.cc +++ b/paddle/operators/recurrent_op.cc @@ -36,15 +36,13 @@ void RecurrentAlgorithm::InferShape(const Scope& scope) const { rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); InitMemories(step_scopes[0], true /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (size_t i = 0; i < seq_len_; i++) { if (i > 0) { rnn::LinkMemories(step_scopes, arg_->memories, i, -1, true /*infer_shape_mode*/); } - net->GetMutable()->InferShape(*step_scopes[i]); + (*stepnet_)->InferShape(*step_scopes[i]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -56,7 +54,6 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); InitMemories(step_scopes[0], false /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); for (size_t step_id = 0; step_id < seq_len_; step_id++) { // create output alias variables @@ -64,7 +61,7 @@ void RecurrentAlgorithm::Run(const Scope& scope, rnn::LinkMemories(step_scopes, arg_->memories, step_id, -1, false /*infer_shape_mode*/); } - net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, false /*infer_shape_mode*/); @@ -78,18 +75,16 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { auto step_scopes = step_scopes_var->GetMutable>(); // Now all variables in scope must be created outside of op. - auto net_var = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net_var != nullptr, "no stepnet called %s in scope", - arg_->step_net); - auto net_op = net_var->GetMutable(); - PADDLE_ENFORCE(!net_op->outputs_.empty(), "net_op has no outputs"); + PADDLE_ENFORCE_NOT_NULL(stepnet_); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "stepnet_ op has no outputs"); + PADDLE_ENFORCE(!(*stepnet_)->Outputs().empty(), "net_op has no outputs"); if (seq_len_ > step_scopes->size()) { for (size_t i = step_scopes->size(); i < seq_len_; ++i) { auto& step_scope = scope.NewScope(); // create step net's temp inputs - for (auto& input : net_op->inputs_) { + for (auto& input : (*stepnet_)->Inputs()) { // the weight are located in parent scope for (auto& var_name : input.second) { if (!step_scope.FindVar(var_name)) { @@ -98,7 +93,7 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { } } // create stepnet's outputs - for (const auto& output : net_op->outputs_) { + for (const auto& output : (*stepnet_)->Outputs()) { for (auto& var_name : output.second) { step_scope.NewVar(var_name); } @@ -140,9 +135,8 @@ RecurrentOp::RecurrentOp(const std::string& type, const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { - std::unique_ptr arg(new rnn::Argument()); - rnn::InitArgument(kArgName, arg.get(), *this); - alg_.Init(std::move(arg)); + rnn::InitArgument(kArgName, &arg_, *this); + alg_.Init(&arg_, &stepnet_); } class RecurrentAlgorithmProtoAndCheckerMaker @@ -158,7 +152,6 @@ class RecurrentAlgorithmProtoAndCheckerMaker .AsDuplicable(); AddInput(name.boot_memories, "variables to initialize memories.") .AsDuplicable(); - AddInput(name.step_net, "network shared by all steps."); AddOutput(name.outlinks, "the outputs that need to concated for all steps.") .AsDuplicable(); @@ -180,14 +173,12 @@ void RecurrentGradientAlgorithm::Run( auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, false /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, false /*infer_shape_mode*/); } - net->GetMutable()->Run(*step_scopes[step_id], dev_ctx); + (*stepnet_)->Run(*step_scopes[step_id], dev_ctx); } LinkBootMemoryGradients(step_scopes[0], false); rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, @@ -219,14 +210,12 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const { auto step_scopes = GetStepScopes(scope); rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_, true /*infer_shape_mode*/); - Variable* net = scope.FindVar(arg_->step_net); - PADDLE_ENFORCE(net != nullptr, "failed to get step net"); for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) { if (static_cast(step_id) != seq_len_ - 1) { rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1, true /*infer_shape_mode*/); } - net->GetMutable()->InferShape(*step_scopes[step_id]); + (*stepnet_)->InferShape(*step_scopes[step_id]); } rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_, true /*infer_shape_mode*/); @@ -238,13 +227,13 @@ RecurrentGradientOp::RecurrentGradientOp( const framework::OperatorBase::VarNameMap& outputs, const framework::AttributeMap& attrs) : OperatorBase(type, inputs, outputs, attrs) { - std::unique_ptr arg(new rnn::Argument()); - rnn::InitArgument(kArgName, arg.get(), *this); - alg_.Init(std::move(arg)); + rnn::InitArgument(kArgName, &arg_, *this); + alg_.Init(&arg_, &stepnet_); } } // namespace operators } // namespace paddle -REGISTER_OP(recurrent_op, paddle::operators::RecurrentOp, - paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); +REGISTER_OP_WITHOUT_GRADIENT( + recurrent_op, paddle::operators::RecurrentOp, + paddle::operators::RecurrentAlgorithmProtoAndCheckerMaker); diff --git a/paddle/operators/recurrent_op.h b/paddle/operators/recurrent_op.h index cc40eff0cf94c75dd21c0e4e8c2b28ef0ed5c3e4..1d8a6973955cf0b4ab372412fbb5428ff2622a0a 100644 --- a/paddle/operators/recurrent_op.h +++ b/paddle/operators/recurrent_op.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/framework/operator.h" +#include "paddle/operators/net_op.h" #include "paddle/operators/rnn/recurrent_op_utils.h" namespace paddle { @@ -33,7 +34,11 @@ class RecurrentAlgorithm { void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; - void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); + arg_ = arg; + stepnet_ = stepnet; + } /** * InferShape must be called before Run. @@ -58,7 +63,8 @@ class RecurrentAlgorithm { void InitMemories(framework::Scope* step_scopes, bool infer_shape_mode) const; private: - std::unique_ptr arg_; + std::shared_ptr* stepnet_; + rnn::Argument* arg_; mutable size_t seq_len_; }; @@ -74,7 +80,11 @@ class RecurrentGradientAlgorithm { * operator. */ public: - void Init(std::unique_ptr arg) { arg_ = std::move(arg); } + void Init(rnn::Argument* arg, std::shared_ptr* stepnet) { + PADDLE_ENFORCE_NOT_NULL(stepnet, "stepnet should be set before."); + arg_ = std::move(arg); + stepnet_ = stepnet; + } void Run(const framework::Scope& scope, const platform::DeviceContext& dev_ctx) const; @@ -95,8 +105,9 @@ class RecurrentGradientAlgorithm { } private: - std::unique_ptr arg_; + rnn::Argument* arg_; mutable size_t seq_len_; + std::shared_ptr* stepnet_; }; class RecurrentOp : public framework::OperatorBase { @@ -122,10 +133,15 @@ class RecurrentOp : public framework::OperatorBase { alg_.Run(scope, dev_ctx); } + void set_stepnet(std::shared_ptr net) { stepnet_ = net; } + const NetOp& stepnet() const { return *stepnet_; } + static const rnn::ArgumentName kArgName; private: RecurrentAlgorithm alg_; + rnn::Argument arg_; + std::shared_ptr stepnet_; }; class RecurrentGradientOp : public framework::OperatorBase { @@ -155,8 +171,13 @@ class RecurrentGradientOp : public framework::OperatorBase { static const rnn::ArgumentName kArgName; + void set_stepnet(const std::shared_ptr& net) { stepnet_ = net; } + const NetOp& stepnet() const { return *stepnet_; } + private: RecurrentGradientAlgorithm alg_; + std::shared_ptr stepnet_; + rnn::Argument arg_; }; } // namespace operators diff --git a/paddle/operators/recurrent_op_test.cc b/paddle/operators/recurrent_op_test.cc deleted file mode 100644 index 2f6eff0720847fdfa6443d2fc233e92dac2d0fab..0000000000000000000000000000000000000000 --- a/paddle/operators/recurrent_op_test.cc +++ /dev/null @@ -1,252 +0,0 @@ -/* - Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. -*/ - -#include "paddle/operators/recurrent_op.h" - -#include -#include - -#include "paddle/framework/ddim.h" -#include "paddle/framework/op_registry.h" -#include "paddle/framework/operator.h" -#include "paddle/framework/tensor.h" -#include "paddle/operators/net_op.h" - -namespace paddle { -namespace operators { - -using namespace paddle::framework; - -class RecurrentGradientAlgorithmTest : public ::testing::Test { - protected: - virtual void SetUp() override { - CreateGlobalVariables(); - CreateStepScopes(); - CreateStepNet(); - CreateRNNGradientAlgorithm(); - - // segment inputs - SegmentInputs(); - // link forward memories - LinkeMemories(); - } - - virtual void TearDown() override {} - - void CreateGlobalVariables() { - // inputs: x - LOG(INFO) << "create global variable x"; - Variable* x = scope_.NewVar("x"); - DDim dims = - make_ddim({10 /*sent size*/, 20 /*batch size*/, 30 /*input dim*/}); - x->GetMutable()->mutable_data(dims, platform::CPUPlace()); - // inputs: h_boot - LOG(INFO) << "create global variable h_boot"; - Variable* h_boot = scope_.NewVar("h_boot"); - h_boot->GetMutable()->mutable_data( - make_ddim({20 /*batch size*/, 30 /*input dim*/}), platform::CPUPlace()); - // inputs: w - LOG(INFO) << "create global variable w"; - Variable* w = scope_.NewVar("rnn/w"); - w->GetMutable()->mutable_data(make_ddim({30, 30}), - platform::CPUPlace()); - // inputs: h_grad - LOG(INFO) << "create variable h_grad"; - Variable* dh = scope_.NewVar("h_grad"); - dh->GetMutable()->mutable_data(make_ddim({10, 20, 30}), - platform::CPUPlace()); - // inputs: step_scopes - LOG(INFO) << "create variable step_scopes"; - scope_.NewVar("step_scopes"); - // inputs: step_net - LOG(INFO) << "create variable step_net"; - scope_.NewVar("step_net"); - // outputs: w_grad - LOG(INFO) << "create global variable w_grad"; - scope_.NewVar("rnn/w_grad"); - // outputs: x_grad - LOG(INFO) << "create global variable x_grad"; - scope_.NewVar("x_grad"); - // outputs: h_boot_grad - LOG(INFO) << "create global variable h_boot_grad"; - scope_.NewVar("h_boot_grad"); - } - - void CreateStepScopes() { - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - for (int i = 0; i < 10; ++i) { - auto& scope = scope_.NewScope(); - auto pre_t = scope.NewVar("rnn/pre_h")->GetMutable(); - pre_t->mutable_data({20, 30}, platform::CPUPlace()); - auto tensor = scope.NewVar("rnn/h")->GetMutable(); - tensor->mutable_data({20, 30}, platform::CPUPlace()); - - // for unit test of ConcatOutputs - auto xg = scope.NewVar("rnn/x_grad")->GetMutable(); - xg->mutable_data({20, 30}, platform::CPUPlace()); - - step_scopes->emplace_back(&scope); - } - - // last time step - auto g = (*step_scopes)[9]->NewVar("rnn/h_pre_grad")->GetMutable(); - g->mutable_data({20, 30}, platform::CPUPlace()); - } - - void CreateRNNGradientAlgorithm() { - std::unique_ptr arg(new rnn::Argument()); - arg->step_net = "step_net"; - arg->step_scopes = "step_scopes"; - rnn::Link inlink; - inlink.external = "h_grad"; - inlink.internal = "rnn/h_grad"; - arg->inlinks = std::vector{inlink}; - - rnn::Link outlink; - outlink.external = "x_grad"; - outlink.internal = "rnn/x_grad"; - arg->outlinks = std::vector{outlink}; - - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "rnn/h_pre_grad"; - mem_attr.var = "rnn/h_grad"; - mem_attr.boot_var = "h_boot_grad"; - arg->memories = std::vector{mem_attr}; - - rnn_grad_algo_.Init(std::move(arg)); - } - - void CreateStepNet() { - LOG(INFO) << "create variable step_net"; - Variable* var = scope_.NewVar("step_net"); - auto net = var->GetMutable(); - // TODO(qingqing) modify backward op create for RNNOp unit test - // and the unit test will be removed to Python. - // net->AddOp(OpRegistry::CreateOp("mul", {"X", {"rnn/h_pre", "rnn/w", - // "rnn/s_grad"}}, {"Y", {"rnn/h_pre_grad", "rnn/w_grad"}}, {})); - - // net->AddOp(OpRegistry::CreateOp("add_two", {"X", {"rnn/h_grad"}}, - // {"Y", {"rnn/x_grad"}}, {"Out", "rnn/s_grad"}}, {})); - net->CompleteAddOp(); - } - - void SegmentInputs() { - LOG(INFO) << "segment inputs"; - std::vector inlinks = {"x"}; - std::vector inlinks_alias = {"rnn/x"}; - - rnn::Link inlink; - inlink.external = "x"; - inlink.internal = "rnn/x"; - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - rnn::SegmentInputs(*step_scopes, std::vector{inlink}, 10, - true /*infer_shape_mode*/); - } - - void LinkeMemories() { - LOG(INFO) << "link memories"; - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "rnn/h_pre"; - mem_attr.var = "rnn/h"; - mem_attr.boot_var = "boot_h"; - std::vector memories; - memories.push_back(mem_attr); - auto step_scopes = - scope_.FindVar("step_scopes")->GetMutable>(); - for (int i = 1; i < 10; ++i) { - rnn::LinkMemories(*step_scopes, memories, i, -1, - true /*infer_shape_mode*/); - } - } - - Scope scope_; - RecurrentGradientAlgorithm rnn_grad_algo_; -}; - -// TEST_F(RecurrentGradientAlgorithmTest, Run) { -// platform::CPUDeviceContext ctx; -// rnn_grad_algo_.Run(scope_, ctx); -// } - -} // namespace operators -} // namespace paddle - -TEST(RecurrentOp, LinkMemories) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::operators; - - // create and init step scopes - size_t len = 10; - std::vector step_scopes; - for (size_t i = 0; i < len; ++i) { - auto scope = new Scope(); - scope->NewVar("pre_h"); - auto tensor = scope->NewVar("h")->GetMutable(); - float* data = tensor->mutable_data({15, 20}, CPUPlace()); - for (size_t j = 0; j < 15 * 20; ++j) { - data[j] = rand() * (1. / (double)RAND_MAX); - } - step_scopes.push_back(scope); - } - - // create MemoryAttr - rnn::MemoryAttr mem_attr; - mem_attr.pre_var = "pre_h"; - mem_attr.var = "h"; - mem_attr.boot_var = "boot_h"; - std::vector memories; - memories.push_back(mem_attr); - - for (size_t i = 1; i < len; ++i) { - rnn::LinkMemories(step_scopes, memories, i, -1, false - /*infer_shape_mode*/); - } - // check - for (size_t i = 0; i < len - 1; ++i) { - const float* a = - step_scopes[i]->FindVar("h")->GetMutable()->data(); - const float* b = step_scopes[i + 1] - ->FindVar("pre_h") - ->GetMutable() - ->data(); - for (size_t j = 0; j < 15 * 20; ++j) { - ASSERT_FLOAT_EQ(a[j], b[j]); - } - } - - for (int i = len - 2; i >= 0; --i) { - rnn::LinkMemories(step_scopes, memories, i, 1, false - /*infer_shape_mode*/); - } - // check - for (int i = len - 2; i >= 0; --i) { - const float* a = - step_scopes[i]->FindVar("pre_h")->GetMutable()->data(); - const float* b = - step_scopes[i + 1]->FindVar("h")->GetMutable()->data(); - for (size_t j = 0; j < 15 * 20; ++j) { - ASSERT_FLOAT_EQ(a[j], b[j]); - } - } - - for (auto s : step_scopes) { - delete s; - } -} - -USE_OP(add_two); -USE_OP(mul); -USE_OP_ITSELF(recurrent_op); diff --git a/paddle/operators/rnn/recurrent_op_utils.cc b/paddle/operators/rnn/recurrent_op_utils.cc index 7e4770630ed2a49214194689aa489e6ab8e476da..a9b65c30f25554e54e9fd7103f240946a93566e2 100644 --- a/paddle/operators/rnn/recurrent_op_utils.cc +++ b/paddle/operators/rnn/recurrent_op_utils.cc @@ -106,7 +106,6 @@ void LinkMemories(const std::vector& scopes, void InitArgument(const ArgumentName& name, Argument* arg, const framework::OperatorBase& op) { - arg->step_net = op.Input(name.step_net); arg->step_scopes = op.Output(name.step_scopes); auto inlinks = op.Inputs(name.inlinks); diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc index b4671c293af1c4fed3b441f05bc8f3a5db039b41..8375d988045dc24fa1109646b46ff477e2a78132 100644 --- a/paddle/operators/rowwise_add_op.cc +++ b/paddle/operators/rowwise_add_op.cc @@ -54,6 +54,7 @@ for i in xrange(X.shape[0]): } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(rowwise_add, ops::RowWiseAddOp, ops::RowWiseAddOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(rowwise_add, ops::RowWiseAddOp, + ops::RowWiseAddOpMaker); REGISTER_OP_CPU_KERNEL( rowwise_add, ops::RowWiseAddKernel); diff --git a/paddle/operators/scatter.h b/paddle/operators/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..6b542675c291607b35f180123cf42fee6a783a85 --- /dev/null +++ b/paddle/operators/scatter.h @@ -0,0 +1,92 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include + +#include "paddle/framework/ddim.h" +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +// Implementation of CPU copy +template +void CPUScatterUpdate(const paddle::framework::Tensor* src, const int* index, + const size_t index_size, + paddle::framework::Tensor* output) { + paddle::framework::DDim output_dims = output->dims(); + + for (size_t i = 0; i < index_size; ++i) { + int index_ = index[i]; + + paddle::framework::Tensor src_ = *src; + paddle::framework::Tensor output_ = *output; + if (index_size > 1) src_ = src->Slice(i, i + 1); + if (output_dims[0] > 1) output_ = output->Slice(index_, index_ + 1); + + auto X = EigenVector::Flatten(src_); + auto Y = EigenVector::Flatten(output_); + + Y = X + Y; + } +} + +// Implementation of GPU scatter: +template +void GPUScatterUpdate(const T* src, const int* index, const int slice_size, + const int index_size, T* output); + +/** + * Return a updated tensor from source tensor, scattered according to index: + * dst[i] += src[index[i]] + * input[src]: type-T source Tensor + * input[index]: type-int index Tensor (1-D) + * return: output tensor + */ +template +void ScatterUpdate(const platform::Place& place, + const paddle::framework::Tensor* src, + const paddle::framework::Tensor* index, + paddle::framework::Tensor* output) { + // check index of shape 1-D + PADDLE_ENFORCE(index->dims().size() == 1); + int index_size = index->dims()[0]; + + auto src_dims = src->dims(); + auto dst_dims = output->dims(); + + // check src shape and dst shape should match + for (int i = 1; i < src_dims.size(); i++) + PADDLE_ENFORCE(src_dims[i] == dst_dims[i]); + + // slice size + size_t slice_size = 1; + for (int i = 0; i < src_dims.size(); ++i) slice_size *= src_dims[i]; + + if (platform::is_cpu_place(place)) { + CPUScatterUpdate(src, index->data(), index_size, output); + } else { + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/scatter_test.cc b/paddle/operators/scatter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4449ce6564396f1971506efb7458c00c834db19f --- /dev/null +++ b/paddle/operators/scatter_test.cc @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/scatter.h" +#include "paddle/framework/ddim.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/place.h" + +#include +#include +#include + +TEST(scatter, ScatterUpdate) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators; + + Tensor* src = new Tensor(); + Tensor* index = new Tensor(); + Tensor* output = new Tensor(); + + float* p_src = nullptr; + int* p_index = nullptr; + p_src = src->mutable_data(make_ddim({1, 4}), CPUPlace()); + p_index = index->mutable_data(make_ddim({1}), CPUPlace()); + + for (size_t i = 0; i < 4; ++i) p_src[i] = float(i); + p_index[0] = 1; + + float* p_output = output->mutable_data(make_ddim({4, 4}), CPUPlace()); + + ScatterUpdate(CPUPlace(), src, index, output); + + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], float(0)); + for (size_t i = 0; i < 4; ++i) EXPECT_EQ(output->data()[i], float(0)); + for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], float(i - 4)); + for (size_t i = 4; i < 8; ++i) + EXPECT_EQ(output->data()[i], float(i - 4)); + for (size_t i = 8; i < 16; ++i) EXPECT_EQ(p_output[i], float(0)); + for (size_t i = 8; i < 16; ++i) EXPECT_EQ(output->data()[i], float(0)); +} diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index bf76df272b6faaed01ed8d715fe3b547ec7dc4e3..ad267e7f087943ff3b8326a7baf2ce3955fa51c2 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -51,6 +51,6 @@ param_out = param - learning_rate * grad; } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sgd, ops::SGDOp, ops::SGDOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(sgd, ops::SGDOp, ops::SGDOpMaker); REGISTER_OP_CPU_KERNEL(sgd, ops::SGDOpKernel); diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc index a7dfb624e5b779164eb07763eb604c548f6e89e7..d773a4f2d50e82146a729b1cda085ce86ade89cc 100644 --- a/paddle/operators/sigmoid_op.cc +++ b/paddle/operators/sigmoid_op.cc @@ -52,9 +52,8 @@ class SigmoidOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker); -REGISTER_GRADIENT_OP(sigmoid, sigmoid_grad, ops::SigmoidOpGrad); - +REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::SigmoidOpGrad); REGISTER_OP_CPU_KERNEL(sigmoid, ops::SigmoidKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 5d8ece1a254a58990bfb2f919567fa43689335b9..40c51a64c49bc064f55975ef6ced1d54070f1291 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -62,9 +62,9 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker); +REGISTER_OP(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, softmax_grad, + ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL(softmax, ops::SoftmaxKernel); -REGISTER_GRADIENT_OP(softmax, softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL( softmax_grad, ops::SoftmaxGradKernel); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 9d668e6085b93bc5a3a06683aa4470f62ae47c02..a0a0d4d914b37fca4250e5218a953f573611a086 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -81,7 +81,7 @@ Used to initialize tensor with uniform random generator. } // namespace operators } // namespace paddle -REGISTER_OP(uniform_random, paddle::operators::UniformRandomOp, - paddle::operators::UniformRandomOpMaker); +REGISTER_OP_WITHOUT_GRADIENT(uniform_random, paddle::operators::UniformRandomOp, + paddle::operators::UniformRandomOpMaker); REGISTER_OP_CPU_KERNEL(uniform_random, paddle::operators::CPUUniformRandomKernel); diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 4154aad15c39119e2f155cb2c7b5177b5aa78022..acfc0639736beb82df41b851664e7bcd079b5eb1 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -16,5 +16,8 @@ ELSE() set(GPU_CTX_DEPS) ENDIF() -cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) +# memcpy deoends on device_context, here add deps individually for +# avoiding cycle dependencies +cc_library(device_context SRCS device_context.cc DEPS memory buddy_allocator + system_allocator memory_block meta_data meta_cache place eigen3 ${GPU_CTX_DEPS}) nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index a928e097787db9deebe1c6eab263190caacac7eb..f92c15ae450e94de44d27e77763e791e6bae4426 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/platform/device_context.h" +#include "paddle/memory/memory.h" namespace paddle { namespace platform { @@ -36,6 +37,59 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } #ifndef PADDLE_ONLY_CPU +class EigenCudaStreamDevice : public Eigen::StreamInterface { + public: + EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { + Eigen::initializeDeviceProp(); + } + ~EigenCudaStreamDevice() override {} + + void Reinitialize(const cudaStream_t* cuda_stream, GPUPlace place) { + stream_ = cuda_stream; + place_ = place; + device_prop_ = &Eigen::m_deviceProperties[place.device]; + } + + const cudaStream_t& stream() const override { return *stream_; } + + const cudaDeviceProp& deviceProperties() const override { + return *device_prop_; + } + + void* allocate(size_t num_bytes) const override { + return paddle::memory::Alloc(place_, num_bytes); + } + + void deallocate(void* buffer) const override { + paddle::memory::Free(place_, buffer); + } + + void* scratchpad() const override { + if (scratch_ == NULL) { + scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int)); + } + return scratch_; + } + + unsigned int* semaphore() const override { + if (semaphore_ == NULL) { + char* scratch = + static_cast(scratchpad()) + Eigen::kCudaScratchSize; + semaphore_ = reinterpret_cast(scratch); + PADDLE_ENFORCE( + cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_)); + } + return semaphore_; + } + + private: + GPUPlace place_; + const cudaStream_t* stream_; // not owned; + const cudaDeviceProp* device_prop_; // not owned; + mutable void* scratch_; + mutable unsigned int* semaphore_; +}; + template <> Eigen::GpuDevice* DeviceContext::get_eigen_device() const { return reinterpret_cast(this)->eigen_device(); @@ -43,19 +97,9 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device() const { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { SetDeviceId(place_.device); - // TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly - // here will cause segment fault. We must implement a class derived from - // Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id - // later. Please refer to the implementation of class EigenCudaStreamDevice - // in TensorFlow. - // - // We find that CUDA 7 introduces a new option, the per-thread default stream, - // that has two effects. Please refer to https://devblogs.nvidia.com/ - // parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/ - // - // So, we decide to use default stream and add –default-stream per-thread nvcc - // flag. Than, two threads with two CUDADeviceContexts will run parallelly. - eigen_stream_.reset(new Eigen::CudaStreamDevice()); + PADDLE_ENFORCE(cudaStreamCreate(&stream_)); + eigen_stream_.reset(new EigenCudaStreamDevice()); + eigen_stream_->Reinitialize(&stream_, place); eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get())); } @@ -75,12 +119,13 @@ CUDADeviceContext::~CUDADeviceContext() { } eigen_stream_.reset(); eigen_device_.reset(); + PADDLE_ENFORCE(cudaStreamDestroy(stream_)); } Place CUDADeviceContext::GetPlace() const { return place_; } void CUDADeviceContext::Wait() const { - PADDLE_ENFORCE(cudaStreamSynchronize(0)); + PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); } Eigen::GpuDevice* CUDADeviceContext::eigen_device() const { @@ -91,6 +136,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() { if (!cublas_handle_) { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_)); + PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_)); } return cublas_handle_; } @@ -99,10 +145,13 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() { if (!cudnn_handle_) { SetDeviceId(place_.device); PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); + PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_)); } return cudnn_handle_; } +cudaStream_t CUDADeviceContext::stream() { return stream_; } + curandGenerator_t CUDADeviceContext::curand_generator() { if (!curand_generator_) { SetDeviceId(place_.device); @@ -110,6 +159,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() { CURAND_RNG_PSEUDO_DEFAULT)); PADDLE_ENFORCE( dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_)); + + PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_)); } return curand_generator_; } diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 08b5b2cff900cc4239a615fe7d7f6b5faa13510b..c5042ae33e47e04521e59e0d91ddd8d4efffe50a 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -52,6 +52,7 @@ class CPUDeviceContext : public DeviceContext { }; #ifndef PADDLE_ONLY_CPU +class EigenCudaStreamDevice; class CUDADeviceContext : public DeviceContext { public: @@ -76,6 +77,9 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return curand handle in the device context. */ curandGenerator_t curand_generator(); + + /*! \brief Return cuda stream in the device context. */ + cudaStream_t stream(); // clang-format on private: @@ -83,15 +87,16 @@ class CUDADeviceContext : public DeviceContext { private: std::unique_ptr eigen_device_; - std::unique_ptr eigen_stream_; + std::unique_ptr eigen_stream_; private: uint64_t seed_; // clang-format off - cudnnHandle_t cudnn_handle_ = nullptr; - cublasHandle_t cublas_handle_ = nullptr; - curandGenerator_t curand_generator_ = nullptr; + cudaStream_t stream_{nullptr}; + cudnnHandle_t cudnn_handle_{nullptr}; + cublasHandle_t cublas_handle_{nullptr}; + curandGenerator_t curand_generator_{nullptr}; // clang-format on }; diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index 65345c433c0a328e7f89038a39312edba35eb8c7..8b764bdcd9d92e6b2203e45160acee35ec110538 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -45,6 +45,7 @@ TEST(Device, CUDADeviceContext) { ASSERT_NE(nullptr, cublas_handle); curandGenerator_t curand_handle = device_context->curand_generator(); ASSERT_NE(nullptr, curand_handle); + ASSERT_NE(nullptr, device_context->stream()); delete device_context; } } diff --git a/paddle/platform/enforce.h b/paddle/platform/enforce.h index 337a059fb1494d500be0fd2437e59c863ae1563c..15fdf7a94f462a87f7edae1429eb0c4da0b17a84 100644 --- a/paddle/platform/enforce.h +++ b/paddle/platform/enforce.h @@ -14,14 +14,21 @@ limitations under the License. */ #pragma once -#include +#include // for dladdr +#include // for backtrace #include +#include #include #include #include + #include "paddle/string/printf.h" #include "paddle/string/to_string.h" +#ifdef __GNUC__ +#include // for __cxa_demangle +#endif + #ifndef PADDLE_ONLY_CPU #include "paddle/platform/dynload/cublas.h" @@ -39,6 +46,19 @@ limitations under the License. */ namespace paddle { namespace platform { +namespace { +#ifdef __GNUC__ +inline std::string demangle(std::string name) { + int status = -4; // some arbitrary value to eliminate the compiler warning + std::unique_ptr res{ + abi::__cxa_demangle(name.c_str(), NULL, NULL, &status), std::free}; + return (status == 0) ? res.get() : name; +} +#else +inline std::string demangle(std::string name) { return name; } +#endif +} + struct EnforceNotMet : public std::exception { std::exception_ptr exp_; std::string err_str_; @@ -48,15 +68,29 @@ struct EnforceNotMet : public std::exception { std::rethrow_exception(exp_); } catch (const std::exception& exp) { std::ostringstream sout; + sout << string::Sprintf("%s at [%s:%d]", exp.what(), f, l) << std::endl; - sout << "Call Stacks: " << std::endl; + sout << "PaddlePaddle Call Stacks: " << std::endl; + void* call_stack[TRACE_STACK_LIMIT]; - int sz = backtrace(call_stack, TRACE_STACK_LIMIT); - auto line = backtrace_symbols(call_stack, sz); - for (int i = 0; i < sz; ++i) { - sout << line[i] << std::endl; + auto size = backtrace(call_stack, TRACE_STACK_LIMIT); + auto symbols = backtrace_symbols(call_stack, size); + + Dl_info info; + for (int i = 0; i < size; ++i) { + if (dladdr(call_stack[i], &info)) { + auto demangled = demangle(info.dli_sname); + auto addr_offset = static_cast(call_stack[i]) - + static_cast(info.dli_saddr); + sout << string::Sprintf("%-3d %*0p %s + %zd\n", i, + 2 + sizeof(void*) * 2, call_stack[i], + demangled, addr_offset); + } else { + sout << string::Sprintf("%-3d %*0p %s\n", i, 2 + sizeof(void*) * 2, + call_stack[i]); + } } - free(line); + free(symbols); err_str_ = sout.str(); } } @@ -170,7 +204,7 @@ inline void throw_on_error(T e) { * PADDLE_ENFORCE_EQ(a, b); * * will raise an expression described as follows: - * "enforce a == b failed, 1 != 2" with detailed stack infomation. + * "enforce a == b failed, 1 != 2" with detailed stack information. * * extra messages is also supported, for example: * PADDLE_ENFORCE(a, b, "some simple enforce failed between %d numbers", 2) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 2f0205b7702b6d73b5348430f39166ec78f6c143..6c2f5fed405ecc80d5388084c7774d0ffb45f9af 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -82,10 +82,6 @@ EOF fi -# To build documentation, we need to run cmake again after installing -# PaddlePaddle. This awkwardness is due to -# https://github.com/PaddlePaddle/Paddle/issues/1854. It also -# describes a solution. if [[ ${WITH_DOC:-OFF} == "ON" ]]; then cat <=1.12 protobuf==3.1 -recordio +recordio>=0.1.0 matplotlib rarfile scipy>=0.19.0 diff --git a/python/setup.py.in b/python/setup.py.in index 38728aa2fd77cf3c882479ed83e99688b9ffa541..287442e013f91df1eed9c629b7767a660d5e30d7 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -23,6 +23,16 @@ with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f: if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: setup_requires+=["opencv-python"] +# the prefix is sys.prefix which should always be usr +paddle_bin_dir = 'local/opt/paddle/bin' +paddle_bins = ['${PADDLE_BINARY_DIR}/paddle/scripts/paddle_usage', + '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_trainer', + '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_merge_model', + '${PADDLE_BINARY_DIR}/paddle/pserver/paddle_pserver_main'] + +paddle_rt_lib_dir = 'local/lib' +paddle_rt_libs = [] if '${MKL_SHARED_LIBS}'== '' else '${MKL_SHARED_LIBS}'.split(';') + setup(name='paddlepaddle', version='${PADDLE_VERSION}', description='Parallel Distributed Deep Learning', @@ -42,9 +52,6 @@ setup(name='paddlepaddle', }, scripts=['${PADDLE_BINARY_DIR}/paddle/scripts/paddle'], distclass=BinaryDistribution, - data_files=[('/usr/local/opt/paddle/bin', - ['${PADDLE_BINARY_DIR}/paddle/scripts/paddle_usage', - '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_trainer', - '${PADDLE_BINARY_DIR}/paddle/trainer/paddle_merge_model', - '${PADDLE_BINARY_DIR}/paddle/pserver/paddle_pserver_main'])] + data_files=[(paddle_bin_dir, paddle_bins), + (paddle_rt_lib_dir, paddle_rt_libs)] )