diff --git a/Dockerfile b/Dockerfile index ed5910d93b41dba8d50b2ba01c59c635797edd29..8cfb16928c95dcbfac08383d32562ff67933d873 100644 --- a/Dockerfile +++ b/Dockerfile @@ -25,7 +25,7 @@ COPY ./paddle/scripts/docker/root/ /root/ RUN apt-get update && \ apt-get install -y \ git python-pip python-dev openssh-server bison \ - wget unzip tar xz-utils bzip2 gzip coreutils ntp \ + wget unzip unrar tar xz-utils bzip2 gzip coreutils ntp \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ python-numpy python-matplotlib gcc g++ \ automake locales clang-format-3.8 swig doxygen cmake \ diff --git a/cmake/configure.cmake b/cmake/configure.cmake index a4f98ec7d4af652d0dd0650f4906696ff3a4efb9..7afab5d5344b704a9329e313a81379032ba0cc97 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -102,12 +102,19 @@ if(WITH_GOLANG) message(FATAL_ERROR "no glide executeble found: $ENV{GOPATH}/bin/glide") endif() - add_custom_target(go_vendor) - add_custom_command(TARGET go_vendor + # this command will only run when the file it depends is missing + # or has changed, or the output is missing. + add_custom_command(OUTPUT ${CMAKE_BINARY_DIR}/glide COMMAND env GOPATH=${GOPATH} ${GLIDE} install + COMMAND touch ${CMAKE_BINARY_DIR}/glide + DEPENDS ${PROJ_ROOT}/go/glide.lock WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go" - ) - add_dependencies(go_vendor go_path) + ) + + # depends on the custom command which outputs + # ${CMAKE_BINARY_DIR}/glide, the custom command does not need to + # run every time this target is built. + add_custom_target(go_vendor DEPENDS ${CMAKE_BINARY_DIR}/glide go_path) endif() endif(WITH_GOLANG) diff --git a/cmake/cpplint.cmake b/cmake/cpplint.cmake index 6bbcd730e1b5ac49415cac676352e6df00eb6eb5..656e1a0803c6e389d70f37f592c3aa2e95a2bcd4 100644 --- a/cmake/cpplint.cmake +++ b/cmake/cpplint.cmake @@ -27,7 +27,8 @@ set(IGNORE_PATTERN .*cblas\\.h.* .*\\.pb\\.txt .*LtrDataProvider.* - .*MultiDataProvider.*) + .*MultiDataProvider.* + .*pb.*) # add_style_check_target # @@ -52,14 +53,13 @@ macro(add_style_check_target TARGET_NAME) endif() endforeach() if(LINT MATCHES ON) + # cpplint code style get_filename_component(base_filename ${filename} NAME) set(CUR_GEN ${CMAKE_CURRENT_BINARY_DIR}/${base_filename}.cpplint) - add_custom_command(OUTPUT ${CUR_GEN} - PRE_BUILD - COMMAND env ${py_env} "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py" - "--filter=${STYLE_FILTER}" - "--write-success=${CUR_GEN}" ${filename} - DEPENDS ${filename} + add_custom_command(TARGET ${TARGET_NAME} PRE_BUILD + COMMAND "${PYTHON_EXECUTABLE}" "${PROJ_ROOT}/paddle/scripts/cpplint.py" + "--filter=${STYLE_FILTER}" + "--write-success=${CUR_GEN}" ${filename} WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) endif() endforeach() diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 716955c7b42f3d05b3ec8387cf81dd9cb1c409bf..e42e75c12ab1e5133f5ecbdb90ef26e3f8df5133 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -104,6 +104,7 @@ function(merge_static_libs TARGET_NAME) foreach(lib ${libs}) list(APPEND libs_deps ${${lib}_LIB_DEPENDS}) endforeach() + list(REMOVE_DUPLICATES libs_deps) if(APPLE) # Use OSX's libtool to merge archives # To produce a library we need at least one source file. @@ -127,7 +128,7 @@ function(merge_static_libs TARGET_NAME) # Get the file names of the libraries to be merged set(libfiles ${libfiles} $) endforeach() - add_custom_command(TARGET ${TARGET_NAME} POST_BUILD + add_custom_command(TARGET ${TARGET_NAME} POST_BUILD COMMAND rm "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" COMMAND /usr/bin/libtool -static -o "${CMAKE_CURRENT_BINARY_DIR}/lib${TARGET_NAME}.a" ${libfiles}) else() # general UNIX: use "ar" to extract objects and re-add to a common lib @@ -145,11 +146,11 @@ function(merge_static_libs TARGET_NAME) DEPENDS ${lib} ${objdir} WORKING_DIRECTORY ${objdir}) - # Empty dummy source file that goes into merged library - set(mergebase ${lib}.mergebase.c) - add_custom_command(OUTPUT ${mergebase} - COMMAND ${CMAKE_COMMAND} -E touch ${mergebase} - DEPENDS ${objlistfile}) + # Empty dummy source file that goes into merged library + set(mergebase ${lib}.mergebase.c) + add_custom_command(OUTPUT ${mergebase} + COMMAND ${CMAKE_COMMAND} -E touch ${mergebase} + DEPENDS ${objlistfile}) list(APPEND mergebases "${mergebase}") endforeach() @@ -184,6 +185,10 @@ function(cc_library TARGET_NAME) add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) endif() + + # cpplint code style + add_style_check_target(${TARGET_NAME} ${cc_library_SRCS}) + else(cc_library_SRCS) if (cc_library_DEPS) merge_static_libs(${TARGET_NAME} ${cc_library_DEPS}) @@ -337,7 +342,7 @@ function(go_test TARGET_NAME) string(REPLACE "${PADDLE_GO_PATH}" "" CMAKE_CURRENT_SOURCE_REL_DIR ${CMAKE_CURRENT_SOURCE_DIR}) add_custom_target(${TARGET_NAME} ALL DEPENDS go_vendor ${go_test_DEPS}) add_custom_command(TARGET ${TARGET_NAME} POST_BUILD - COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test + COMMAND env GOPATH=${GOPATH} ${CMAKE_Go_COMPILER} test -race -c -o "${CMAKE_CURRENT_BINARY_DIR}/${TARGET_NAME}" ".${CMAKE_CURRENT_SOURCE_REL_DIR}" WORKING_DIRECTORY "${PADDLE_IN_GOPATH}/go") diff --git a/doc/api/v2/config/layer.rst b/doc/api/v2/config/layer.rst index 4f4a9187bcbe8ef902e923622552909808b121d6..daee55b7f9adfffdf709ed2b5b0d957c7ca1aea4 100644 --- a/doc/api/v2/config/layer.rst +++ b/doc/api/v2/config/layer.rst @@ -474,6 +474,11 @@ prelu .. autoclass:: paddle.v2.layer.prelu :noindex: +gated_unit +----------- +.. autoclass:: paddle.v2.layer.gated_unit + :noindex: + Detection output Layer ====================== diff --git a/go/pserver/client/client_test.go b/go/pserver/client/client_test.go index 27f4ff2380b3b5aa01485838eb4a876a8863d901..aab91556b4b91fab6de66322987eabe24f1b0f70 100644 --- a/go/pserver/client/client_test.go +++ b/go/pserver/client/client_test.go @@ -164,7 +164,7 @@ func testClient(t *testing.T, c *client.Client) { wg.Add(1) go func(gs []pserver.Gradient) { - err = c.SendGrads(gs) + err := c.SendGrads(gs) if err != nil { t.Fatal(err) } diff --git a/paddle/api/ConfigParser.cpp b/paddle/api/ConfigParser.cpp index 2f45173bfd401ddda26d61ab7fcfe131d079f710..b6ff6ec7890c0b79d52a2f0784743289c7bc213f 100644 --- a/paddle/api/ConfigParser.cpp +++ b/paddle/api/ConfigParser.cpp @@ -64,11 +64,7 @@ ModelConfig* TrainerConfig::getModelConfig() const { ParameterConfig::ParameterConfig() : m(new ParameterConfigPrivate()) {} -ParameterConfig::~ParameterConfig() { - if (m) { - delete m; - } -} +ParameterConfig::~ParameterConfig() { delete m; } ParameterConfig* ParameterConfig::createParameterConfigFromParameterSharedPtr( void* ptr) { @@ -98,11 +94,7 @@ void* ParameterConfig::getRawPtr() { return m->getConfigPtr(); } OptimizationConfig::OptimizationConfig() : m(new OptimizationConfigPrivate()) {} -OptimizationConfig::~OptimizationConfig() { - if (m) { - delete m; - } -} +OptimizationConfig::~OptimizationConfig() { delete m; } std::string OptimizationConfig::toProtoString() { return m->getConfig().SerializeAsString(); diff --git a/paddle/api/ParameterOptimizer.cpp b/paddle/api/ParameterOptimizer.cpp index 21b851dd5e26c4752888067b20d0b1e16a4ab52d..120eea3f70125a57fb5ad685f2a11479bce12d0c 100644 --- a/paddle/api/ParameterOptimizer.cpp +++ b/paddle/api/ParameterOptimizer.cpp @@ -53,11 +53,7 @@ struct ParameterTraverseCallbackPrivate { ParameterOptimizer::ParameterOptimizer() : m(new ParameterOptimizerPrivate()) {} -ParameterOptimizer::~ParameterOptimizer() { - if (m) { - delete m; - } -} +ParameterOptimizer::~ParameterOptimizer() { delete m; } ParameterOptimizer* ParameterOptimizer::create(OptimizationConfig* config) { CHECK(config != nullptr); @@ -104,11 +100,7 @@ std::vector ParameterOptimizer::getParameterTypes() const { ParameterTraverseCallback::ParameterTraverseCallback() : m(new ParameterTraverseCallbackPrivate()) {} -ParameterTraverseCallback::~ParameterTraverseCallback() { - if (m) { - delete m; - } -} +ParameterTraverseCallback::~ParameterTraverseCallback() { delete m; } void ParameterTraverseCallback::apply(const std::vector& vecs, const ParameterConfig& conf, diff --git a/paddle/api/Vector.cpp b/paddle/api/Vector.cpp index db8f005929d90f718fc1ad42c60b68108ff55005..500bc448c92630f4fc2f4df603c955e572d868ec 100644 --- a/paddle/api/Vector.cpp +++ b/paddle/api/Vector.cpp @@ -171,11 +171,7 @@ struct VectorPrivate { Vector::Vector() : m(new VectorPrivate()) {} -Vector::~Vector() { - if (m) { - delete m; - } -} +Vector::~Vector() { delete m; } Vector* Vector::createZero(size_t sz, bool useGpu) { auto retVec = new Vector(); diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 2a93779878312aa8a84a721641b39e87db9c5ef3..31b138423526f48de7c11d45990cabfa1b559b67 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -11,8 +11,10 @@ proto_library(op_proto SRCS op_proto.proto DEPS attr_type) cc_test(op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf) proto_library(op_desc SRCS op_desc.proto DEPS attr_type) cc_test(op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf) + cc_library(operator SRCS operator.cc DEPS op_desc device_context) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) + cc_library(op_registry SRCS op_registry.cc DEPS op_proto op_desc) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry operator) py_proto_compile(framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto) @@ -21,4 +23,5 @@ add_custom_target(framework_py_proto_init ALL COMMAND ${CMAKE_COMMAND} -E touch add_dependencies(framework_py_proto framework_py_proto_init) proto_library(net_proto SRCS net_proto.proto DEPS op_proto) -cc_library(net SRCS net.cc DEPS net_proto) +cc_library(net SRCS net.cc DEPS operator net_proto op_registry) +cc_test(net_op_test SRCS net_op_test.cc DEPS net) diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index fe8f79abd4fe5d94a8805fa2ddcd8103706dd083..87a3618e095c544b422746ed3f497b21f3824fbd 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -1,10 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include "paddle/framework/ddim.h" -#include "paddle/framework/enforce.h" namespace paddle { namespace framework { -///@cond HIDDEN +/// @cond HIDDEN template Dim make_dim(const int* d) { @@ -51,7 +64,7 @@ void make_ddim(DDim& ddim, const int* dims, int n) { } } -///@endcond +/// @endcond DDim make_ddim(std::initializer_list dims) { DDim result(make_dim(0)); @@ -65,11 +78,11 @@ DDim make_ddim(const std::vector& dims) { return result; } -///@cond HIDDEN +/// @cond HIDDEN // XXX For some reason, putting this in an anonymous namespace causes errors class DynamicMutableIndexer : public boost::static_visitor { public: - DynamicMutableIndexer(int idx) : idx_(idx) {} + explicit DynamicMutableIndexer(int idx) : idx_(idx) {} template int& operator()(Dim& dim) const { @@ -82,7 +95,7 @@ class DynamicMutableIndexer : public boost::static_visitor { class DynamicConstIndexer : public boost::static_visitor { public: - DynamicConstIndexer(int idx) : idx_(idx) {} + explicit DynamicConstIndexer(int idx) : idx_(idx) {} template int operator()(const Dim& dim) const { @@ -93,7 +106,7 @@ class DynamicConstIndexer : public boost::static_visitor { int idx_; }; -///@endcond +/// @endcond int& DDim::operator[](int idx) { return boost::apply_visitor(DynamicMutableIndexer(idx), var); @@ -156,11 +169,11 @@ int get(const DDim& ddim, int idx) { return ddim[idx]; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } -///@cond HIDDEN +/// @cond HIDDEN struct VectorizeVisitor : public boost::static_visitor<> { std::vector& vector; - VectorizeVisitor(std::vector& v) : vector(v) {} + explicit VectorizeVisitor(std::vector& v) : vector(v) {} template void operator()(const T& t) { @@ -170,7 +183,7 @@ struct VectorizeVisitor : public boost::static_visitor<> { void operator()(const Dim<1>& t) { vector.push_back(t.head); } }; -///@endcond +/// @endcond std::vector vectorize(const DDim& ddim) { std::vector result; @@ -188,7 +201,7 @@ ssize_t product(const DDim& ddim) { return result; } -///\cond HIDDEN +/// \cond HIDDEN struct ArityVisitor : boost::static_visitor { template @@ -197,15 +210,15 @@ struct ArityVisitor : boost::static_visitor { } }; -///\endcond +/// \endcond int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); } -///\cond HIDDEN +/// \cond HIDDEN struct DDimPrinter : boost::static_visitor { std::ostream& os; - DDimPrinter(std::ostream& os_) : os(os_) {} + explicit DDimPrinter(std::ostream& os_) : os(os_) {} template void operator()(const T& t) { @@ -213,7 +226,7 @@ struct DDimPrinter : boost::static_visitor { } }; -///\endcond +/// \endcond std::ostream& operator<<(std::ostream& os, const DDim& ddim) { DDimPrinter printer(os); @@ -221,16 +234,5 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { return os; } -template -Eigen::DSizes ToEigenDSizes(const DDim& dims) { - int rank = arity(dims); - PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); - Eigen::DSizes dsizes; - for (int d = 0; d < rank; d++) { - dsizes[d] = dims[d]; - } - return dsizes; -} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index 18395c3636cb710901d13b3660ac81a73270e1cd..bd51b4fada773c986b92a61ec754fcf141b90e16 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -1,11 +1,25 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include #include #include #include - #include "paddle/framework/dim.h" +#include "paddle/framework/enforce.h" #include "unsupported/Eigen/CXX11/Tensor" namespace paddle { @@ -28,7 +42,7 @@ struct DDim { DDim() : var(Dim<1>()) {} template - DDim(const Dim& in) : var(in) {} + explicit DDim(const Dim& in) : var(in) {} template DDim& operator=(const Dim& in) { @@ -93,7 +107,15 @@ int arity(const DDim& ddim); std::ostream& operator<<(std::ostream&, const DDim&); template -Eigen::DSizes ToEigenDSizes(const DDim& dims); +Eigen::DSizes ToEigenDSizes(const DDim& dims) { + int rank = arity(dims); + PADDLE_ENFORCE(rank == NDIMS, "DDim and NDIMS must be same"); + Eigen::DSizes dsizes; + for (int d = 0; d < rank; d++) { + dsizes[d] = dims[d]; + } + return dsizes; +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.cc b/paddle/framework/net.cc index 73b3051235ee90b31bd65acb22f454fc13d64da9..7311cda9a9ad282b21711d8eb0b9ba1cf9542296 100644 --- a/paddle/framework/net.cc +++ b/paddle/framework/net.cc @@ -1,20 +1,59 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + #include "paddle/framework/net.h" namespace paddle { namespace framework { -PlainNet::PlainNet(const NetDesc& def) {} - -void PlainNet::InferShape(Scope* scope) { +void PlainNet::CompleteAddOp() { + std::unordered_set input_set; + std::unordered_set output_set; + std::unordered_set temp_output; for (auto& op : ops_) { - op.InferShape(); + for (auto& ipt : op->inputs_) { + if (!Contains(output_set, ipt)) { // Not other op's output + input_set.insert(ipt); + } else { + temp_output.insert(ipt); + } + } + + for (auto& opt : op->outputs_) { + output_set.insert(opt); + } } -} + inputs_.reserve(input_set.size()); + std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_)); -void PlainNet::Run(std::shared_ptr scope, DeviceContext* ctx) { - for (auto& op : ops_) { - op.Run(ctx); + outputs_.reserve(output_set.size()); + std::vector tmp_index; + tmp_index.reserve(temp_output.size()); + int idx = 0; + for (auto& opt : output_set) { + if (Contains(temp_output, opt)) { + tmp_index.push_back(idx); + } + outputs_.push_back(opt); + ++idx; } + + attrs_["temporary_index"] = tmp_index; + add_op_done_ = true; } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/net.h b/paddle/framework/net.h index 76992e07282904fd1074bb0ced2367a8d20e3ec2..19a1620e29b86fbccfc112a5f85a1784a197dd0b 100644 --- a/paddle/framework/net.h +++ b/paddle/framework/net.h @@ -1,99 +1,51 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once +#include +#include #include "paddle/framework/net_proto.pb.h" #include "paddle/framework/op_proto.pb.h" +#include "paddle/framework/op_registry.h" #include "paddle/framework/scope.h" #include "paddle/platform/device_context.h" namespace paddle { namespace framework { -using namespace paddle::platform; - -// operator's index stored in a network. -typedef int OpIndex; -/** - * NOTE following codes are some definitions of unimplemented concepts. - * We write some basic implementation to make Net compilable. These APIs will - * keep updating if the concepts related are implemented. - */ - -struct OpDesc; -struct OpAttrs {}; - -class Operator { - public: - Operator(const OpDesc &def) {} - void InferShape() {} - void Run(DeviceContext *ctx) {} -}; - /** - * @brief Network that manage the operators it has. + * @brief Network is also a type of Operator + * + * It will manage the operators it has. * - * Network is the container and controller of a set of operators, user can build - * a real network from a NetDesc which is a protobuf message and use - * Network.Run() * to run all the operators in the network. + * Network is the container and controller of a set of operators. * A network object knows all Operators belonging to this network. Variables, * which are inputs and outputs of these operators, are created and managed by a * hierarchy of Scope objects. * - * This is the base class of network, all the networks should implement the apis + * This is the base class of network, all the networks should implement the APIs * it defines. */ -class Net { +class Net : public OperatorBase { public: - /** - * @brief Infer shapes of all inputs and outputs of operators. - */ - virtual void InferShape(Scope *scope) = 0; - /** - * @brief Run the network. - * - * Run all the operators and return success(true) or not, with all the - * variables are located in `scope`. `context` describes the detail execution - * environment for ops. `begin` and `end` specify the scope of `ops_` to run, - * If no positive indexes are provided, all operators in `ops_` will run. - */ - virtual void Run(std::shared_ptr scope, DeviceContext *ctx) = 0; - - /** - * @brief Add an Operator according to `def`. - */ - virtual OpIndex AddOp(const OpProto &def) = 0; - - /** - * @brief Add optimizer operators acctording to `attrs`. - */ - virtual void AddOptimizerOps(const OpAttrs &attrs) = 0; - - /** - * @brief Add backward operators. - */ - virtual void AddBackwardOps() = 0; - - /** - * @brief Create a network. - */ - static std::unique_ptr Create(const NetDesc &def = NetDesc()); - - virtual ~Net() {} + virtual void AddOp(const OperatorPtr& op) = 0; + virtual void CompleteAddOp() = 0; }; +using NetPtr = std::shared_ptr; + /** * @brief a basic implementation of Net. * @@ -103,18 +55,14 @@ class Net { class PlainNet : public Net { public: /** - * @brief Initialize a PlainNet. - * - * Initialize from a network describe by `def`. NetDesc is the definition of - * a network. - */ - PlainNet(const NetDesc &def); - - /** - * Infer all the operators' input and output varialbes' shapes, will be called + * Infer all the operators' input and output variables' shapes, will be called * before every mini-batch */ - virtual void InferShape(Scope *scope) override; + void InferShape(const ScopePtr& scope) const override { + for (auto& op : ops_) { + op->InferShape(scope); + } + } /** * @brief Run the network. @@ -123,48 +71,32 @@ class PlainNet : public Net { * scope will be used instead. If no OpContext is provicded, default context * will be used. */ - virtual void Run(std::shared_ptr scope, DeviceContext *ctx) override; + void Run(const ScopePtr& scope, + const platform::DeviceContext& dev_ctx) const override { + for (auto& op : ops_) { + op->Run(scope, dev_ctx); + } + } /** - * @brief Add an operator to this network. + * @brief Add an operator by ptr */ - virtual OpIndex AddOp(const OpProto &def) override; + void AddOp(const OperatorPtr& op) override { + PADDLE_ENFORCE(!add_op_done_, "Cannot AddOp when this network is sealed"); + ops_.push_back(op); + } - /** - * @brief Add all optimizer operators related into the network. - */ - virtual void AddOptimizerOps(const OpAttrs &attrs) override; + void CompleteAddOp() override; - /** - * @brief Add all backward operators related into the network. - */ - virtual void AddBackwardOps() override; - - virtual ~PlainNet() override {} - - protected: - /** - * @brief Build the network. - * - * Create operators accordding to `def`, will be called by the constructor. - */ - void BuildNet(const NetDesc &def); - - /** - * @brief Add an operator into this network. - * - * Add a operator which is identified as `type` and has attributes described - * in `attrs`, the `inputs` are the keys of readonly input variables, - * `outputs` are keys of mutable output variables. An `OpIndex` will be - * returned to indicate the offset of the new operator in `ops_`. - */ - OpIndex AddOp(const std::string &type, const std::vector &inputs, - const std::vector &outputs, - const OpAttrs &attrs = OpAttrs()); + std::vector ops_; private: - // the operators owned by `Network`. - std::vector ops_; + bool add_op_done_{false}; + + template + static bool Contains(T container, KeyType key) { + return container.find(key) != container.end(); + } }; } // namespace framework diff --git a/paddle/framework/net_op_test.cc b/paddle/framework/net_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f5e1c22400a73c3aa09839ef9654f87def99bc77 --- /dev/null +++ b/paddle/framework/net_op_test.cc @@ -0,0 +1,67 @@ +#include +#include +#include +#include + +namespace pd = paddle::framework; + +static int infer_shape_cnt = 0; +static int run_cnt = 0; + +class TestOp : public pd::OperatorBase { + public: + void InferShape(const paddle::framework::ScopePtr& scope) const override { + ++infer_shape_cnt; + } + void Run(const paddle::framework::ScopePtr& scope, + const paddle::platform::DeviceContext& dev_ctx) const override { + ++run_cnt; + } +}; + +template +void AssertSameVectorWithoutOrder(const std::vector& expected, + const std::vector& actual) { + ASSERT_EQ(expected.size(), actual.size()); + std::unordered_set expected_set; + for (auto& tmp : expected) { + expected_set.insert(tmp); + } + for (auto& act : actual) { + ASSERT_NE(expected_set.end(), expected_set.find(act)); + } +} + +TEST(OpKernel, all) { + auto net = std::make_shared(); + ASSERT_NE(net, nullptr); + + auto op1 = std::make_shared(); + op1->inputs_ = {"x", "w1", "b1"}; + op1->outputs_ = {"y"}; + net->AddOp(op1); + + auto op2 = std::make_shared(); + op2->inputs_ = {"y", "w2", "b2"}; + op2->outputs_ = {"z"}; + net->AddOp(op2); + + net->CompleteAddOp(); + AssertSameVectorWithoutOrder({"x", "w1", "b1", "w2", "b2"}, net->inputs_); + AssertSameVectorWithoutOrder({"y", "z"}, net->outputs_); + auto tmp_idx_iter = net->attrs_.find("temporary_index"); + ASSERT_NE(net->attrs_.end(), tmp_idx_iter); + auto& tmp_idx = boost::get>(tmp_idx_iter->second); + ASSERT_EQ(1UL, tmp_idx.size()); + ASSERT_EQ("y", net->outputs_[tmp_idx[0]]); + + auto scope = std::make_shared(); + paddle::platform::CPUDeviceContext dev_ctx; + + net->InferShape(scope); + net->Run(scope, dev_ctx); + ASSERT_EQ(2, infer_shape_cnt); + ASSERT_EQ(2, run_cnt); + + ASSERT_THROW(net->AddOp(op2), paddle::framework::EnforceNotMet); +} diff --git a/paddle/framework/op_proto.proto b/paddle/framework/op_proto.proto index 22df6f9c6b70277ddbf31e0432401889e3aa7483..596b8588e783722362815f75db876931f83484ec 100644 --- a/paddle/framework/op_proto.proto +++ b/paddle/framework/op_proto.proto @@ -34,6 +34,11 @@ message AttrProto { // Supported attribute comments. It helps 3rd-party language generate doc-string. required string comment = 3; + + // If that attribute is generated, it means the Paddle third language + // binding has responsibility to fill that attribute. End-User should + // not set that attribute. + optional bool generated = 4 [default=false]; } // Input or output message for 3rd-party language binding. @@ -45,6 +50,40 @@ message VarProto { // The comment for that input. It helps 3rd-party language generate doc-string. required string comment = 2; + + // Is that input/output could be a list or not. + // If so, that Op should write a attributed named `input_format` or + // `output_format`. + // + // e.g. + // If the op is a fc op, the inputs are `X`, `W`, `b`. The `X` and `W` + // could be multiple, so the multiple of `X` and `W` is True, and OpDesc + // will hold a attribute of them. + // + // The Op desc of same fc could be + // { + // "type": "fc", + // "input": ["X1", "X2", "W1", "W2", "b"], + // "output": "fc.out", + // "attrs" : { + // "input_format": [0, 2, 4, 5] + // } + // } + // + optional bool multiple = 3 [default=false]; + + // It marks that output is a temporary output. That output is not used by + // user, but used by other op internally as input. If other op is not use + // that output, it could be optimized early. + // + // Attribute temporary_index will be set in OpDesc if there is some + // outputs are temporary. + // + // output = [ "xxx.out1", "xxx.tmp", "xxx.out2"], + // attrs = { + // "temporary_index": [1] + // } + optional bool temporary = 4 [default=false]; } // Op protocol message for 3rd-party language binding. diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 4b35e04e681b414c36cf6d9aee9e64dd68ba5da9..1d14535c50b542733663a6900a8b5f2033290ea6 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include namespace paddle { @@ -33,4 +47,4 @@ void AttrTypeHelper::SetAttrType>(AttrProto* attr) { attr->set_type(paddle::framework::AttrType::STRINGS); } } // namespace framework -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index 61dfcb704964cd730a8fc9ab6ad394cd47cb4666..24f56b281282881fb12fc8f2d477da310df5db6f 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -2,6 +2,8 @@ #include #include +#include +#include #include "paddle/framework/attr_checker.h" #include "paddle/framework/op_desc.pb.h" #include "paddle/framework/op_proto.pb.h" @@ -59,25 +61,52 @@ class OpProtoAndCheckerMaker { OpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : proto_(proto), op_checker_(op_checker) {} + ~OpProtoAndCheckerMaker() { CheckNoDuplicatedAttrs(); } + protected: - void AddInput(const std::string& name, const std::string& comment) { + void AddInput(const std::string& name, const std::string& comment, + bool multiple = false) { auto input = proto_->mutable_inputs()->Add(); *input->mutable_name() = name; *input->mutable_comment() = comment; + input->set_multiple(multiple); + if (multiple) { + SetHasMultipleInput(); + } } - void AddOutput(const std::string& name, const std::string& comment) { + void AddInputs(const std::string& name, const std::string& comment) { + AddInput(name, comment, true); + } + + void AddOutput(const std::string& name, const std::string& comment, + bool temporary = false, bool multiple = false) { auto output = proto_->mutable_outputs()->Add(); *output->mutable_name() = name; *output->mutable_comment() = comment; + output->set_multiple(multiple); + if (multiple) { + SetHasMultipleOutput(); + } + output->set_temporary(temporary); + if (temporary) { + SetHasTemporaryOutput(); + } + } + + void AddOutputs(const std::string& name, const std::string& comment, + bool temporary = false) { + AddOutput(name, comment, temporary, true); } template TypedAttrChecker& AddAttr(const std::string& name, - const std::string& comment) { + const std::string& comment, + bool generated = false) { auto attr = proto_->mutable_attrs()->Add(); *attr->mutable_name() = name; *attr->mutable_comment() = comment; + attr->set_generated(generated); AttrTypeHelper::SetAttrType(attr); return op_checker_->AddAttrChecker(name); } @@ -86,8 +115,70 @@ class OpProtoAndCheckerMaker { *(proto_->mutable_comment()) = comment; } + private: + void SetHasMultiple(const std::string& in_out, bool* flag) { + if (!*flag) { + AddAttr>(in_out + "_format", + "The multiple index of " + in_out + + "\n" + R"DOC( +This attribute is used by Paddle core framework. Paddle's Op support each input +or output could be a list of variable. This attribute is used to show how that +list organized. + +e.g. + input = ["a", "b", "c", "d", "e", "f"] + input_format = [0, 4, 5, 6] + +means + The number of all input variables this op is six, and they are segmented into + three inputs. + + The first input is input[0:4], second is input[4:5], third is input[5:6]. +)DOC", + /*generated*/ true); + *flag = true; + } + } + + void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); } + void SetHasMultipleOutput() { + SetHasMultiple("output", &has_multiple_output_); + } + + void SetHasTemporaryOutput() { + if (!has_temporary_output_) { + AddAttr>("temporary_index", + R"DOC(The temporary index of output. + +Not all output of Paddle Op is used by user. For faster computation, each op +could output some its internal state to other op, other op could take that +output to make compute faster. + +Add a mark to which output is temporary is helpful for future optimization. +)DOC", + /*generated*/ true) + .SetDefault(std::vector()); + has_temporary_output_ = true; + } + } + + void CheckNoDuplicatedAttrs() { + std::unordered_set names; + size_t cnt = 0; + for (auto& attr : proto_->attrs()) { + names.insert(attr.name()); + ++cnt; + } + PADDLE_ENFORCE(names.size() == cnt, + "Cannot register two attribute in same name!"); + } + OpProto* proto_; OpAttrChecker* op_checker_; + bool has_multiple_input_{false}; + bool has_multiple_output_{false}; + bool has_temporary_output_{false}; }; class OpRegistry { @@ -107,10 +198,10 @@ class OpRegistry { op_type, op_proto.InitializationErrorString()); } - static OperatorBase* CreateOp(const OpDesc& op_desc) { + static OperatorPtr CreateOp(const OpDesc& op_desc) { std::string op_type = op_desc.type(); - OperatorBase* op = creators().at(op_type)(); - op->desc_ = op_desc; + OperatorPtr op(creators().at(op_type)()); + op->type_ = op_desc.type(); op->inputs_.reserve((size_t)op_desc.inputs_size()); std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), std::back_inserter(op->inputs_)); @@ -125,17 +216,17 @@ class OpRegistry { return op; } + static std::unordered_map& protos() { + static std::unordered_map protos_; + return protos_; + }; + private: static std::unordered_map& creators() { static std::unordered_map creators_; return creators_; } - static std::unordered_map& protos() { - static std::unordered_map protos_; - return protos_; - }; - static std::unordered_map& op_checkers() { static std::unordered_map op_checkers_; return op_checkers_; @@ -150,12 +241,18 @@ class OpRegisterHelper { } }; +/** + * check if MACRO is used in GLOBAL NAMESPACE. + */ #define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \ struct __test_global_namespace_##uniq_name##__ {}; \ static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \ __test_global_namespace_##uniq_name##__>::value, \ msg) +/** + * Macro to Register Operator. + */ #define REGISTER_OP(__op_type, __op_class, __op_maker_class) \ STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \ "REGISTER_OP must be in global namespace"); \ @@ -163,9 +260,12 @@ class OpRegisterHelper { __op_register_##__op_type##__(#__op_type); \ int __op_register_##__op_type##_handle__() { return 0; } -#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \ +/** + * Macro to Register OperatorKernel. + */ +#define REGISTER_OP_KERNEL(type, DEVICE_TYPE, PlaceType, KernelType) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ - __reg_op_kernel_##type##_##GPU_OR_CPU##__, \ + __reg_op_kernel_##type##_##DEVICE_TYPE##__, \ "REGISTER_OP_KERNEL must be in global namespace"); \ struct __op_kernel_register__##type##__ { \ __op_kernel_register__##type##__() { \ @@ -176,7 +276,7 @@ class OpRegisterHelper { } \ }; \ static __op_kernel_register__##type##__ __reg_kernel_##type##__; \ - int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; } + int __op_kernel_register_##type##_handle_##DEVICE_TYPE##__() { return 0; } #define REGISTER_OP_GPU_KERNEL(type, KernelType) \ REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType) @@ -184,6 +284,10 @@ class OpRegisterHelper { #define REGISTER_OP_CPU_KERNEL(type, KernelType) \ REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType) +/** + * Macro to mark what Operator and Kernel we will use and tell the compiler to + * link them into target. + */ #define USE_OP_WITHOUT_KERNEL(op_type) \ STATIC_ASSERT_GLOBAL_NAMESPACE( \ __use_op_without_kernel_##op_type, \ @@ -201,15 +305,16 @@ class OpRegisterHelper { __attribute__((unused)) = \ __op_kernel_register_##op_type##_handle_##DEVICE_TYPE##__() -#ifdef PADDLE_ONLY_CPU -#define USE_OP(op_type) \ +// use Operator with only cpu kernel. +#define USE_OP_CPU(op_type) \ USE_OP_WITHOUT_KERNEL(op_type); \ - USE_OP_KERNEL(op_type, CPU); + USE_OP_KERNEL(op_type, CPU) +#ifdef PADDLE_ONLY_CPU +#define USE_OP(op_type) USE_OP_CPU(op_type) #else -#define USE_OP(op_type) \ - USE_OP_WITHOUT_KERNEL(op_type); \ - USE_OP_KERNEL(op_type, CPU); \ +#define USE_OP(op_type) \ + USE_OP_CPU(op_type); \ USE_OP_KERNEL(op_type, GPU) #endif diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 9bcc0407addca555e0b47b8178f6304396ce81fc..4791d4aaab4cfe19d1cca7741ce259f8f7aeb18a 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -5,9 +5,9 @@ namespace paddle { namespace framework { class CosineOp : public OperatorBase { public: - void Run(const std::shared_ptr& scope, + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override {} - void InferShape(const std::shared_ptr& scope) const override {} + void InferShape(const ScopePtr& scope) const override {} }; class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { @@ -25,8 +25,8 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { class MyTestOp : public OperatorBase { public: - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const ScopePtr& scope) const override {} + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override {} public: @@ -36,8 +36,9 @@ class MyTestOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: MyTestOpProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("input", "input of cosine op"); - AddOutput("output", "output of cosine op"); + AddInputs("input", "input of cosine op"); + AddOutput("output", "output of cosine op", + /*temporary*/ true); auto my_checker = [](int i) { PADDLE_ENFORCE(i % 2 == 0, "'test_attr' must be even!"); }; @@ -66,7 +67,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - paddle::framework::OperatorBase* op = + paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); paddle::platform::CPUDeviceContext dev_ctx; @@ -88,7 +89,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OperatorBase* op __attribute__((unused)) = + paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -109,7 +110,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - paddle::framework::OperatorBase* op = + paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); auto scope = std::make_shared(); paddle::platform::CPUDeviceContext dev_ctx; @@ -117,16 +118,25 @@ TEST(OpRegistry, DefaultValue) { ASSERT_EQ(op->GetAttr("scale"), 1.0); } +static void SetInputFormat(paddle::framework::OpDesc* desc) { + auto attr = desc->add_attrs(); + attr->set_name("input_format"); + attr->set_type(paddle::framework::INTS); + attr->mutable_ints()->Add(0); + attr->mutable_ints()->Add(1); +} + TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); op_desc.add_inputs("ii"); op_desc.add_outputs("oo"); + SetInputFormat(&op_desc); // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OperatorBase* op __attribute__((unused)) = + paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -145,7 +155,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OperatorBase* op __attribute__((unused)) = + paddle::framework::OperatorPtr op __attribute__((unused)) = paddle::framework::OpRegistry::CreateOp(op_desc); } catch (paddle::framework::EnforceNotMet err) { caught = true; @@ -163,7 +173,8 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - paddle::framework::OperatorBase* op = + SetInputFormat(&op_desc); + paddle::framework::OperatorPtr op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; auto scope = std::make_shared(); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 3c6376c1503f3c3d816e0500b9ad79a99857ef20..aa859591f08cfff10e2fe40016b6aec24333995f 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -34,7 +34,7 @@ DeviceType* OpKernel::KernelContext::get_eigen_device() std::string OperatorBase::DebugString() const { std::stringstream ss; ss << "=================\n"; - ss << "type = " << desc_.type() << "\n"; + ss << "type = " << type_ << "\n"; ss << "inputs = ["; for (auto& ipt : inputs_) { ss << ipt << ", "; @@ -54,4 +54,4 @@ std::string OperatorBase::DebugString() const { } } // namespace framework -} // namespace paddle \ No newline at end of file +} // namespace paddle diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 558d4a0b6769ad11e775e6eeb5a48bcbf48067f3..c48d990eb275ce60d83932f3b82c592077801718 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -45,7 +45,7 @@ struct EigenDeviceConverter { #endif class OperatorBase; - +using OperatorPtr = std::shared_ptr; /** * OperatorBase has the basic element that Net will call to do computation. * Only CreateOperator from OpRegistry will new Operator directly. User @@ -71,17 +71,14 @@ class OperatorBase { /// InferShape infer the size of Variables used by this Operator with /// information inside scope - virtual void InferShape(const std::shared_ptr& scope) const = 0; + virtual void InferShape(const ScopePtr& scope) const = 0; /// Net will call this function to Run an op. - virtual void Run(const std::shared_ptr& scope, + virtual void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const = 0; - protected: - std::string Type() const { return desc_.type(); } - public: - OpDesc desc_; + std::string type_; std::vector inputs_; std::vector outputs_; AttributeMap attrs_; @@ -97,7 +94,7 @@ class OpKernel { */ class KernelContext { public: - KernelContext(const OperatorBase* op, const std::shared_ptr& scope, + KernelContext(const OperatorBase* op, const ScopePtr& scope, const platform::DeviceContext& device_context) : op_(*op), scope_(scope), device_context_(device_context) {} @@ -115,7 +112,7 @@ class OpKernel { DeviceType* get_eigen_device() const; const OperatorBase& op_; - const std::shared_ptr& scope_; + const ScopePtr& scope_; const platform::DeviceContext& device_context_; }; @@ -160,9 +157,9 @@ class OperatorWithKernel : public OperatorBase { using OpKernelMap = std::unordered_map, OpKernelHash>; - void Run(const std::shared_ptr& scope, + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const final { - auto& opKernel = AllOpKernels().at(Type()).at(OpKernelKey(dev_ctx)); + auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); opKernel->Compute(OpKernel::KernelContext(this, scope, dev_ctx)); } diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index 204b601a4aad52ee57b81235851c9806b607799f..19ac4ecafa21d0a6fde57ef5e867670d7823fde0 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -19,14 +19,18 @@ limitations under the License. */ namespace paddle { namespace framework { -class OperatorTest : public OperatorBase { +static int op_run_num = 0; + +class OpWithoutKernelTest : public OperatorBase { public: void Init() override { x = 1; } - void InferShape(const std::shared_ptr& scope) const override {} - void Run(const std::shared_ptr& scope, + void InferShape(const ScopePtr& scope) const override {} + void Run(const ScopePtr& scope, const platform::DeviceContext& dev_ctx) const override { - float scale = GetAttr("scale"); - ASSERT_NEAR(scale, 3.14, 1e-5); + op_run_num++; + ASSERT_EQ((int)inputs_.size(), 1); + ASSERT_EQ((int)outputs_.size(), 1); + ASSERT_NEAR(GetAttr("scale"), 3.14, 1e-5); ASSERT_EQ(scope->GetVariable(inputs_[0]), nullptr); ASSERT_EQ(x, 1); ASSERT_NE(scope->GetVariable(outputs_[0]), nullptr); @@ -36,19 +40,61 @@ class OperatorTest : public OperatorBase { float x = 0; }; +class OpeWithoutKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { + public: + OpeWithoutKernelTestProtoAndCheckerMaker(OpProto* proto, + OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("input", "input of test op"); + AddOutput("output", "output of test op"); + AddAttr("scale", "scale of cosine op"); + AddComment("This is test op"); + } +}; + +} // namespace framework +} // namespace paddle + +REGISTER_OP(test_operator, paddle::framework::OpWithoutKernelTest, + paddle::framework::OpeWithoutKernelTestProtoAndCheckerMaker); + +TEST(OperatorBase, all) { + paddle::framework::OpDesc op_desc; + op_desc.set_type("test_operator"); + *op_desc.mutable_inputs()->Add() = "IN1"; + *op_desc.mutable_outputs()->Add() = "OUT1"; + auto attr = op_desc.mutable_attrs()->Add(); + attr->set_name("scale"); + attr->set_type(paddle::framework::AttrType::FLOAT); + attr->set_f(3.14); + + paddle::platform::CPUDeviceContext device_context; + auto scope = std::make_shared(); + + paddle::framework::OperatorPtr op = + paddle::framework::OpRegistry::CreateOp(op_desc); + scope->CreateVariable("OUT1"); + ASSERT_EQ(paddle::framework::op_run_num, 0); + op->Run(scope, device_context); + ASSERT_EQ(paddle::framework::op_run_num, 1); +} + +namespace paddle { +namespace framework { + class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker { public: OpKernelTestProtoAndCheckerMaker(OpProto* proto, OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("input", "input of test op"); AddOutput("output", "output of test op"); - AddAttr("scale", "scale of cosine op") - .SetDefault(1.0) - .LargerThan(0.0); + AddAttr("scale", "scale of cosine op"); AddComment("This is test op"); } }; +static int cpu_kernel_run_num = 0; + class OpWithKernelTest : public OperatorWithKernel { protected: void InferShape(const std::vector& inputs, @@ -58,10 +104,10 @@ class OpWithKernelTest : public OperatorWithKernel { class CPUKernelTest : public OpKernel { public: void Compute(const KernelContext& context) const { - float scale = context.op_.GetAttr("scale"); - ASSERT_NEAR(scale, 3.14, 1e-5); - std::cout << "this is cpu kernel" << std::endl; - std::cout << context.op_.DebugString() << std::endl; + cpu_kernel_run_num++; + ASSERT_EQ((int)context.op_.inputs_.size(), 1); + ASSERT_EQ((int)context.op_.outputs_.size(), 1); + ASSERT_NEAR(context.op_.GetAttr("scale"), 3.14, 1e-5); } }; @@ -73,9 +119,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest, REGISTER_OP_CPU_KERNEL(op_with_kernel, paddle::framework::CPUKernelTest); TEST(OpKernel, all) { - using namespace paddle::framework; - - OpDesc op_desc; + paddle::framework::OpDesc op_desc; op_desc.set_type("op_with_kernel"); *op_desc.mutable_inputs()->Add() = "IN1"; *op_desc.mutable_outputs()->Add() = "OUT1"; @@ -85,10 +129,11 @@ TEST(OpKernel, all) { attr->set_f(3.14); paddle::platform::CPUDeviceContext cpu_device_context; - auto scope = std::make_shared(); + auto scope = std::make_shared(); - OperatorBase* op = paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OperatorPtr op = + paddle::framework::OpRegistry::CreateOp(op_desc); + ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_device_context); - - delete op; + ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); } diff --git a/paddle/framework/scope.h b/paddle/framework/scope.h index a4470f726fb0d59a82db29b3239c111ea1569c55..ec62c9189fd2a5ea74c6d6e5635a4d500e4823e2 100644 --- a/paddle/framework/scope.h +++ b/paddle/framework/scope.h @@ -23,6 +23,9 @@ limitations under the License. */ namespace paddle { namespace framework { +class Scope; +using ScopePtr = std::shared_ptr; + /** * @brief Scope that manage all variables. * @@ -41,7 +44,7 @@ class Scope { /** * @brief Initialize a Scope with parent. */ - explicit Scope(const std::shared_ptr& parent) : parent_(parent) {} + explicit Scope(const ScopePtr& parent) : parent_(parent) {} /** * @brief Create Variable @@ -88,7 +91,7 @@ class Scope { private: std::unordered_map> vars_; - std::shared_ptr parent_{nullptr}; + ScopePtr parent_{nullptr}; }; } // namespace framework diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 784d52cc426b92b509c320e86c0f8e0b6b0e2c13..30e00d0e0fa64f459f02a1b696fda2f2ea71ffd8 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -15,8 +15,8 @@ limitations under the License. */ #pragma once #include +#include #include -#include #include "paddle/framework/ddim.h" #include "paddle/framework/enforce.h" #include "paddle/framework/tensor_types.h" @@ -29,68 +29,78 @@ namespace framework { class Tensor { public: - Tensor() : offset_(0) {} + Tensor() : numel_(0), offset_(0) {} - explicit Tensor(const DDim& dims) : dims_(dims), offset_(0) {} + Tensor& operator=(const Tensor& src) = delete; template + const T* data() const { + CheckDims(); + return reinterpret_cast( + reinterpret_cast(holder_->ptr()) + offset_); + } - T* data() const { - PADDLE_ENFORCE( - holder_ != nullptr, - "Tenosr has not been initialized. Call Tensor::mutable_data first."); - return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + template + T* raw_data() const { + CheckDims(); + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } - template ::value>::type* = nullptr> + template T* mutable_data(DDim dims, paddle::platform::Place place) { - if (holder_ == nullptr || - !(holder_->Place() == - place) /* some versions of boost::variant don't have operator!= */ - || holder_->Size() < product(dims) * sizeof(T) + offset_) { - holder_.reset(new PlaceholderImpl(place, product(dims) * sizeof(T))); - dims_ = dims; - offset_ = 0; - } - return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + - offset_); + set_dims(dims); + return mutable_data(place); } - template ::value>::type* = nullptr> + template T* mutable_data(paddle::platform::Place place) { + PADDLE_ENFORCE(numel_ > 0, + "Tensor::numel_ must be larger than zero to call " + "Tensor::mutable_data. Call Tensor::set_dim first."); if (holder_ == nullptr || - !(holder_->Place() == + !(holder_->place() == place) /* some versions of boost::variant don't have operator!= */ - || holder_->Size() < product(dims_) * sizeof(T) + offset_) { - holder_.reset(new PlaceholderImpl(place, product(dims_) * sizeof(T))); + || holder_->size() < numel_ * sizeof(T) + offset_) { +#ifdef __CUDACC__ + switch (place.which()) { + case 0: + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); + break; + + case 1: + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); + break; + } +#else + holder_.reset(new PlaceholderImpl( + boost::get(place), numel_ * sizeof(T))); +#endif offset_ = 0; } - return reinterpret_cast(reinterpret_cast(holder_->Ptr()) + + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } - size_t NumElements() const { return product(dims_); } - template typename TTypes::Tensor shaped(DDim new_dims) { Eigen::array dims = paddle::framework::ToEigenDSizes(new_dims); - return typename TTypes::Tensor(data(), dims); + return typename TTypes::Tensor(raw_data(), dims); } template typename TTypes::Tensor tensor() { return typename TTypes::Tensor( - data(), paddle::framework::ToEigenDSizes(dims_)); + raw_data(), paddle::framework::ToEigenDSizes(dims_)); } // flat to rank = 1 template typename TTypes::Flat flat() { - return shaped(make_ddim({static_cast(NumElements())})); + return shaped(make_ddim({static_cast(numel_)})); } // to TensorType Vec @@ -106,6 +116,13 @@ class Tensor { } // const versions of all the methods above. + template + typename TTypes::Tensor shaped(DDim new_dims) const { + Eigen::array dims = + paddle::framework::ToEigenDSizes(new_dims); + return typename TTypes::Tensor(data(), dims); + } + template typename TTypes::ConstantTensor tensor() const { return typename TTypes::Tensor( @@ -114,7 +131,7 @@ class Tensor { template typename TTypes::ConstFlat flat() const { - return shaped(make_ddim({static_cast(NumElements())})); + return shaped(make_ddim({static_cast(numel_)})); } template @@ -127,17 +144,30 @@ class Tensor { return tensor(); } + template void ShareDataFrom(const Tensor& src) { - PADDLE_ENFORCE(src.holder_ != nullptr, - "Can not share data from an uninitialized tensor."); + src.CheckDims(); holder_ = src.holder_; - dims_ = src.dims_; + set_dims(src.dims()); offset_ = src.offset_; } + template + void CopyFrom(const Tensor& src, paddle::platform::Place dst_place) { + PADDLE_ENFORCE(platform::is_cpu_place(src.holder_->place()) && + platform::is_cpu_place(dst_place), + "Tensor::CopyFrom only support CPU now."); + src.CheckDims(); + size_t size = src.numel_ * sizeof(T); + set_dims(src.dims()); + const void* src_ptr = static_cast(src.data()); + void* dst_ptr = static_cast(mutable_data(dst_place)); + memcpy(dst_ptr, src_ptr, size); + } + + template Tensor Slice(const int& begin_idx, const int& end_idx) const { - PADDLE_ENFORCE(holder_ != nullptr, - "The sliced tenosr has not been initialized."); + CheckDims(); PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0], "Slice index is less than zero or out of bound."); PADDLE_ENFORCE(begin_idx < end_idx, @@ -150,12 +180,21 @@ class Tensor { } Tensor dst; dst.holder_ = holder_; - dst.dims_ = dims_; - dst.dims_[0] = end_idx - begin_idx; - dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize(); + DDim dst_dims = dims_; + dst_dims[0] = end_idx - begin_idx; + dst.set_dims(dst_dims); + dst.offset_ = offset_ + begin_idx * base * sizeof(T); return dst; } + void set_dims(const DDim& dims) { + if (dims == dims_) { + return; + } + dims_ = dims; + numel_ = product(dims_); + } + DDim dims() const { return dims_; } private: @@ -163,45 +202,54 @@ class Tensor { // parameter of Variable. struct Placeholder { virtual ~Placeholder() {} - virtual void* Ptr() const = 0; - virtual paddle::platform::Place Place() const = 0; - virtual size_t Size() const = 0; - virtual size_t TypeSize() const = 0; + virtual void* ptr() const = 0; + virtual paddle::platform::Place place() const = 0; + virtual size_t size() const = 0; }; - template + template struct PlaceholderImpl : public Placeholder { private: + template class Deleter { public: - Deleter(platform::Place place) : place_(place) {} + Deleter(PType place) : place_(place) {} void operator()(T* ptr) { paddle::memory::Free(place_, static_cast(ptr)); } private: - paddle::platform::Place place_; + PType place_; }; public: - PlaceholderImpl(paddle::platform::Place place, size_t size) + PlaceholderImpl(PlaceType place, size_t size) : ptr_(static_cast(paddle::memory::Alloc(place, size)), - Deleter(place)), + Deleter(place)), place_(place), size_(size) {} - virtual void* Ptr() const { return static_cast(ptr_.get()); } - virtual size_t Size() const { return size_; } - virtual paddle::platform::Place Place() const { return place_; } - virtual size_t TypeSize() const { return sizeof(T); } + virtual void* ptr() const { return static_cast(ptr_.get()); } + virtual size_t size() const { return size_; } + virtual paddle::platform::Place place() const { return place_; } - std::unique_ptr ptr_; + std::unique_ptr> ptr_; paddle::platform::Place place_; // record the place of ptr_. size_t size_; // size of the memory block. }; + template + inline void CheckDims() const { + PADDLE_ENFORCE(holder_ != nullptr, + "Tenosr holds no memory. Call Tensor::mutable_data first."); + PADDLE_ENFORCE(holder_->size() >= numel_ * sizeof(T) + offset_, + "Tensor's dims_ is out of bound. Call Tensor::mutable_data " + "first to re-allocate memory."); + } + std::shared_ptr holder_; // holds the memory block if allocated. DDim dims_; + size_t numel_; // cache of `product(dims_)` size_t offset_; // marks the begin of tensor data area. }; diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc index f4822838cfbd27656232a23b14f716f2fbe510e0..255f69372f4f06609471b9ff7a9b9ce790fcddf0 100644 --- a/paddle/framework/tensor_test.cc +++ b/paddle/framework/tensor_test.cc @@ -18,7 +18,8 @@ TEST(Tensor, Dims) { using namespace paddle::framework; using namespace paddle::platform; - Tensor tt(make_ddim({2, 3, 4})); + Tensor tt; + tt.set_dims(make_ddim({2, 3, 4})); DDim dims = tt.dims(); ASSERT_EQ(arity(dims), 3); for (int i = 0; i < 3; ++i) { @@ -35,7 +36,7 @@ TEST(Tensor, DataAssert) { } catch (paddle::framework::EnforceNotMet err) { caught = true; std::string msg = - "Tenosr has not been initialized. Call Tensor::mutable_data first."; + "Tenosr holds no memory. Call Tensor::mutable_data first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); ++i) { ASSERT_EQ(what[i], msg[i]); @@ -104,19 +105,18 @@ TEST(Tensor, ShareDataFrom) { // Try to share data form uninitialized tensor bool caught = false; try { - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataFrom(src_tensor); } catch (EnforceNotMet err) { caught = true; - std::string msg = "Can not share data from an uninitialized tensor."; - const char* what = err.what(); - for (size_t i = 0; i < msg.length(); ++i) { - ASSERT_EQ(what[i], msg[i]); + std::string msg = "Tenosr holds no memory. Call Tensor::mutable_data +first."; const char* what = err.what(); for (size_t i = 0; i < msg.length(); +++i) { ASSERT_EQ(what[i], msg[i]); } } ASSERT_TRUE(caught); src_tensor.mutable_data(make_ddim({2, 3, 4}), CPUPlace()); - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataFrom(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } @@ -124,7 +124,7 @@ TEST(Tensor, ShareDataFrom) { Tensor src_tensor; Tensor dst_tensor; src_tensor.mutable_data(make_ddim({2, 3, 4}), GPUPlace()); - dst_tensor.ShareDataFrom(src_tensor); + dst_tensor.ShareDataFrom(src_tensor); ASSERT_EQ(src_tensor.data(), dst_tensor.data()); } } @@ -135,7 +135,7 @@ TEST(Tensor, Slice) { { Tensor src_tensor; src_tensor.mutable_data(make_ddim({5, 3, 4}), CPUPlace()); - Tensor slice_tensor = src_tensor.Slice(1, 3); + Tensor slice_tensor = src_tensor.Slice(1, 3); DDim slice_dims = slice_tensor.dims(); ASSERT_EQ(arity(slice_dims), 3); EXPECT_EQ(slice_dims[0], 2); @@ -158,7 +158,7 @@ TEST(Tensor, Slice) { { Tensor src_tensor; src_tensor.mutable_data(make_ddim({6, 9}), GPUPlace()); - Tensor slice_tensor = src_tensor.Slice(2, 6); + Tensor slice_tensor = src_tensor.Slice(2, 6); DDim slice_dims = slice_tensor.dims(); ASSERT_EQ(arity(slice_dims), 2); EXPECT_EQ(slice_dims[0], 4); @@ -178,4 +178,29 @@ TEST(Tensor, Slice) { } } +TEST(Tensor, CopyFrom) { + using namespace paddle::framework; + using namespace paddle::platform; + + Tensor src_tensor; + int* src_ptr = src_tensor.mutable_data(make_ddim({3, 3}), CPUPlace()); + int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9}; + memcpy(src_ptr, arr, 9 * sizeof(int)); + Tensor dst_tensor; + dst_tensor.CopyFrom(src_tensor, CPUPlace()); + const int* dst_ptr = dst_tensor.data(); + ASSERT_NE(src_ptr, dst_ptr); + for (size_t i = 0; i < 9; ++i) { + EXPECT_EQ(src_ptr[i], dst_ptr[i]); + } + + Tensor slice_tensor = src_tensor.Slice(1, 2); + dst_tensor.CopyFrom(slice_tensor, CPUPlace()); + const int* slice_ptr = slice_tensor.data(); + dst_ptr = dst_tensor.data(); + ASSERT_NE(dst_ptr, slice_ptr); + for (size_t i = 0; i < 3; ++i) { + EXPECT_EQ(dst_ptr[i], slice_ptr[i]); + } +} */ \ No newline at end of file diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index a40e5d9d2e76605525f0956445fc43c693933cf8..00880effc59cc80b2761fb6a4d9f3246439afd3f 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -117,8 +117,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); @@ -217,8 +216,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& output = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& input = outputs[0].shape(); @@ -311,8 +309,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& output = inputs[0].shape(); const TensorShape& input = inputs[1].shape(); const TensorShape& filter = outputs[0].shape(); diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp index 4348f0f775e9442c50a3c45b9a8e6dad5c6b198d..e0692fa06d6e0c35cfa742ca3eac7fe2037b1a80 100644 --- a/paddle/function/NaiveConvOp.cpp +++ b/paddle/function/NaiveConvOp.cpp @@ -90,8 +90,7 @@ public: ConvFunctionBase::init(config); } - virtual void check(const BufferArgs& inputs, - const BufferArgs& outputs) override { + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { const TensorShape& input = inputs[0].shape(); const TensorShape& filter = inputs[1].shape(); const TensorShape& output = outputs[0].shape(); diff --git a/paddle/gserver/dataproviders/DataProvider.h b/paddle/gserver/dataproviders/DataProvider.h index 40036762179ebb1495b90907f16b97e3c60c50d8..265dbb54933540ff8b0d1e2e2d985b4b7fa51ecd 100644 --- a/paddle/gserver/dataproviders/DataProvider.h +++ b/paddle/gserver/dataproviders/DataProvider.h @@ -205,10 +205,8 @@ public: hl_destroy_event(hlEvent_); hlEvent_ = NULL; } - if (batchData_) { - delete batchData_; - batchData_ = NULL; - } + delete batchData_; + batchData_ = NULL; } void setDataBatch(DataBatch* batchData) { batchData_ = batchData; } diff --git a/paddle/gserver/gradientmachines/NeuralNetwork.cpp b/paddle/gserver/gradientmachines/NeuralNetwork.cpp index 2e839f640503b8f4e390fc87d9d59960dbc37f6e..cfa80a89365af5111746eec9599d16e37532a9f7 100644 --- a/paddle/gserver/gradientmachines/NeuralNetwork.cpp +++ b/paddle/gserver/gradientmachines/NeuralNetwork.cpp @@ -403,7 +403,7 @@ public: : layerName_(layerName) { addEvaluator(std::move(evaluator)); } - virtual void eval(const NeuralNetwork& nn) override { + void eval(const NeuralNetwork& nn) override { const LayerPtr& layer = nn.getLayer(layerName_); CHECK(layer) << "Nonexisted layer: " << layerName_ << " in submodel " << nn.getName(); diff --git a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp index 9a972466d66ba1417b2c31e66dc375b3da229aa8..9ddd449de7500f5682d59469328f06971c6e83bf 100644 --- a/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp +++ b/paddle/gserver/gradientmachines/RecurrentGradientMachine.cpp @@ -636,7 +636,7 @@ void lenToStarts(std::vector& starts) { } starts.back() = pos; } -} +} // namespace void RecurrentGradientMachine::calcSequenceStartPositions() { std::vector starts(commonSeqInfo_.size() + 1); diff --git a/paddle/gserver/layers/AgentLayer.cpp b/paddle/gserver/layers/AgentLayer.cpp index 15e7411b5fde0fa3a532394cf7d0e8477ef052d0..bdae7e623ae0472d4fe5ef3a88fc1e93bbf1e52c 100644 --- a/paddle/gserver/layers/AgentLayer.cpp +++ b/paddle/gserver/layers/AgentLayer.cpp @@ -124,7 +124,7 @@ void copyElements(const IVector& srcVec, dest[index[i]] = src[i]; } } -} +} // namespace void GatherAgentLayer::forwardIds(PassType passType) { IVectorPtr realId = realLayers_[0]->getOutputLabel(); diff --git a/paddle/math/Storage.cpp b/paddle/math/Storage.cpp index 7ce17a3207becb176a852a16fca52376009db9ee..4adaaef9838f0d178468af3af142031325bfc11d 100644 --- a/paddle/math/Storage.cpp +++ b/paddle/math/Storage.cpp @@ -32,9 +32,7 @@ static InitFunction __init_storage_engine([]() { StorageEngine::singleton(); }, StorageEngine::StorageEngine() : cpuAllocator_(nullptr) {} StorageEngine::~StorageEngine() { - if (cpuAllocator_) { - delete cpuAllocator_; - } + delete cpuAllocator_; for (auto it : gpuAllocator_) { delete it; } diff --git a/paddle/memory/CMakeLists.txt b/paddle/memory/CMakeLists.txt index 3943c3cfad31d13a00645aba6fc153d3d13da987..fac442cca56b81f56a750bd3b1c2c0911e79e468 100644 --- a/paddle/memory/CMakeLists.txt +++ b/paddle/memory/CMakeLists.txt @@ -1 +1,11 @@ add_subdirectory(detail) + +cc_library(memory SRCS memory.cc) + +cc_library(paddle_memory + DEPS + memory meta_data + meta_cache memory_block + buddy_allocator system_allocator) + +cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory) diff --git a/paddle/memory/detail/CMakeLists.txt b/paddle/memory/detail/CMakeLists.txt index 72d3749ad789eca9a4b10944131171c0cf8dfe5a..b9c3fc31c1523abf3acbd116745bbf1596454aac 100644 --- a/paddle/memory/detail/CMakeLists.txt +++ b/paddle/memory/detail/CMakeLists.txt @@ -1,7 +1,15 @@ if(${WITH_GPU}) - nv_library(system_allocator SRCS system_allocator.cc DEPS gflags) - nv_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags) + nv_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info gpu_info) else(${WITH_GPU}) - cc_library(system_allocator SRCS system_allocator.cc DEPS gflags) - cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator gflags) + cc_library(system_allocator SRCS system_allocator.cc DEPS gflags cpu_info) endif(${WITH_GPU}) + +cc_test(system_allocator_test SRCS system_allocator_test.cc DEPS system_allocator) + +cc_library(meta_data SRCS meta_data.cc) + +cc_library(meta_cache SRCS meta_cache.cc) + +cc_library(memory_block SRCS memory_block.cc) + +cc_library(buddy_allocator SRCS buddy_allocator.cc DEPS glog) diff --git a/paddle/memory/detail/buddy_allocator.cc b/paddle/memory/detail/buddy_allocator.cc index ebe680f5eea4948339fb8c5584a5b9f5d71c752e..27c1b4033b53b059d38ed88694b20b429cbb4cce 100644 --- a/paddle/memory/detail/buddy_allocator.cc +++ b/paddle/memory/detail/buddy_allocator.cc @@ -12,22 +12,317 @@ See the License for the specific language governing permissions and limitations under the License. */ -#pragma once - #include "paddle/memory/detail/buddy_allocator.h" +#include "glog/logging.h" namespace paddle { namespace memory { namespace detail { -BuddyAllocator::BuddyAllocator(size_t pool_size, size_t max_pools, - SystemAllocator* system_allocator) - : pool_size_(pool_size), - max_pools_(max_pools), - system_allocator_(system_allocator) { - PADDLE_ASSERT(pool_size > 0); - PADDLE_ASSERT(max_pools > 0); - PADDLE_ASSERT(system_allocator != nullptr); +BuddyAllocator::BuddyAllocator(SystemAllocator* system_allocator, + size_t min_chunk_size, size_t max_chunk_size) + : min_chunk_size_(min_chunk_size), + max_chunk_size_(max_chunk_size), + cache_(system_allocator->UseGpu()), + system_allocator_(std::move(system_allocator)) {} + +BuddyAllocator::~BuddyAllocator() { + DLOG(INFO) << "BuddyAllocator Disconstructor makes sure that all of these " + "have actually been freed"; + while (!pool_.empty()) { + auto block = static_cast(std::get<2>(*pool_.begin())); + DLOG(INFO) << "Free from block (" << block << ", " << max_chunk_size_ + << ")"; + + system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); + cache_.invalidate(block); + pool_.erase(pool_.begin()); + } +} + +inline size_t align(size_t size, size_t alignment) { + size_t remaining = size % alignment; + return remaining == 0 ? size : size + (alignment - remaining); +} + +void* BuddyAllocator::Alloc(size_t unaligned_size) { + // adjust allocation alignment + size_t size = align(unaligned_size + sizeof(Metadata), min_chunk_size_); + + // acquire the allocator lock + std::lock_guard lock(mutex_); + + DLOG(INFO) << "Allocate " << unaligned_size << " bytes from chunk size " + << size; + + // if the allocation is huge, send directly to the system allocator + if (size > max_chunk_size_) { + DLOG(INFO) << "Allocate from system allocator."; + return SystemAlloc(size); + } + + // query and allocate from the existing chunk + auto it = FindExistChunk(size); + + // refill the pool if failure + if (it == pool_.end()) { + it = RefillPool(); + // if still failure, fail fatally + if (it == pool_.end()) { + return nullptr; + } + } else { + DLOG(INFO) << "Allocation from existing memory block " << std::get<2>(*it) + << " at address " + << reinterpret_cast(std::get<2>(*it))->data(); + } + + total_used_ += size; + total_free_ -= size; + + // split the allocation and return data for use + return reinterpret_cast(SplitToAlloc(it, size))->data(); +} + +void BuddyAllocator::Free(void* p) { + // Point back to metadata + auto block = static_cast(p)->metadata(); + + // Acquire the allocator lock + std::lock_guard lock(mutex_); + + DLOG(INFO) << "Free from address " << block; + + if (block->type(cache_) == MemoryBlock::HUGE_CHUNK) { + DLOG(INFO) << "Free directly from system allocator"; + system_allocator_->Free(block, block->total_size(cache_), + block->index(cache_)); + + // Invalidate GPU allocation from cache + cache_.invalidate(block); + + return; + } + + block->mark_as_free(cache_); + + total_used_ -= block->total_size(cache_); + total_free_ += block->total_size(cache_); + + // Trying to merge the right buddy + if (block->has_right_buddy(cache_)) { + DLOG(INFO) << "Merging this block " << block << " with its right buddy " + << block->right_buddy(cache_); + + auto right_buddy = block->right_buddy(cache_); + + if (right_buddy->type(cache_) == MemoryBlock::FREE_CHUNK) { + // Take away right buddy from pool + pool_.erase(IndexSizeAddress(right_buddy->index(cache_), + right_buddy->total_size(cache_), + right_buddy)); + + // merge its right buddy to the block + block->merge(cache_, right_buddy); + } + } + + // Trying to merge the left buddy + if (block->has_left_buddy(cache_)) { + DLOG(INFO) << "Merging this block " << block << " with its left buddy " + << block->left_buddy(cache_); + + auto left_buddy = block->left_buddy(cache_); + + if (left_buddy->type(cache_) == MemoryBlock::FREE_CHUNK) { + // Take away right buddy from pool + pool_.erase(IndexSizeAddress(left_buddy->index(cache_), + left_buddy->total_size(cache_), left_buddy)); + + // merge the block to its left buddy + left_buddy->merge(cache_, block); + block = left_buddy; + } + } + + // Dumping this block into pool + DLOG(INFO) << "Inserting free block (" << block << ", " + << block->total_size(cache_) << ")"; + pool_.insert( + IndexSizeAddress(block->index(cache_), block->total_size(cache_), block)); + + // Clean up if existing too much free memory + + // Prefer freeing fallback allocation first + CleanIdleFallBackAlloc(); + + // Free normal allocation + CleanIdleNormalAlloc(); +} + +size_t BuddyAllocator::Used() { return total_used_; } + +void* BuddyAllocator::SystemAlloc(size_t size) { + size_t index = 0; + void* p = system_allocator_->Alloc(index, size); + + DLOG(INFO) << "Allocated " << p << " from system allocator."; + + if (p == nullptr) return nullptr; + + static_cast(p)->init(cache_, MemoryBlock::HUGE_CHUNK, index, + size, nullptr, nullptr); + + return static_cast(p)->data(); +} + +BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() { +#ifndef PADDLE_ONLY_CPU + if (system_allocator_->UseGpu()) { + if ((total_used_ + total_free_) == 0) { + // Compute the maximum allocation size for the first allocation. + max_chunk_size_ = platform::GpuMaxChunkSize(); + } + } +#endif // PADDLE_ONLY_CPU + + // Allocate a new maximum sized block + size_t index = 0; + void* p = system_allocator_->Alloc(index, max_chunk_size_); + + if (p == nullptr) return pool_.end(); + + DLOG(INFO) << "Creating and inserting new block " << p + << " from system allocator"; + + static_cast(p)->init(cache_, MemoryBlock::FREE_CHUNK, index, + max_chunk_size_, nullptr, nullptr); + + // gpu fallback allocation + if (system_allocator_->UseGpu() && + static_cast(p)->index(cache_) == 1) { + fallback_alloc_count_++; + } + + total_free_ += max_chunk_size_; + + // dump the block into pool + return pool_.insert(IndexSizeAddress(index, max_chunk_size_, p)).first; +} + +BuddyAllocator::PoolSet::iterator BuddyAllocator::FindExistChunk(size_t size) { + size_t index = 0; + + while (1) { + auto it = pool_.lower_bound(IndexSizeAddress(index, size, nullptr)); + + // no match chunk memory + if (it == pool_.end()) return it; + + if (std::get<0>(*it) > index) { + // find suitable one + if (std::get<1>(*it) >= size) { + return it; + } + // update and continue + index = std::get<0>(*it); + continue; + } + return it; + } +} + +void* BuddyAllocator::SplitToAlloc(BuddyAllocator::PoolSet::iterator it, + size_t size) { + auto block = static_cast(std::get<2>(*it)); + pool_.erase(it); + + DLOG(INFO) << "Split block (" << block << ", " << block->total_size(cache_) + << ") into"; + block->split(cache_, size); + + DLOG(INFO) << "Left block (" << block << ", " << block->total_size(cache_) + << ")"; + block->set_type(cache_, MemoryBlock::ARENA_CHUNK); + + // the rest of memory if exist + if (block->has_right_buddy(cache_)) { + if (block->right_buddy(cache_)->type(cache_) == MemoryBlock::FREE_CHUNK) { + DLOG(INFO) << "Insert right block (" << block->right_buddy(cache_) << ", " + << block->right_buddy(cache_)->total_size(cache_) << ")"; + + pool_.insert( + IndexSizeAddress(block->right_buddy(cache_)->index(cache_), + block->right_buddy(cache_)->total_size(cache_), + block->right_buddy(cache_))); + } + } + + return block; +} + +void BuddyAllocator::CleanIdleFallBackAlloc() { + // If fallback allocation does not exist, return directly + if (!fallback_alloc_count_) return; + + for (auto pool = pool_.rbegin(); pool != pool_.rend();) { + // If free memory block less than max_chunk_size_, return directly + if (std::get<1>(*pool) < max_chunk_size_) return; + + MemoryBlock* block = static_cast(std::get<2>(*pool)); + + // If no GPU fallback allocator, return + if (!system_allocator_->UseGpu() || block->index(cache_) == 0) { + return; + } + + DLOG(INFO) << "Return block " << block << " to fallback allocator."; + + system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); + cache_.invalidate(block); + + pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base())); + + total_free_ -= max_chunk_size_; + fallback_alloc_count_--; + + // If no fall allocation exists, return directly + if (!fallback_alloc_count_) return; + } +} + +void BuddyAllocator::CleanIdleNormalAlloc() { + auto shall_free_alloc = [&]() -> bool { + // free all fallback allocations + if (fallback_alloc_count_ > 0) { + return true; + } + // keep 2x overhead if we haven't fallen back + if ((total_used_ + max_chunk_size_) * 2 < total_free_) { + return true; + } + return false; + }; + + if (!shall_free_alloc()) return; + + for (auto pool = pool_.rbegin(); pool != pool_.rend();) { + // If free memory block less than max_chunk_size_, return directly + if (std::get<1>(*pool) < max_chunk_size_) return; + + MemoryBlock* block = static_cast(std::get<2>(*pool)); + + DLOG(INFO) << "Return block " << block << " to base allocator."; + + system_allocator_->Free(block, max_chunk_size_, block->index(cache_)); + cache_.invalidate(block); + + pool = PoolSet::reverse_iterator(pool_.erase(std::next(pool).base())); + + total_free_ -= max_chunk_size_; + + if (!shall_free_alloc()) return; + } } } // namespace detail diff --git a/paddle/memory/detail/buddy_allocator.h b/paddle/memory/detail/buddy_allocator.h index 82e6aaedc719966b4074449ce1ef7193c73dc265..4fa3fb0ee5f826d2b084c0ba184c505aee3acc48 100644 --- a/paddle/memory/detail/buddy_allocator.h +++ b/paddle/memory/detail/buddy_allocator.h @@ -14,9 +14,16 @@ #pragma once +#include "paddle/memory/detail/meta_cache.h" +#include "paddle/memory/detail/meta_data.h" #include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/cpu_info.h" +#include "paddle/platform/gpu_info.h" #include +#include +#include #include namespace paddle { @@ -25,61 +32,80 @@ namespace detail { class BuddyAllocator { public: - BuddyAllocator(size_t pool_size, size_t max_pools, - SystemAllocator* system_allocator); + BuddyAllocator(SystemAllocator* system_allocator, size_t min_chunk_size, + size_t max_chunk_size); + ~BuddyAllocator(); - void* Alloc(size_t size); + public: + void* Alloc(size_t unaligned_size); void Free(void*); size_t Used(); + public: + // Disable copy and assignment + BuddyAllocator(const BuddyAllocator&) = delete; + BuddyAllocator& operator=(const BuddyAllocator&) = delete; + private: - struct Block { - size_t size_; - Block* left_; // left buddy - Block* right_; // right buddy - }; + // Tuple (allocator index, memory size, memory address) + using IndexSizeAddress = std::tuple; + // Each element in PoolSet is a free allocation + using PoolSet = std::set; - // Initially, there is only one pool. If a Alloc founds not enough - // memory from that pool, and there has not been max_num_pools_, - // create a new pool by calling system_allocator_.Alloc(pool_size_). - std::vector pools_; + /*! \brief Allocate fixed-size memory from system */ + void* SystemAlloc(size_t size); - size_t pool_size_; // the size of each pool; - size_t max_num_pools_; // the size of all pools; + /*! \brief If existing chunks are not suitable, refill pool */ + PoolSet::iterator RefillPool(); - SystemAllocator* system_allocator_; + /** + * \brief Find the suitable chunk from existing pool and split + * it to left and right buddies + * + * \param it the iterator of pool list + * \param size the size of allocation + * + * \return the left buddy address + */ + void* SplitToAlloc(PoolSet::iterator it, size_t size); - std::mutex mutex_; + /*! \brief Find the existing chunk which used to allocation */ + PoolSet::iterator FindExistChunk(size_t size); - // Disable copy and assignment. - BuddyAllocator(const BuddyAllocator&) = delete; - BuddyAllocator& operator=(const BuddyAllocator&) = delete; -}; + /*! \brief Clean idle fallback allocation */ + void CleanIdleFallBackAlloc(); + + /*! \brief Clean idle normal allocation */ + void CleanIdleNormalAlloc(); -BuddyAllocator* GetCPUBuddyAllocator() { - static BuddyAllocator* a = nullptr; - if (a == nullptr) { - a = new BuddyAllocator(); - } - return a; -} - -#ifndef PADDLE_ONLY_CPU // The following code are for CUDA. - -BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { - static BuddyAllocator** as = NULL; - if (as == NULL) { - int gpu_num = platform::GetDeviceCount(); - as = new BuddyAllocator*[gpu_num]; - for (int gpu = 0; gpu < gpu_num; gpu++) { - as[gpu] = new BuddyAllocator(); - } - } - return as[gpu_id]; -} - -#endif // PADDLE_ONLY_CPU + private: + size_t total_used_ = 0; // the total size of used memory + size_t total_free_ = 0; // the total size of free memory + + size_t min_chunk_size_; // the minimum size of each chunk + size_t max_chunk_size_; // the maximum size of each chunk + + private: + /** + * \brief A list of free allocation + * + * \note Only store free chunk memory in pool + */ + PoolSet pool_; + + /*! Record fallback allocation count for auto-scaling */ + size_t fallback_alloc_count_ = 0; + + private: + /*! Unify the metadata format between GPU and CPU allocations */ + MetadataCache cache_; + + private: + /*! Allocate CPU/GPU memory from system */ + SystemAllocator* system_allocator_; + std::mutex mutex_; +}; } // namespace detail } // namespace memory diff --git a/paddle/memory/detail/memory_block.cc b/paddle/memory/detail/memory_block.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc40993208323f1f5d18103165c8835b5f829613 --- /dev/null +++ b/paddle/memory/detail/memory_block.cc @@ -0,0 +1,157 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/memory/detail/memory_block.h" +#include "paddle/memory/detail/meta_cache.h" +#include "paddle/memory/detail/meta_data.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace memory { +namespace detail { + +void MemoryBlock::init(MetadataCache& cache, Type t, size_t index, size_t size, + void* left_buddy, void* right_buddy) { + cache.store(this, Metadata(t, index, size - sizeof(Metadata), size, + static_cast(left_buddy), + static_cast(right_buddy))); +} + +MemoryBlock::Type MemoryBlock::type(MetadataCache& cache) const { + return cache.load(this).type; +} + +size_t MemoryBlock::size(MetadataCache& cache) const { + return cache.load(this).size; +} + +size_t MemoryBlock::total_size(MetadataCache& cache) const { + return cache.load(this).total_size; +} + +MemoryBlock* MemoryBlock::left_buddy(MetadataCache& cache) const { + return cache.load(this).left_buddy; +} + +MemoryBlock* MemoryBlock::right_buddy(MetadataCache& cache) const { + return cache.load(this).right_buddy; +} + +void MemoryBlock::split(MetadataCache& cache, size_t size) { + // make sure the split fits + PADDLE_ASSERT(total_size(cache) >= size); + + // bail out if there is no room for another partition + if (total_size(cache) - size <= sizeof(Metadata)) { + return; + } + + // find the position of the split + void* right_partition = reinterpret_cast(this) + size; + + size_t remaining_size = total_size(cache) - size; + + // Add the new block as a buddy + auto metadata = cache.load(this); + + // Write the metadata for the new block + auto new_block_right_buddy = metadata.right_buddy; + + cache.store( + static_cast(right_partition), + Metadata(FREE_CHUNK, index(cache), remaining_size - sizeof(Metadata), + remaining_size, this, new_block_right_buddy)); + + metadata.right_buddy = static_cast(right_partition); + metadata.size = size - sizeof(Metadata); + metadata.total_size = size; + + cache.store(this, metadata); + + // Write metadata for the new block's right buddy + if (new_block_right_buddy != nullptr) { + auto buddy_metadata = cache.load(new_block_right_buddy); + + buddy_metadata.left_buddy = static_cast(right_partition); + + cache.store(new_block_right_buddy, buddy_metadata); + } +} + +void MemoryBlock::merge(MetadataCache& cache, MemoryBlock* right_buddy) { + // only free blocks can be merged + PADDLE_ASSERT(type(cache) == FREE_CHUNK); + PADDLE_ASSERT(right_buddy->type(cache) == FREE_CHUNK); + + auto metadata = cache.load(this); + + // link this->buddy's buddy + metadata.right_buddy = right_buddy->right_buddy(cache); + + // link buddy's buddy -> this + if (metadata.right_buddy != nullptr) { + auto buddy_metadata = cache.load(metadata.right_buddy); + + buddy_metadata.left_buddy = this; + + cache.store(metadata.right_buddy, buddy_metadata); + } + + metadata.size += right_buddy->total_size(cache); + metadata.total_size += right_buddy->total_size(cache); + + cache.store(this, metadata); + cache.store(right_buddy, Metadata(INVALID_CHUNK, 0, 0, 0, nullptr, nullptr)); +} + +void MemoryBlock::mark_as_free(MetadataCache& cache) { + // check for double free or corruption + PADDLE_ASSERT(type(cache) != FREE_CHUNK); + PADDLE_ASSERT(type(cache) != INVALID_CHUNK); + + set_type(cache, FREE_CHUNK); +} + +void MemoryBlock::set_type(MetadataCache& cache, Type t) { + auto metadata = cache.load(this); + + metadata.type = t; + + cache.store(this, metadata); +} + +bool MemoryBlock::has_left_buddy(MetadataCache& cache) const { + return left_buddy(cache) != nullptr; +} + +bool MemoryBlock::has_right_buddy(MetadataCache& cache) const { + return right_buddy(cache) != nullptr; +} + +size_t MemoryBlock::index(MetadataCache& cache) const { + return cache.load(this).index; +} + +void* MemoryBlock::data() const { + return const_cast(reinterpret_cast(this)) + 1; +} + +MemoryBlock* MemoryBlock::metadata() const { + return const_cast(reinterpret_cast( + reinterpret_cast(this) - 1)); +} + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/memory_block.h b/paddle/memory/detail/memory_block.h new file mode 100644 index 0000000000000000000000000000000000000000..a5168b519f3a3747f34ef2ea7b87d72dce70064d --- /dev/null +++ b/paddle/memory/detail/memory_block.h @@ -0,0 +1,91 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include + +namespace paddle { +namespace memory { +namespace detail { + +// Forward Declarations +class MetadataCache; + +/*! \brief A class used to interpret the contents of a memory block */ +class MemoryBlock { + public: + enum Type { + FREE_CHUNK, // memory is free and idle + ARENA_CHUNK, // memory is being occupied + HUGE_CHUNK, // memory is out of management + INVALID_CHUNK // memory is invalid + }; + + public: + void init(MetadataCache& cache, Type t, size_t index, size_t size, + void* left_buddy, void* right_buddy); + + public: + /*! \brief The type of the allocation */ + Type type(MetadataCache& cache) const; + + /*! \brief The size of the data region */ + size_t size(MetadataCache& cache) const; + + /*! \brief An index to track the allocator */ + size_t index(MetadataCache& cache) const; + + /*! \brief The total size of the block */ + size_t total_size(MetadataCache& cache) const; + + /*! \brief Check the left buddy of the block */ + bool has_left_buddy(MetadataCache& cache) const; + + /*! \brief Check the right buddy of the block */ + bool has_right_buddy(MetadataCache& cache) const; + + /*! \brief Get the left buddy */ + MemoryBlock* left_buddy(MetadataCache& cache) const; + + /*! \brief Get the right buddy */ + MemoryBlock* right_buddy(MetadataCache& cache) const; + + public: + /*! \brief Split the allocation into left/right blocks */ + void split(MetadataCache& cache, size_t size); + + /*! \brief Merge left and right blocks together */ + void merge(MetadataCache& cache, MemoryBlock* right_buddy); + + /*! \brief Mark the allocation as free */ + void mark_as_free(MetadataCache& cache); + + /*! \brief Change the type of the allocation */ + void set_type(MetadataCache& cache, Type t); + + public: + /*! \brief Get a pointer to the memory block's data */ + void* data() const; + + /*! \brief Get a pointer to the memory block's metadata */ + MemoryBlock* metadata() const; + + public: + static size_t overhead(); +}; + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/meta_cache.cc b/paddle/memory/detail/meta_cache.cc new file mode 100644 index 0000000000000000000000000000000000000000..30ff80e7bac0b595fe60aeab0a3c59f4e23eae2d --- /dev/null +++ b/paddle/memory/detail/meta_cache.cc @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/memory/detail/meta_cache.h" +#include "paddle/memory/detail/memory_block.h" +#include "paddle/platform/assert.h" + +namespace paddle { +namespace memory { +namespace detail { + +MetadataCache::MetadataCache(bool uses_gpu) : uses_gpu_(uses_gpu) {} + +Metadata MetadataCache::load(const MemoryBlock* block) { + if (uses_gpu_) { + auto existing_metadata = cache_.find(block); + PADDLE_ASSERT(existing_metadata->second.check_guards()); + return existing_metadata->second; + } else { + PADDLE_ASSERT(reinterpret_cast(block)->check_guards()); + return *reinterpret_cast(block); + } +} + +void MetadataCache::store(MemoryBlock* block, + const Metadata& original_metadata) { + auto metadata = original_metadata; + + metadata.update_guards(); + + if (uses_gpu_) { + cache_[block] = metadata; + } else { + *reinterpret_cast(block) = metadata; + } +} + +void MetadataCache::invalidate(MemoryBlock* block) { + if (uses_gpu_) { + cache_.erase(block); + } +} + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/meta_cache.h b/paddle/memory/detail/meta_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..ca0789779e273fb71c3d6282c0a921cda2d776cc --- /dev/null +++ b/paddle/memory/detail/meta_cache.h @@ -0,0 +1,64 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/memory/detail/memory_block.h" +#include "paddle/memory/detail/meta_data.h" + +#include + +namespace paddle { +namespace memory { +namespace detail { + +/** + * \brief A cache for accessing memory block meta-data that may be expensive + * to access directly. + * + * \note This class exists to unify the metadata format between GPU and CPU + * allocations. It should be removed when the CPU can access all GPU + * allocations directly via UVM. + */ +class MetadataCache { + public: + MetadataCache(bool uses_gpu); + + public: + /*! \brief Load the associated metadata for the specified memory block. */ + Metadata load(const MemoryBlock*); + + /*! \brief Store the associated metadata for the specified memory block. */ + void store(MemoryBlock*, const Metadata&); + + /*! \brief Indicate that the specified metadata will no longer be used. */ + void invalidate(MemoryBlock*); + + public: + MetadataCache(const MetadataCache&) = delete; + MetadataCache& operator=(const MetadataCache&) = delete; + + private: + bool uses_gpu_; + + private: + typedef std::unordered_map MetadataMap; + + private: + MetadataMap cache_; +}; + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/meta_data.cc b/paddle/memory/detail/meta_data.cc new file mode 100644 index 0000000000000000000000000000000000000000..70c5c1f439e84ec33cf0507beae33f9cdfa51727 --- /dev/null +++ b/paddle/memory/detail/meta_data.cc @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/memory/detail/meta_data.h" + +#include + +namespace paddle { +namespace memory { +namespace detail { + +Metadata::Metadata(MemoryBlock::Type t, size_t i, size_t s, size_t ts, + MemoryBlock* l, MemoryBlock* r) + : type(t), + index(i), + size(s), + total_size(ts), + left_buddy(l), + right_buddy(r) {} + +Metadata::Metadata() + : type(MemoryBlock::INVALID_CHUNK), + index(0), + size(0), + total_size(0), + left_buddy(nullptr), + right_buddy(nullptr) {} + +template +inline void hash_combine(std::size_t& seed, const T& v) { + std::hash hasher; + seed ^= hasher(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); +} + +inline size_t hash(const Metadata* metadata, size_t initial_seed) { + size_t seed = initial_seed; + + hash_combine(seed, (size_t)metadata->type); + hash_combine(seed, metadata->index); + hash_combine(seed, metadata->size); + hash_combine(seed, metadata->total_size); + hash_combine(seed, metadata->left_buddy); + hash_combine(seed, metadata->right_buddy); + + return seed; +} + +void Metadata::update_guards() { + guard_begin = hash(this, 1); + guard_end = hash(this, 2); +} + +bool Metadata::check_guards() const { + return guard_begin == hash(this, 1) && guard_end == hash(this, 2); +} + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/meta_data.h b/paddle/memory/detail/meta_data.h new file mode 100644 index 0000000000000000000000000000000000000000..628cf1f2e347e288d1bf34c14c7b2f13a28d3662 --- /dev/null +++ b/paddle/memory/detail/meta_data.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/memory/detail/memory_block.h" + +#include + +namespace paddle { +namespace memory { +namespace detail { + +class Metadata { + public: + Metadata(MemoryBlock::Type t, size_t i, size_t s, size_t ts, MemoryBlock* l, + MemoryBlock* r); + Metadata(); + + public: + /*! \brief Update the guards when metadata is changed */ + void update_guards(); + + /*! \brief Check consistency to previous modification */ + bool check_guards() const; + + public: + // TODO(gangliao): compress this + // clang-format off + size_t guard_begin = 0; + MemoryBlock::Type type = MemoryBlock::INVALID_CHUNK; + size_t index = 0; + size_t size = 0; + size_t total_size = 0; + MemoryBlock* left_buddy = nullptr; + MemoryBlock* right_buddy = nullptr; + size_t guard_end = 0; + // clang-format on +}; + +} // namespace detail +} // namespace memory +} // namespace paddle diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 50bec926f83dee8a4343d0b16aeb088f9d2a4871..1579174b1a6ff08824629d833d01411cff651f48 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -13,76 +13,128 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/memory/detail/system_allocator.h" +#include "paddle/platform/assert.h" +#include "paddle/platform/error.h" +#include "paddle/platform/gpu_info.h" #include // for malloc and free #include // for mlock and munlock #include "gflags/gflags.h" -#include "paddle/platform/assert.h" -#include "paddle/platform/cuda.h" // If use_pinned_memory is true, CPUAllocator calls mlock, which // returns pinned and locked memory as staging areas for data exchange // between host and device. Allocates too much would reduce the amount // of memory available to the system for paging. So, by default, we // should set false to use_pinned_memory. -DEFINE_bool(use_pinned_memory, false, - "If set, allocate cpu/gpu pinned memory."); +DEFINE_bool(use_pinned_memory, false, "If set, allocate cpu pinned memory."); namespace paddle { namespace memory { namespace detail { -void* CPUAllocator::Alloc(size_t size) { +void* CPUAllocator::Alloc(size_t& index, size_t size) { // According to http://www.cplusplus.com/reference/cstdlib/malloc/, // malloc might not return nullptr if size is zero, but the returned // pointer shall not be dereferenced -- so we make it nullptr. if (size <= 0) return nullptr; + index = 0; // unlock memory + void* p = malloc(size); - if (p != nullptr && FLAGS_use_pinned_memory) { - mlock(p, size); + + if (p != nullptr) { + if (FLAGS_use_pinned_memory) { + index = 1; + mlock(p, size); // lock memory + } } + return p; } -void CPUAllocator::Free(void* p, size_t size) { - if (p != nullptr && FLAGS_use_pinned_memory) { +void CPUAllocator::Free(void* p, size_t size, size_t index) { + if (p != nullptr && index == 1) { munlock(p, size); } free(p); } +bool CPUAllocator::UseGpu() const { return false; } + #ifndef PADDLE_ONLY_CPU -void* GPUAllocator::Alloc(size_t size) { +void* GPUAllocator::Alloc(size_t& index, size_t size) { // CUDA documentation doesn't explain if cudaMalloc returns nullptr // if size is 0. We just make sure it does. - if (size <= 0) { - return nullptr; - } + if (size <= 0) return nullptr; + size_t available = 0; + size_t capacity = 0; + paddle::platform::GpuMemoryUsage(available, capacity); + + // Reserve memory for page tables, etc. + size_t reserving = capacity - paddle::platform::GpuMaxAllocSize(); + size_t usable = available > reserving ? available - reserving : 0; + + // If remaining size no less than expected size, using general + // cudaMalloc to allocate GPU memory. void* p = 0; - cudaError_t result = - FLAGS_use_pinned_memory ? cudaMallocHost(&p, size) : cudaMalloc(&p, size); - if (result != cudaSuccess) { - cudaGetLastError(); // clear error if there is any. + if (size <= usable) { + cudaError_t result = cudaMalloc(&p, size); + if (result == cudaSuccess) { + index = 0; + gpu_alloc_size_ += size; + return p; + } + } + + // If remaining size less than expected size or cudaMalloc failed, + // cudaMallocHost will be considered as a fallback allocator. + // + // NOTE: here, we use GpuMaxAllocSize() as the maximum memory size + // of host fallback allocation. Allocates too much would reduce + // the amount of memory available to the underlying system for paging. + usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_; + + if (size > usable) return nullptr; + + cudaError_t result = cudaMallocHost(&p, size); + if (result == cudaSuccess) { + index = 1; + fallback_alloc_size_ += size; + return p; } - return result == cudaSuccess ? p : nullptr; + + return nullptr; } -void GPUAllocator::Free(void* p, size_t size) { +void GPUAllocator::Free(void* p, size_t size, size_t index) { + cudaError_t err; + + if (index == 0) { + PADDLE_ASSERT(gpu_alloc_size_ >= size); + gpu_alloc_size_ -= size; + err = cudaFree(p); + } else { + PADDLE_ASSERT(fallback_alloc_size_ >= size); + fallback_alloc_size_ -= size; + err = cudaFreeHost(p); + } + // Purposefully allow cudaErrorCudartUnloading, because // that is returned if you ever call cudaFree after the // driver has already shutdown. This happens only if the // process is terminating, in which case we don't care if // cudaFree succeeds. - cudaError_t err = FLAGS_use_pinned_memory ? cudaFreeHost(p) : cudaFree(p); if (err != cudaErrorCudartUnloading) { - platform::throw_on_error(err, "cudaFree{Host} failed"); + platform::throw_on_error(err, + "cudaFree{Host} failed in GPUAllocator::Free."); } } +bool GPUAllocator::UseGpu() const { return true; } + #endif // PADDLE_ONLY_CPU } // namespace detail diff --git a/paddle/memory/detail/system_allocator.h b/paddle/memory/detail/system_allocator.h index 184b383f7f78244fa6632a3bffb1a0a78b3aa664..82ba322e057575c460b1d51d719c9b0fa459273e 100644 --- a/paddle/memory/detail/system_allocator.h +++ b/paddle/memory/detail/system_allocator.h @@ -20,31 +20,36 @@ namespace paddle { namespace memory { namespace detail { -// SystemAllocator is the parent class of CPUAllocator and -// GPUAllocator. A BuddyAllocator object uses a SystemAllocator* -// pointing to the underlying system allocator. An alternative to -// this class hierarchy is to pass a system allocator class to -// BuddyAllocator as a template parameter. This approach makes -// BuddyAllocator a class template, and it's very complicated -// algorithm would make the buddy_allocator.h messy. +/** + * \brief SystemAllocator is the parent class of CPUAllocator and GPUAllocator. + * A BuddyAllocator object uses a SystemAllocator* pointing to the + * underlying system allocator. + */ class SystemAllocator { public: virtual ~SystemAllocator() {} - virtual void* Alloc(size_t size) = 0; - virtual void Free(void* p, size_t size) = 0; + virtual void* Alloc(size_t& index, size_t size) = 0; + virtual void Free(void* p, size_t size, size_t index) = 0; + virtual bool UseGpu() const = 0; }; class CPUAllocator : public SystemAllocator { public: - virtual void* Alloc(size_t size); - virtual void Free(void* p, size_t size); + virtual void* Alloc(size_t& index, size_t size); + virtual void Free(void* p, size_t size, size_t index); + virtual bool UseGpu() const; }; #ifndef PADDLE_ONLY_CPU class GPUAllocator : public SystemAllocator { public: - virtual void* Alloc(size_t size); - virtual void Free(void* p, size_t size); + virtual void* Alloc(size_t& index, size_t size); + virtual void Free(void* p, size_t size, size_t index); + virtual bool UseGpu() const; + + private: + size_t gpu_alloc_size_ = 0; + size_t fallback_alloc_size_ = 0; }; #endif // PADDLE_ONLY_CPU diff --git a/paddle/memory/detail/system_allocator_test.cc b/paddle/memory/detail/system_allocator_test.cc index 9bd5706a4e4d1546a8c879ebbac0f3349c9d59f6..ba44e06ddb68e92e4086a8006b868557b0c89b50 100644 --- a/paddle/memory/detail/system_allocator_test.cc +++ b/paddle/memory/detail/system_allocator_test.cc @@ -25,7 +25,8 @@ DECLARE_bool(use_pinned_memory); void TestAllocator(paddle::memory::detail::SystemAllocator& a, size_t size) { bool freed = false; { - void* p = a.Alloc(size); + size_t index; + void* p = a.Alloc(index, size); if (size > 0) { EXPECT_NE(p, nullptr); } else { @@ -35,7 +36,7 @@ void TestAllocator(paddle::memory::detail::SystemAllocator& a, size_t size) { int* i = static_cast(p); std::shared_ptr ptr(i, [&](void* p) { freed = true; - a.Free(p, size); + a.Free(p, size, index); }); } EXPECT_TRUE(freed); @@ -56,14 +57,7 @@ TEST(CPUAllocator, LockMem) { } #ifndef PADDLE_ONLY_CPU -TEST(GPUAllocator, NoStaging) { - FLAGS_use_pinned_memory = false; - paddle::memory::detail::GPUAllocator a; - TestAllocator(a, 2048); - TestAllocator(a, 0); -} -TEST(GPUAllocator, Staging) { - FLAGS_use_pinned_memory = true; +TEST(GPUAllocator, Alloc) { paddle::memory::detail::GPUAllocator a; TestAllocator(a, 2048); TestAllocator(a, 0); diff --git a/paddle/memory/memory.cc b/paddle/memory/memory.cc index 0d123d99e234a378ee64850eebacece223e2b121..df3d57d629184d28fd42130df9b020a7b52ade72 100644 --- a/paddle/memory/memory.cc +++ b/paddle/memory/memory.cc @@ -17,43 +17,67 @@ limitations under the License. */ #include "paddle/memory/detail/system_allocator.h" #include "paddle/platform/assert.h" -#include - namespace paddle { namespace memory { -void* Alloc(platform::Place pl, size_t size) { -#ifndef PADDLE_ONLY_CPU - if (paddle::platform::is_gpu_place(pl)) { - size_t gpu_id = boost::get(pl).device; - return detail::GetGPUBuddyAllocator(gpu_id)->Alloc(size); +detail::BuddyAllocator* GetCPUBuddyAllocator() { + static detail::BuddyAllocator* a = nullptr; + if (a == nullptr) { + a = new detail::BuddyAllocator(new detail::CPUAllocator, + platform::CpuMinChunkSize(), + platform::CpuMaxChunkSize()); } -#endif // PADDLE_ONLY_CPU - PADDLE_ASSERT(paddle::platform::is_cpu_place(pl)); - return detail::GetCPUBuddyAllocator()->Alloc(size); + return a; } -void Free(paddle::platform::Place pl, void* p) { -#ifndef PADDLE_ONLY_CPU - if (paddle::platform::is_gpu_place(pl)) { - size_t gpu_id = boost::get(pl).device; - detail::GetGPUBuddyAllocator(gpu_id)->Free(p); - } -#endif // PADDLE_ONLY_CPU - PADDLE_ASSERT(paddle::platform::is_cpu_place(pl)); - detail::GetCPUBuddyAllocator()->Free(p); +template <> +void* Alloc(platform::CPUPlace place, size_t size) { + return GetCPUBuddyAllocator()->Alloc(size); +} + +template <> +void Free(platform::CPUPlace place, void* p) { + GetCPUBuddyAllocator()->Free(p); +} + +template <> +size_t Used(platform::CPUPlace place) { + return GetCPUBuddyAllocator()->Used(); } -size_t Used(paddle::platform::Place pl) { #ifndef PADDLE_ONLY_CPU - if (paddle::platform::is_gpu_place(pl)) { - size_t gpu_id = boost::get(pl).device; - return detail::GetGPUBuddyAllocator(gpu_id)->Used(); + +detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) { + static detail::BuddyAllocator** as = NULL; + if (as == NULL) { + int gpu_num = platform::GetDeviceCount(); + as = new detail::BuddyAllocator*[gpu_num]; + for (int gpu = 0; gpu < gpu_num; gpu++) { + platform::SetDeviceId(gpu); + as[gpu] = new detail::BuddyAllocator(new detail::GPUAllocator, + platform::GpuMinChunkSize(), + platform::GpuMaxChunkSize()); + } } -#endif // PADDLE_ONLY_CPU - PADDLE_ASSERT(paddle::platform::is_cpu_place(pl)); - return detail::GetCPUBuddyAllocator()->Used(); + return as[gpu_id]; +} + +template <> +void* Alloc(platform::GPUPlace place, size_t size) { + return GetGPUBuddyAllocator(place.device)->Alloc(size); +} + +template <> +void Free(platform::GPUPlace place, void* p) { + GetGPUBuddyAllocator(place.device)->Free(p); +} + +template <> +size_t Used(platform::GPUPlace place) { + return GetGPUBuddyAllocator(place.device)->Used(); } +#endif // PADDLE_ONLY_CPU + } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory.h b/paddle/memory/memory.h index a33092bade65e6df0faee226a8967c9fc9caa032..2d6f4fd2a08ee0039647d276476263d0f8d00329 100644 --- a/paddle/memory/memory.h +++ b/paddle/memory/memory.h @@ -19,9 +19,14 @@ limitations under the License. */ namespace paddle { namespace memory { -void* Alloc(paddle::platform::Place, size_t); -void Free(paddle::platform::Place, void*); -size_t Used(paddle::platform::Place); +template +void* Alloc(Place, size_t); + +template +void Free(Place, void*); + +template +size_t Used(Place); } // namespace memory } // namespace paddle diff --git a/paddle/memory/memory_test.cc b/paddle/memory/memory_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..53cc63a098d0802479e3a371717adb7596c249ed --- /dev/null +++ b/paddle/memory/memory_test.cc @@ -0,0 +1,138 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/memory/memory.h" +#include "paddle/memory/detail/memory_block.h" +#include "paddle/memory/detail/meta_data.h" + +#include "paddle/platform/cpu_info.h" +#include "paddle/platform/gpu_info.h" +#include "paddle/platform/place.h" + +#include +#include + +inline bool is_aligned(void const *p) { + return 0 == (reinterpret_cast(p) & 0x3); +} + +size_t align(size_t size, paddle::platform::CPUPlace place) { + size += sizeof(paddle::memory::detail::Metadata); + size_t alignment = paddle::platform::CpuMinChunkSize(); + size_t remaining = size % alignment; + return remaining == 0 ? size : size + (alignment - remaining); +} + +TEST(BuddyAllocator, CPUAllocation) { + void *p = nullptr; + + EXPECT_EQ(p, nullptr); + + paddle::platform::CPUPlace cpu; + p = paddle::memory::Alloc(cpu, 4096); + + EXPECT_NE(p, nullptr); + + paddle::memory::Free(cpu, p); +} + +TEST(BuddyAllocator, CPUMultAlloc) { + paddle::platform::CPUPlace cpu; + + std::unordered_map ps; + + size_t total_size = paddle::memory::Used(cpu); + EXPECT_EQ(total_size, 0UL); + + for (auto size : + {128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) { + ps[paddle::memory::Alloc(cpu, size)] = size; + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(cpu) == total_size) continue; + + size_t aligned_size = align(size, cpu); + total_size += aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(cpu)); + } + + for (auto p : ps) { + EXPECT_EQ(is_aligned(p.first), true); + paddle::memory::Free(cpu, p.first); + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(cpu) == total_size) continue; + + size_t aligned_size = align(p.second, cpu); + total_size -= aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(cpu)); + } +} + +#ifndef PADDLE_ONLY_CPU + +size_t align(size_t size, paddle::platform::GPUPlace place) { + size += sizeof(paddle::memory::detail::Metadata); + size_t alignment = paddle::platform::GpuMinChunkSize(); + size_t remaining = size % alignment; + return remaining == 0 ? size : size + (alignment - remaining); +} + +TEST(BuddyAllocator, GPUAllocation) { + void *p = nullptr; + + EXPECT_EQ(p, nullptr); + + paddle::platform::GPUPlace gpu(0); + p = paddle::memory::Alloc(gpu, 4096); + + EXPECT_NE(p, nullptr); + + paddle::memory::Free(gpu, p); +} + +TEST(BuddyAllocator, GPUMultAlloc) { + paddle::platform::GPUPlace gpu; + + std::unordered_map ps; + + size_t total_size = paddle::memory::Used(gpu); + EXPECT_EQ(total_size, 0UL); + + for (auto size : + {128, 256, 1024, 4096, 16384, 65536, 262144, 1048576, 4194304}) { + ps[paddle::memory::Alloc(gpu, size)] = size; + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(gpu) == total_size) continue; + + size_t aligned_size = align(size, gpu); + total_size += aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(gpu)); + } + + for (auto p : ps) { + EXPECT_EQ(is_aligned(p.first), true); + paddle::memory::Free(gpu, p.first); + + // Buddy Allocator doesn't manage too large memory chunk + if (paddle::memory::Used(gpu) == total_size) continue; + + size_t aligned_size = align(p.second, gpu); + total_size -= aligned_size; + EXPECT_EQ(total_size, paddle::memory::Used(gpu)); + } +} + +#endif // PADDLE_ONLY_CPU diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 40bb326512c118178184120d4bc26dc127689ff3..50ecc6f85ca51f57da069ae445918bcf58a4a146 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_GPU) - nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim) + nv_library(add_op SRCS add_op.cc add_op.cu DEPS operator op_registry ddim glog paddle_memory) else() - cc_library(add_op SRCS add_op.cc DEPS operator op_registry glog ddim) + cc_library(add_op SRCS add_op.cc DEPS operator op_registry ddim glog paddle_memory) endif() cc_test(add_op_test SRCS add_op_test.cc DEPS add_op) diff --git a/paddle/operators/add_op.cc b/paddle/operators/add_op.cc index 7dc6414af2b3378c68b833568d7ac05251461a97..41d044cdb72b5fb2a7f8654e8ad103778e0857d1 100644 --- a/paddle/operators/add_op.cc +++ b/paddle/operators/add_op.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include "paddle/operators/add_op.h" #include "paddle/framework/op_registry.h" #include "paddle/framework/tensor.h" @@ -17,8 +31,7 @@ protected: "Inputs/Outputs of AddOp must all be set"); PADDLE_ENFORCE(inputs[0]->dims() == inputs[1]->dims(), "Two input of Add Op's dimension must be same."); - // Need set dims in Tensor - // outputs[0]->set_dims(inputs[0]->dims()) + outputs[0]->set_dims(inputs[0]->dims()); } }; @@ -42,4 +55,4 @@ The equation is: Out = X + Y REGISTER_OP(add_two, paddle::operators::AddOp, paddle::operators::AddOpMaker); typedef paddle::operators::AddKernel<::paddle::platform::CPUPlace, float> AddKernel_CPU_float; -REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); \ No newline at end of file +REGISTER_OP_CPU_KERNEL(add_two, AddKernel_CPU_float); diff --git a/paddle/operators/add_op.h b/paddle/operators/add_op.h index 568cb19742cdeebf9752149706b37388c0ab3ad6..e8c718669a9959252ce473e989a0ad27ebc487cc 100644 --- a/paddle/operators/add_op.h +++ b/paddle/operators/add_op.h @@ -1,7 +1,20 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #pragma once #include "glog/logging.h" #include "paddle/framework/operator.h" -//#include "paddle/operators/add_op_functor.h" namespace paddle { namespace operators { diff --git a/paddle/operators/add_op_test.cc b/paddle/operators/add_op_test.cc index f554ac1bef3255f136ad4407a7a1096bdc2b1db5..53b354fedcacf2176aed8b504daf2046bdf96bb6 100644 --- a/paddle/operators/add_op_test.cc +++ b/paddle/operators/add_op_test.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + #include #define private public #include diff --git a/paddle/optimizer/parameter_optimizer_test.cpp b/paddle/optimizer/parameter_optimizer_test.cpp index 4e6254d9e4dab48279b4a880695959526d30d70c..edf4ae37a9beee2911d23dd1ab23e67a18065b1b 100644 --- a/paddle/optimizer/parameter_optimizer_test.cpp +++ b/paddle/optimizer/parameter_optimizer_test.cpp @@ -1,3 +1,19 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + #include "parameter_optimizer.h" #include #include @@ -5,21 +21,18 @@ #include "gtest/gtest.h" #include "lr_policy.h" -using namespace paddle; -using namespace paddle::optimizer; - -Tensor* FillTensor(size_t size) { - Tensor* param = new Tensor(size); - Tensor& p = *param; +paddle::optimizer::Tensor* FillTensor(size_t size) { + paddle::optimizer::Tensor* param = new paddle::optimizer::Tensor(size); + paddle::optimizer::Tensor& p = *param; for (size_t i = 0; i < p.size(); ++i) { p[i] = (float)rand() / (float)RAND_MAX; } return param; } -Tensor* FixedTensor(size_t size) { - Tensor* param = new Tensor(size); - Tensor& p = *param; +paddle::optimizer::Tensor* FixedTensor(size_t size) { + paddle::optimizer::Tensor* param = new paddle::optimizer::Tensor(size); + paddle::optimizer::Tensor& p = *param; for (size_t i = 0; i < p.size(); ++i) { p[i] = i; } @@ -28,7 +41,8 @@ Tensor* FixedTensor(size_t size) { class OptimizerTest : public testing::Test { public: - // init tensor shape + virtual ~OptimizerTest() {} + // init paddle::optimizer::Tensor shape const size_t kSize = 5; virtual void SetUp() { @@ -38,34 +52,36 @@ public: virtual void TearDown() {} void CreateSGD() { - Tensor* parameter = FixedTensor(kSize); - config_.set_optimizer(OptimizerConfig::SGD); + paddle::optimizer::Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(paddle::OptimizerConfig::SGD); config_.mutable_sgd()->set_momentum(0.0); config_.mutable_sgd()->set_decay(0.0); config_.mutable_sgd()->set_nesterov(false); - config_.set_lr_policy(OptimizerConfig::Const); + config_.set_lr_policy(paddle::OptimizerConfig::Const); config_.mutable_const_lr()->set_learning_rate(0.1); std::string str = config_.SerializeAsString(); - ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); + paddle::optimizer::ParameterOptimizer* opt = + paddle::optimizer::ParameterOptimizer::Create(str, parameter); opts_.push_back(opt); } void CreateAdam() { - Tensor* parameter = FixedTensor(kSize); - config_.set_optimizer(OptimizerConfig::Adam); + paddle::optimizer::Tensor* parameter = FixedTensor(kSize); + config_.set_optimizer(paddle::OptimizerConfig::Adam); config_.mutable_adam()->set_beta_1(0.9); config_.mutable_adam()->set_beta_2(0.1); config_.mutable_adam()->set_epsilon(1e-3); config_.mutable_adam()->set_decay(0.0); - config_.set_lr_policy(OptimizerConfig::Const); + config_.set_lr_policy(paddle::OptimizerConfig::Const); config_.mutable_const_lr()->set_learning_rate(0.1); std::string str = config_.SerializeAsString(); - ParameterOptimizer* opt = ParameterOptimizer::Create(str, parameter); + paddle::optimizer::ParameterOptimizer* opt = + paddle::optimizer::ParameterOptimizer::Create(str, parameter); opts_.push_back(opt); } void TestGetWeight() { - Tensor* p = FixedTensor(kSize); + paddle::optimizer::Tensor* p = FixedTensor(kSize); for (size_t i = 0; i < opts_.size(); ++i) { int s = 0; float* newp = (float*)opts_[i]->get_weight(&s); @@ -76,7 +92,7 @@ public: } void TestUpdate() { - Tensor* g = FixedTensor(kSize); + paddle::optimizer::Tensor* g = FixedTensor(kSize); for (size_t i = 0; i < opts_.size(); ++i) { opts_[i]->Update(g); } @@ -91,8 +107,8 @@ public: } private: - std::vector opts_; - OptimizerConfig config_; + std::vector opts_; + paddle::OptimizerConfig config_; }; TEST_F(OptimizerTest, TestGetWeight) { TestGetWeight(); } diff --git a/paddle/optimizer/serialization_test.cpp b/paddle/optimizer/serialization_test.cpp index d2454140dc243b40ed8348578360b30894213838..e4d97cbdba545c4ba5adf5b30efd3fc9f3f744ee 100644 --- a/paddle/optimizer/serialization_test.cpp +++ b/paddle/optimizer/serialization_test.cpp @@ -1,19 +1,32 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + #include "serialization.h" #include "gtest/gtest.h" -using namespace paddle; -using namespace paddle::optimizer; - TEST(TensorToProto, Case1) { - Tensor t(3), t1(3); + paddle::optimizer::Tensor t(3), t1(3); for (size_t i = 0; i < t.size(); ++i) { t[i] = i; t1[i] = 0; } - TensorProto proto; - TensorToProto(t, &proto); - ProtoToTensor(proto, &t1); + paddle::TensorProto proto; + paddle::optimizer::TensorToProto(t, &proto); + paddle::optimizer::ProtoToTensor(proto, &t1); for (size_t i = 0; i < t1.size(); ++i) { EXPECT_EQ(t1[i], t[i]); } diff --git a/paddle/platform/CMakeLists.txt b/paddle/platform/CMakeLists.txt index 358d14f4555e1d046c8e7b91e23d54fb504926e5..6ac4035c0f863c5f63d17b523a7a8be668ff3da0 100644 --- a/paddle/platform/CMakeLists.txt +++ b/paddle/platform/CMakeLists.txt @@ -1,10 +1,13 @@ -add_subdirectory(dynload) +cc_library(cpu_info SRCS cpu_info.cc DEPS gflags glog) +cc_test(cpu_info_test SRCS cpu_info_test.cc DEPS cpu_info) -nv_test(cuda_test SRCS cuda_test.cu) +nv_library(gpu_info SRCS gpu_info.cc DEPS gflags) cc_library(place SRCS place.cc) cc_test(place_test SRCS place_test.cc DEPS place glog gflags) +add_subdirectory(dynload) + IF(WITH_GPU) set(GPU_CTX_DEPS dynload_cuda dynamic_loader) ELSE() @@ -12,4 +15,4 @@ ELSE() ENDIF() cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS}) -nv_test(device_context_test SRCS device_context_test.cc DEPS device_context glog gflags) +nv_test(device_context_test SRCS device_context_test.cc DEPS device_context gpu_info) diff --git a/paddle/platform/cpu_info.cc b/paddle/platform/cpu_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..dfab391cfbe1f04bc2a998233f7e7909579ca72b --- /dev/null +++ b/paddle/platform/cpu_info.cc @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/cpu_info.h" + +#ifdef __APPLE__ +#include +#include +#else +#include +#endif + +#include "gflags/gflags.h" +#include "paddle/platform/error.h" + +DEFINE_double(fraction_of_cpu_memory_to_use, 1, + "Default use 100% of CPU memory for PaddlePaddle," + "reserve the rest for page tables, etc"); + +namespace paddle { +namespace platform { + +inline size_t CpuTotalPhysicalMemory() { +#ifdef __APPLE__ + int mib[2]; + mib[0] = CTL_HW; + mib[1] = HW_MEMSIZE; + int64_t size = 0; + size_t len = sizeof(size); + if (sysctl(mib, 2, &size, &len, NULL, 0) == 0) return (size_t)size; + return 0L; +#else + int64_t pages = sysconf(_SC_PHYS_PAGES); + int64_t page_size = sysconf(_SC_PAGE_SIZE); + return pages * page_size; +#endif +} + +size_t CpuMaxAllocSize() { + // For distributed systems, it requires configuring and limiting + // the fraction of memory to use. + return FLAGS_fraction_of_cpu_memory_to_use * CpuTotalPhysicalMemory(); +} + +size_t CpuMinChunkSize() { + // Allow to allocate the minimum chunk size is 4 KB. + return 1 << 12; +} + +size_t CpuMaxChunkSize() { + // Allow to allocate the maximum chunk size is roughly 3% of CPU memory. + return CpuMaxAllocSize() / 32; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/operators/add_op_functor.h b/paddle/platform/cpu_info.h similarity index 57% rename from paddle/operators/add_op_functor.h rename to paddle/platform/cpu_info.h index 904f24b03054c5eb6149cc9400f635036091316c..8df7c7b4bca5bc88f6ed95d6ab82c81b73918e92 100644 --- a/paddle/operators/add_op_functor.h +++ b/paddle/platform/cpu_info.h @@ -14,22 +14,19 @@ limitations under the License. */ #pragma once -#include "paddle/framework/tensor_types.h" -#include "unsupported/Eigen/CXX11/Tensor" +#include namespace paddle { -namespace operators { -namespace functor { - -template -struct Add { - void Operator()(const Device& d, - typename TTypes::ConstTensor input1, - typename TTypes::ConstTensor input2, - typename TTypes::Tensor output) { - output.device(d) = input1 + input2; - } -}; -} // namespace functor -} // namespace operators +namespace platform { + +//! Get the maximum allocation size for a machine. +size_t CpuMaxAllocSize(); + +//! Get the minimum chunk size for buddy allocator. +size_t CpuMinChunkSize(); + +//! Get the maximum chunk size for buddy allocator. +size_t CpuMaxChunkSize(); + +} // namespace platform } // namespace paddle diff --git a/paddle/platform/cpu_info_test.cc b/paddle/platform/cpu_info_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8fb195aa7c0a41b7417ff5cf63394046e9c72267 --- /dev/null +++ b/paddle/platform/cpu_info_test.cc @@ -0,0 +1,21 @@ +#include "paddle/platform/cpu_info.h" +#include "paddle/string/printf.h" + +#include +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +DECLARE_double(fraction_of_cpu_memory_to_use); + +TEST(CpuMemoryUsage, Print) { + std::stringstream ss; + size_t memory_size = paddle::platform::CpuMaxAllocSize() / 1024 / 1024 / 1024; + float use_percent = FLAGS_fraction_of_cpu_memory_to_use * 100; + + std::cout << paddle::string::Sprintf("\n%.2f %% of CPU Memory Usage: %d GB\n", + use_percent, memory_size) + << std::endl; +} diff --git a/paddle/platform/cuda_test.cu b/paddle/platform/cuda_test.cu deleted file mode 100644 index 4067dda2f19f7661722d8a14a27c7b32ed6afc92..0000000000000000000000000000000000000000 --- a/paddle/platform/cuda_test.cu +++ /dev/null @@ -1,59 +0,0 @@ -#include -#include -#include "gtest/gtest.h" - -#define CHECK_ERR(x) \ - if (x != cudaSuccess) { \ - fprintf(stderr, \ - "%s in %s at line %d\n", \ - cudaGetErrorString(err), \ - __FILE__, \ - __LINE__); \ - exit(-1); \ - } - -__global__ void vecAdd(float *d_A, float *d_B, float *d_C, int n) { - int i = blockDim.x * blockIdx.x + threadIdx.x; - if (i < n) { - d_C[i] = d_A[i] + d_B[i]; - } -} - -TEST(Cuda, Equality) { - int n = 10; - // Memory allocation for h_A, h_B and h_C (in the host) - float h_A[10] = {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0}; - float h_B[10] = {0.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0}; - float h_C[10]; - float *d_A, *d_B, *d_C; - cudaError_t err; - // Memory allocation for d_A, d_B and d_C (in the device) - err = cudaMalloc((void **)&d_A, sizeof(float) * n); - CHECK_ERR(err); - - err = cudaMalloc((void **)&d_B, sizeof(float) * n); - CHECK_ERR(err); - - err = cudaMalloc((void **)&d_C, sizeof(float) * n); - CHECK_ERR(err); - - // Copying memory to device - err = cudaMemcpy(d_A, h_A, sizeof(float) * n, cudaMemcpyHostToDevice); - CHECK_ERR(err); - - err = cudaMemcpy(d_B, h_B, sizeof(float) * n, cudaMemcpyHostToDevice); - CHECK_ERR(err); - - // Calling the kernel - vecAdd<<>>(d_A, d_B, d_C, n); - - // Copying results back to host - err = cudaMemcpy(h_C, d_C, sizeof(float) * n, cudaMemcpyDeviceToHost); - CHECK_ERR(err); - - EXPECT_EQ(h_C[0], 1.0); - for (int i = 1; i < n - 1; ++i) { - EXPECT_EQ(h_C[i], 11.0); - } - EXPECT_EQ(h_C[9], 1.0); -} diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 2ec7b055994b019cd81af191a6b9cf511bc83489..5f8ad159517ef4deaa8c241cf8b13073228022b9 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -13,10 +13,11 @@ limitations under the License. */ #include "paddle/framework/enforce.h" #ifndef PADDLE_ONLY_CPU -#include "paddle/platform/cuda.h" #include "paddle/platform/dynload/cublas.h" #include "paddle/platform/dynload/cudnn.h" #include "paddle/platform/dynload/curand.h" +#include "paddle/platform/error.h" +#include "paddle/platform/gpu_info.h" #define EIGEN_USE_GPU #endif #include diff --git a/paddle/platform/error.h b/paddle/platform/error.h new file mode 100644 index 0000000000000000000000000000000000000000..93424bb61096503a4843da7942853a113f612e3b --- /dev/null +++ b/paddle/platform/error.h @@ -0,0 +1,87 @@ +#pragma once + +#include +#include +#include + +#ifndef PADDLE_ONLY_CPU + +#include +#include +#include +#include +#include + +#endif // PADDLE_ONLY_CPU + +namespace paddle { +namespace platform { + +#ifndef PADDLE_ONLY_CPU + +inline void throw_on_error(cudaError_t e, const char* message) { + if (e) { + throw thrust::system_error(e, thrust::cuda_category(), message); + } +} + +inline void throw_on_error(curandStatus_t stat, const char* message) { + if (stat != CURAND_STATUS_SUCCESS) { + throw thrust::system_error(cudaErrorLaunchFailure, thrust::cuda_category(), + message); + } +} + +inline void throw_on_error(cudnnStatus_t stat, const char* message) { + std::stringstream ss; + if (stat == CUDNN_STATUS_SUCCESS) { + return; + } else { + ss << cudnnGetErrorString(stat); + ss << ", " << message; + throw std::runtime_error(ss.str()); + } +} + +inline void throw_on_error(cublasStatus_t stat, const char* message) { + std::stringstream ss; + if (stat == CUBLAS_STATUS_SUCCESS) { + return; + } else if (stat == CUBLAS_STATUS_NOT_INITIALIZED) { + ss << "CUBLAS: not initialized"; + } else if (stat == CUBLAS_STATUS_ALLOC_FAILED) { + ss << "CUBLAS: alloc failed"; + } else if (stat == CUBLAS_STATUS_INVALID_VALUE) { + ss << "CUBLAS: invalid value"; + } else if (stat == CUBLAS_STATUS_ARCH_MISMATCH) { + ss << "CUBLAS: arch mismatch"; + } else if (stat == CUBLAS_STATUS_MAPPING_ERROR) { + ss << "CUBLAS: mapping error"; + } else if (stat == CUBLAS_STATUS_EXECUTION_FAILED) { + ss << "CUBLAS: execution failed"; + } else if (stat == CUBLAS_STATUS_INTERNAL_ERROR) { + ss << "CUBLAS: internal error"; + } else if (stat == CUBLAS_STATUS_NOT_SUPPORTED) { + ss << "CUBLAS: not supported"; + } else if (stat == CUBLAS_STATUS_LICENSE_ERROR) { + ss << "CUBLAS: license error"; + } + ss << ", " << message; + throw std::runtime_error(ss.str()); +} + +inline void throw_on_error(cublasStatus_t stat) { + const char* message = ""; + throw_on_error(stat, message); +} + +#endif // PADDLE_ONLY_CPU + +inline void throw_on_error(int stat, const char* message) { + if (stat) { + throw std::runtime_error(message + (", stat = " + std::to_string(stat))); + } +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..a1383d3524aedf834c329425419b989d47668bea --- /dev/null +++ b/paddle/platform/gpu_info.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/platform/gpu_info.h" +#include "gflags/gflags.h" +#include "paddle/platform/error.h" + +DEFINE_double(fraction_of_gpu_memory_to_use, 0.95, + "Default use 95% of GPU memory for PaddlePaddle," + "reserve the rest for page tables, etc"); + +namespace paddle { +namespace platform { + +int GetDeviceCount() { + int count; + throw_on_error( + cudaGetDeviceCount(&count), + "cudaGetDeviceCount failed in paddle::platform::GetDeviceCount"); + return count; +} + +int GetCurrentDeviceId() { + int device_id; + throw_on_error( + cudaGetDevice(&device_id), + "cudaGetDevice failed in paddle::platform::GetCurrentDeviceId"); + return device_id; +} + +void SetDeviceId(int id) { + throw_on_error(cudaSetDevice(id), + "cudaSetDevice failed in paddle::platform::SetDeviceId"); +} + +void GpuMemoryUsage(size_t& available, size_t& total) { + throw_on_error(cudaMemGetInfo(&available, &total), + "cudaMemGetInfo failed in paddle::platform::GetMemoryUsage"); +} + +size_t GpuMaxAllocSize() { + size_t total = 0; + size_t available = 0; + + GpuMemoryUsage(available, total); + + // Reserve the rest for page tables, etc. + return static_cast(total * FLAGS_fraction_of_gpu_memory_to_use); +} + +size_t GpuMinChunkSize() { + // Allow to allocate the minimum chunk size is 256 bytes. + return 1 << 8; +} + +size_t GpuMaxChunkSize() { + size_t total = 0; + size_t available = 0; + + GpuMemoryUsage(available, total); + + // Reserving the rest memory for page tables, etc. + size_t reserving = (1 - FLAGS_fraction_of_gpu_memory_to_use) * total; + + // If available less than minimum chunk size, no usable memory exists. + available = std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(); + + // If available less than reserving, no usable memory exists. + size_t usable = std::max(available, reserving) - reserving; + + return usable; +} + +} // namespace platform +} // namespace paddle diff --git a/paddle/platform/cuda.h b/paddle/platform/gpu_info.h similarity index 54% rename from paddle/platform/cuda.h rename to paddle/platform/gpu_info.h index 96889abf9eb14dd203eb55ffd0b720450323b38e..79e71956bd32e8c253ac4192a04e5903bed1c94a 100644 --- a/paddle/platform/cuda.h +++ b/paddle/platform/gpu_info.h @@ -16,33 +16,31 @@ limitations under the License. */ #ifndef PADDLE_ONLY_CPU -#include -#include +#include namespace paddle { namespace platform { -inline void throw_on_error(cudaError_t e, const char* message) { - if (e) { - throw thrust::system_error(e, thrust::cuda_category(), message); - } -} - -inline int GetDeviceCount(void) { - int count; - throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed"); - return count; -} - -inline int GetCurrentDeviceId(void) { - int device_id; - throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed"); - return device_id; -} - -inline void SetDeviceId(int device_id) { - throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed"); -} +//! Get the total number of GPU devices in system. +int GetDeviceCount(); + +//! Get the current GPU device id in system. +int GetCurrentDeviceId(); + +//! Set the GPU device id for next execution. +void SetDeviceId(int device_id); + +//!Get the memory usage of current GPU device. +void GpuMemoryUsage(size_t& available, size_t& total); + +//! Get the maximum allocation size of current GPU device. +size_t GpuMaxAllocSize(); + +//! Get the minimum chunk size for GPU buddy allocator. +size_t GpuMinChunkSize(); + +//! Get the maximum chunk size for GPU buddy allocator. +size_t GpuMaxChunkSize(); } // namespace platform } // namespace paddle diff --git a/paddle/platform/place.cc b/paddle/platform/place.cc index 0704820aa05079401eb56814d689d6e280311edb..b31515e1f028acac885a506ff1c20479407a05e3 100644 --- a/paddle/platform/place.cc +++ b/paddle/platform/place.cc @@ -1,3 +1,17 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + #include "paddle/platform/place.h" namespace paddle { @@ -7,7 +21,7 @@ namespace detail { class PlacePrinter : public boost::static_visitor<> { public: - PlacePrinter(std::ostream &os) : os_(os) {} + explicit PlacePrinter(std::ostream &os) : os_(os) {} void operator()(const CPUPlace &) { os_ << "CPUPlace"; } void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; } diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index af85fdeecb57729d7fb580ebd4c59c1afc61d61a..8564a5f5fe474dbd55ab3e413f9c2cf93f88e38e 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1 +1 @@ -cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python) +cc_library(paddle_pybind SHARED SRCS pybind.cc DEPS pybind python add_op) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index f9f87acf15a6b62c343cc0e3db9ebc7e0aabb786..c1a025ed0492f10237ee552a9b854f1937aa465c 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -13,12 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include #include #include +#include +#include +#include namespace py = pybind11; namespace pd = paddle::framework; +USE_OP(add_two); + PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of Paddle Paddle"); @@ -43,5 +49,20 @@ All parameter, weight, gradient are variables in Paddle. &pd::Scope::CreateVariable, py::return_value_policy::reference); + //! @note: Be careful! PyBind will return std::string as an unicode, not + //! Python str. If you want a str object, you should cast them in Python. + m.def("get_all_op_protos", []() -> std::vector { + auto& protos = pd::OpRegistry::protos(); + std::vector ret_values; + for (auto it = protos.begin(); it != protos.end(); ++it) { + PADDLE_ENFORCE(it->second.IsInitialized(), + "OpProto must all be initialized"); + ret_values.emplace_back(); + PADDLE_ENFORCE(it->second.SerializeToString(&ret_values.back()), + "Serialize OpProto Error. This could be a bug of Paddle."); + } + return ret_values; + }); + return m.ptr(); } diff --git a/paddle/string/piece.h b/paddle/string/piece.h index db7c3e69804a6a8f0510ba376432fe560ae74442..0272529d1c9b2cb6000a26f1d4d80276d06bf27b 100644 --- a/paddle/string/piece.h +++ b/paddle/string/piece.h @@ -35,7 +35,7 @@ public: // We provide non-explicit singleton constructors so users can // pass in a "const char*" or a "string" wherever a "Piece" - // is expected. These contructors ensure that if data_ is NULL, + // is expected. These constructors ensure that if data_ is NULL, // size_ is 0. Piece(); Piece(const char* d, size_t n); diff --git a/paddle/trainer/TrainerConfigHelper.cpp b/paddle/trainer/TrainerConfigHelper.cpp index 60ac8459a12db801321da4a9d9c1d48ac8bd6d16..133e2be104c6fbfddefd8698d2b6aa8315c56c70 100644 --- a/paddle/trainer/TrainerConfigHelper.cpp +++ b/paddle/trainer/TrainerConfigHelper.cpp @@ -62,11 +62,7 @@ TrainerConfigHelper::TrainerConfigHelper(const TrainerConfig &config) m->conf = config; } -TrainerConfigHelper::~TrainerConfigHelper() { - if (m) { - delete m; - } -} +TrainerConfigHelper::~TrainerConfigHelper() { delete m; } const TrainerConfig &TrainerConfigHelper::getConfig() const { return m->conf; } diff --git a/paddle/utils/DynamicLoader.h b/paddle/utils/DynamicLoader.h index 9b5ad21724afd7176f958619e7e10d12dc08fa49..2e5ff76a06152b6a12818f06baaeaa6a69726ba8 100644 --- a/paddle/utils/DynamicLoader.h +++ b/paddle/utils/DynamicLoader.h @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef DYNAMIC_LOAD_H_ -#define DYNAMIC_LOAD_H_ +#pragma once #include #include @@ -59,5 +58,3 @@ void GetWarpCTCDsoHandle(void** dso_handle); * */ void GetLapackDsoHandle(void** dso_handle); - -#endif // DYNAMIC_LOAD_H_ diff --git a/paddle/utils/ThreadLocal.h b/paddle/utils/ThreadLocal.h index b5e2862546212041a774599ec664b95e56224a07..0a27b8b97b83a9066af23039a317c437ea56777a 100644 --- a/paddle/utils/ThreadLocal.h +++ b/paddle/utils/ThreadLocal.h @@ -51,7 +51,7 @@ template class ThreadLocal { public: ThreadLocal() { - CHECK(pthread_key_create(&threadSpecificKey_, dataDestructor) == 0); + CHECK_EQ(pthread_key_create(&threadSpecificKey_, dataDestructor), 0); } ~ThreadLocal() { pthread_key_delete(threadSpecificKey_); } @@ -65,7 +65,7 @@ public: if (!p && createLocal) { p = new T(); int ret = pthread_setspecific(threadSpecificKey_, p); - CHECK(ret == 0); + CHECK_EQ(ret, 0); } return p; } @@ -79,7 +79,7 @@ public: if (T* q = get(false)) { dataDestructor(q); } - CHECK(pthread_setspecific(threadSpecificKey_, p) == 0); + CHECK_EQ(pthread_setspecific(threadSpecificKey_, p), 0); } /** @@ -112,7 +112,7 @@ private: template class ThreadLocalD { public: - ThreadLocalD() { CHECK(pthread_key_create(&threadSpecificKey_, NULL) == 0); } + ThreadLocalD() { CHECK_EQ(pthread_key_create(&threadSpecificKey_, NULL), 0); } ~ThreadLocalD() { pthread_key_delete(threadSpecificKey_); for (auto t : threadMap_) { @@ -127,7 +127,7 @@ public: T* p = (T*)pthread_getspecific(threadSpecificKey_); if (!p) { p = new T(); - CHECK(pthread_setspecific(threadSpecificKey_, p) == 0); + CHECK_EQ(pthread_setspecific(threadSpecificKey_, p), 0); updateMap(p); } return p; @@ -141,7 +141,7 @@ public: if (T* q = (T*)pthread_getspecific(threadSpecificKey_)) { dataDestructor(q); } - CHECK(pthread_setspecific(threadSpecificKey_, p) == 0); + CHECK_EQ(pthread_setspecific(threadSpecificKey_, p), 0); updateMap(p); } diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index b0524a507bacec6768424045e58bf91305de2d08..78aa0778f8d1dca9fae82f0411be5a00e636cbc9 100755 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -126,6 +126,7 @@ __all__ = [ 'row_conv_layer', 'dropout_layer', 'prelu_layer', + 'gated_unit_layer', ] @@ -5862,7 +5863,7 @@ def prelu_layer(input, :rtype: LayerOutput """ - assert isinstance(input, LayerOutput), 'prelu_layer only accepts one input' + assert isinstance(input, LayerOutput), 'prelu_layer accepts only one input.' assert isinstance(param_attr, ParameterAttribute) l = Layer( @@ -5876,3 +5877,96 @@ def prelu_layer(input, layer_type=LayerType.PRELU, parents=input, size=l.config.size) + + +@wrap_name_default() +@layer_support(ERROR_CLIPPING, DROPOUT) +@wrap_act_default(act=LinearActivation()) +def gated_unit_layer(input, + size, + act=None, + name=None, + gate_attr=None, + gate_param_attr=None, + gate_bias_attr=True, + inproj_attr=None, + inproj_param_attr=None, + inproj_bias_attr=True, + layer_attr=None): + """ + The gated unit layer implements a simple gating mechanism over the input. + The input :math:`X` is first projected into a new space :math:`X'`, and + it is also used to produce a gate weight :math:`\sigma`. Element-wise + prodict between :match:`X'` and :math:`\sigma` is finally returned. + + Reference: + Language Modeling with Gated Convolutional Networks + https://arxiv.org/abs/1612.08083 + + .. math:: + y=\\text{act}(X \cdot W + b)\otimes \sigma(X \cdot V + c) + + The example usage is: + + .. code-block:: python + gated_unit = gated_unit_layer(size=128, input=input_layer)) + + :param input: input for this layer. + :type input: LayerOutput + :param size: output size of the gated unit. + :type size: int + :param act: activation type of the projected input. + :type act: BaseActivation + :param name: name of this layer. + :type name: basestring + :param gate_attr: Attributes to tune the gate output, for example, error + clipping threshold, dropout and so on. See ExtraLayerAttribute for + more details. + :type gate_attr: ExtraLayerAttribute|None + :param gate_param_attr: Attributes to tune the learnable projected matrix + parameter of the gate. + :type gate_param_attr: ParameterAttribute|None + :param gate_bias_attr: Attributes to tune the learnable bias of the gate. + :type gate_bias_attr: ParameterAttribute|None + :param inproj_attr: Attributes to the tune the projected input, for + example, error clipping threshold, dropout and so on. See + ExtraLayerAttribute for more details. + :type inproj_attr: ExtraLayerAttribute|None + :param inproj_param_attr: Attributes to tune the learnable parameter of + the projection of input. + :type inproj_param_attr: ParameterAttribute|None + :param inproj_bias_attr: Attributes to tune the learnable bias of + projection of the input. + :type inproj_bias_attr: ParameterAttribute|None + :param layer_attr: Attributes to tune the final output of the gated unit, + for example, error clipping threshold, dropout and so on. See + ExtraLayerAttribute for more details. + :type layer_attr: ExtraLayerAttribute|None + :return: LayerOutput object. + :rtype: LayerOutput + """ + + assert isinstance( + input, LayerOutput), 'The gated linear unit accepts only one input.' + + input_proj = fc_layer( + input=input, + name="%s_input_proj" % name, + size=size, + act=act, + layer_attr=inproj_attr, + param_attr=inproj_param_attr, + bias_attr=inproj_bias_attr) + + gate = fc_layer( + size=size, + name="%s_gate" % name, + act=SigmoidActivation(), + input=input, + layer_attr=gate_attr, + param_attr=gate_param_attr, + bias_attr=gate_bias_attr) + return mixed_layer( + name="%s_gated_act" % name, + input=dotmul_operator(input_proj, gate), + layer_attr=layer_attr) diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh index 70e342fb79ab51e3376ea6ad8f593c4c3a1fff18..cdf9b2eab733adb173cf33cd6a93ef7b5abefc50 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh @@ -7,6 +7,6 @@ test_rnn_group shared_fc shared_lstm shared_gru test_cost_layers_with_weight test_spp_layer test_bilinear_interp test_maxout test_bi_grumemory math_ops test_seq_concat_reshape test_pad test_smooth_l1 test_multiplex_layer test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_layer -test_recursive_topology) +test_recursive_topology test_gated_unit_layer) export whole_configs=(test_split_datasource) diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_gated_unit_layer.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_gated_unit_layer.protostr new file mode 100644 index 0000000000000000000000000000000000000000..f1e4d894a5fb0040f48bdb5a751c3f0d956c23bb --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_gated_unit_layer.protostr @@ -0,0 +1,106 @@ +type: "nn" +layers { + name: "input" + type: "data" + size: 256 + active_type: "" +} +layers { + name: "__gated_unit_layer_0___input_proj" + type: "fc" + size: 512 + active_type: "tanh" + inputs { + input_layer_name: "input" + input_parameter_name: "___gated_unit_layer_0___input_proj.w0" + } + bias_parameter_name: "___gated_unit_layer_0___input_proj.wbias" + error_clipping_threshold: 100.0 +} +layers { + name: "__gated_unit_layer_0___gate" + type: "fc" + size: 512 + active_type: "sigmoid" + inputs { + input_layer_name: "input" + input_parameter_name: "___gated_unit_layer_0___gate.w0" + } + bias_parameter_name: "___gated_unit_layer_0___gate.wbias" + error_clipping_threshold: 100.0 +} +layers { + name: "__gated_unit_layer_0___gated_act" + type: "mixed" + size: 512 + active_type: "" + inputs { + input_layer_name: "__gated_unit_layer_0___input_proj" + } + inputs { + input_layer_name: "__gated_unit_layer_0___gate" + } + error_clipping_threshold: 100.0 + operator_confs { + type: "dot_mul" + input_indices: 0 + input_indices: 1 + input_sizes: 512 + input_sizes: 512 + output_size: 512 + dotmul_scale: 1 + } +} +parameters { + name: "___gated_unit_layer_0___input_proj.w0" + size: 131072 + initial_mean: 0.0 + initial_std: 0.0001 + dims: 256 + dims: 512 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "___gated_unit_layer_0___input_proj.wbias" + size: 512 + initial_mean: 0.0 + initial_std: 1 + dims: 1 + dims: 512 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "___gated_unit_layer_0___gate.w0" + size: 131072 + initial_mean: 0.0 + initial_std: 0.0001 + dims: 256 + dims: 512 + initial_strategy: 0 + initial_smart: false +} +parameters { + name: "___gated_unit_layer_0___gate.wbias" + size: 512 + initial_mean: 0.0 + initial_std: 1 + dims: 1 + dims: 512 + initial_strategy: 0 + initial_smart: false +} +input_layer_names: "input" +output_layer_names: "__gated_unit_layer_0___gated_act" +sub_models { + name: "root" + layer_names: "input" + layer_names: "__gated_unit_layer_0___input_proj" + layer_names: "__gated_unit_layer_0___gate" + layer_names: "__gated_unit_layer_0___gated_act" + input_layer_names: "input" + output_layer_names: "__gated_unit_layer_0___gated_act" + is_recurrent_layer_group: false +} + diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_gated_unit_layer.py b/python/paddle/trainer_config_helpers/tests/configs/test_gated_unit_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..9dab45519c65b0ca686558ec7fe2064bb9ad8824 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_gated_unit_layer.py @@ -0,0 +1,16 @@ +from paddle.trainer_config_helpers import * + +data = data_layer(name='input', size=256) +glu = gated_unit_layer( + size=512, + input=data, + act=TanhActivation(), + gate_attr=ExtraLayerAttribute(error_clipping_threshold=100.0), + gate_param_attr=ParamAttr(initial_std=1e-4), + gate_bias_attr=ParamAttr(initial_std=1), + inproj_attr=ExtraLayerAttribute(error_clipping_threshold=100.0), + inproj_param_attr=ParamAttr(initial_std=1e-4), + inproj_bias_attr=ParamAttr(initial_std=1), + layer_attr=ExtraLayerAttribute(error_clipping_threshold=100.0)) + +outputs(glu) diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 3ba5c31871807027e452df5d889b3b403e1c6414..3c75ca4c3abf1e94fc00b87f3af51d1cbf6dc430 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -20,7 +20,6 @@ import trainer import event import data_type import topology -import data_feeder import networks import evaluator from . import dataset @@ -31,7 +30,6 @@ import op import pooling import inference import networks -import py_paddle.swig_paddle as api import minibatch import plot import image @@ -47,7 +45,6 @@ __all__ = [ 'data_type', 'attr', 'pooling', - 'data_feeder', 'dataset', 'reader', 'topology', @@ -61,6 +58,7 @@ __all__ = [ def init(**kwargs): + import py_paddle.swig_paddle as api args = [] args_dict = {} # NOTE: append arguments if they are in ENV diff --git a/python/paddle/v2/data_feeder.py b/python/paddle/v2/data_feeder.py index 2698251b9e15046eb14f71c3f5b0546ecbb4a5dd..98dfb85a0ea57050bf8dd8d46fca9574801d8eb3 100644 --- a/python/paddle/v2/data_feeder.py +++ b/python/paddle/v2/data_feeder.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from py_paddle import DataProviderConverter import collections import paddle.trainer.PyDataProvider2 as pydp2 diff --git a/python/paddle/v2/event.py b/python/paddle/v2/event.py index fd6050fa339d280ad54e40128ea6bae25132c873..7589cc9917f26375d595e200245d5ba099bc38d7 100644 --- a/python/paddle/v2/event.py +++ b/python/paddle/v2/event.py @@ -9,8 +9,6 @@ There are: * BeginPass * EndPass """ -import py_paddle.swig_paddle as api - __all__ = [ 'EndIteration', 'BeginIteration', 'BeginPass', 'EndPass', 'TestResult' ] @@ -18,6 +16,7 @@ __all__ = [ class WithMetric(object): def __init__(self, evaluator): + import py_paddle.swig_paddle as api if not isinstance(evaluator, api.Evaluator): raise TypeError("Evaluator should be api.Evaluator type") self.__evaluator__ = evaluator diff --git a/python/paddle/v2/framework/create_op_creation_methods.py b/python/paddle/v2/framework/create_op_creation_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..2fcdfead25414ccf44e9bfa964c83b98c852f6be --- /dev/null +++ b/python/paddle/v2/framework/create_op_creation_methods.py @@ -0,0 +1,11 @@ +import paddle.v2.framework.core as core +import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2 + + +def get_all_op_protos(): + protostrs = core.get_all_op_protos() + ret_values = [] + for pbstr in protostrs: + op_proto = op_proto_pb2.OpProto.FromString(str(pbstr)) + ret_values.append(op_proto) + return ret_values diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index 7023e82b5f08eb49fa1fee27118a7907d58312e2..86fc60f26aee0fbdcf4ac4938d20d26d35df57f6 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -1,2 +1,2 @@ add_python_test(test_framework test_protobuf.py test_scope.py - test_default_scope_funcs.py) + test_default_scope_funcs.py test_op_creation_methods.py) diff --git a/python/paddle/v2/framework/tests/test_op_creation_methods.py b/python/paddle/v2/framework/tests/test_op_creation_methods.py new file mode 100644 index 0000000000000000000000000000000000000000..b205e2cabb99ab08604ab3c3ce073bcb95ec4bb3 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_op_creation_methods.py @@ -0,0 +1,15 @@ +import unittest +import paddle.v2.framework.create_op_creation_methods as creation + + +class TestOpCreationsMethods(unittest.TestCase): + def test_all_protos(self): + all_protos = creation.get_all_op_protos() + self.assertNotEqual(0, len(all_protos)) + + for each in all_protos: + self.assertTrue(each.IsInitialized()) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 34b7308601390a4ccb0c19ef10d2c7a60b3fa576..40134a3270c3579fd2f6a891af66ff241050f60c 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -1,9 +1,7 @@ import numpy -import py_paddle.swig_paddle as api import collections import topology import minibatch -from data_feeder import DataFeeder __all__ = ['infer', 'Inference'] @@ -28,6 +26,7 @@ class Inference(object): """ def __init__(self, output_layer, parameters): + import py_paddle.swig_paddle as api topo = topology.Topology(output_layer) gm = api.GradientMachine.createFromConfigProto( topo.proto(), api.CREATE_MODE_TESTING, [api.PARAMETER_VALUE]) @@ -40,6 +39,7 @@ class Inference(object): self.__data_types__ = topo.data_type() def iter_infer(self, input, feeding=None): + from data_feeder import DataFeeder feeder = DataFeeder(self.__data_types__, feeding) batch_size = len(input) diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py index 390c22ee552c506fde1567efba1326a6d735ad2e..b6ee51cfe899fd0652fd3bf702ddcb440c3c7566 100644 --- a/python/paddle/v2/optimizer.py +++ b/python/paddle/v2/optimizer.py @@ -1,5 +1,3 @@ -import py_paddle.swig_paddle as swig_api - import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils import paddle.trainer_config_helpers.optimizers as v1_optimizers """ @@ -18,6 +16,7 @@ __all__ = [ class Optimizer(object): def __init__(self, **kwargs): + import py_paddle.swig_paddle as swig_api if 'batch_size' in kwargs: del kwargs['batch_size'] # not important for python library. @@ -268,6 +267,7 @@ ModelAverage = v1_optimizers.ModelAverage L2Regularization = v1_optimizers.L2Regularization if __name__ == '__main__': + import py_paddle.swig_paddle as swig_api swig_api.initPaddle('--use_gpu=false') for opt in [ Momentum(), Adam(), Adamax(), AdaGrad(), DecayedAdaGrad(), diff --git a/python/paddle/v2/parameters.py b/python/paddle/v2/parameters.py index bbaf8bfa979fbbf460561ebf7077b75b9c41a11a..a9cba8ca0b1efd4149463f6c7bf2dcdfbea350c9 100644 --- a/python/paddle/v2/parameters.py +++ b/python/paddle/v2/parameters.py @@ -1,5 +1,4 @@ import numpy as np -import py_paddle.swig_paddle as api from paddle.proto.ParameterConfig_pb2 import ParameterConfig import paddle.trainer.config_parser as cp import struct @@ -124,6 +123,7 @@ class Parameters(object): :return: parameter value :rtype: np.ndarray """ + import py_paddle.swig_paddle as api shape = self.get_shape(key) if len(self.__gradient_machines__) == 0: @@ -223,7 +223,7 @@ class Parameters(object): :type gradient_machine: api.GradientMachine :return: """ - + import py_paddle.swig_paddle as api if not isinstance(gradient_machine, api.GradientMachine): raise ValueError("gradient_machine should be api.GradientMachine") @@ -359,6 +359,7 @@ def __copy_parameter_to_gradient_machine__(gradient_machine, name, arr): :return: :rtype: api.Parameter """ + import py_paddle.swig_paddle as api param = __get_parameter_in_gradient_machine__(gradient_machine, name) vec = param.getBuf(api.PARAMETER_VALUE) assert isinstance(vec, api.Vector) diff --git a/python/paddle/v2/trainer.py b/python/paddle/v2/trainer.py index 96c6c4b89a2f2e2c3ecb95213e0e0191b1998f50..92fdf98e9030993cc9f250b2f9e6317073cb49de 100644 --- a/python/paddle/v2/trainer.py +++ b/python/paddle/v2/trainer.py @@ -2,12 +2,6 @@ Module Trainer """ import collections -import gzip -import os - -import py_paddle.swig_paddle as api - -from data_feeder import DataFeeder from topology import Topology from . import event as v2_event from . import optimizer as v2_optimizer @@ -59,6 +53,7 @@ class SGD(object): if not isinstance(update_equation, v2_optimizer.Optimizer): raise TypeError("update equation parameter must be " "paddle.v2.optimizer.Optimizer") + import py_paddle.swig_paddle as api topology = Topology(cost, extra_layers=extra_layers) self.__optimizer__ = update_equation self.__topology__ = topology @@ -124,6 +119,8 @@ class SGD(object): :type feeding: dict|list :return: """ + import py_paddle.swig_paddle as api + from data_feeder import DataFeeder if event_handler is None: event_handler = default_event_handler __check_train_args__(**locals()) @@ -187,6 +184,8 @@ class SGD(object): :type feeding: dict :return: """ + import py_paddle.swig_paddle as api + from data_feeder import DataFeeder feeder = DataFeeder(self.__data_types__, feeding) evaluator = self.__gradient_machine__.makeEvaluator() out_args = api.Arguments.createArguments(0) diff --git a/python/setup.py.in b/python/setup.py.in index 271ee6e5526981ad94710315d1472b0f4069a1aa..b1041f6102a56f5a200aa909e77729095c052f31 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -19,7 +19,8 @@ setup_requires=["requests", "recordio", "matplotlib", "rarfile", - "scipy>=0.19.0"] + "scipy>=0.19.0", + "nltk"] if '${CMAKE_SYSTEM_PROCESSOR}' not in ['arm', 'armv7-a', 'aarch64']: setup_requires+=["opencv-python"]