diff --git a/Dockerfile b/Dockerfile index d24042d63ebbbe1db1ba1600502e3177da021684..0d2f56911442f989c0abfa8fdea924af1be48f9e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,8 +27,9 @@ RUN apt-get update && \ git python-pip python-dev openssh-server bison \ wget unzip tar xz-utils bzip2 gzip coreutils \ curl sed grep graphviz libjpeg-dev zlib1g-dev \ - python-numpy python-matplotlib gcc g++ gfortran \ + python-numpy python-matplotlib gcc g++ \ automake locales clang-format-3.8 swig doxygen cmake \ + liblapack-dev liblapacke-dev \ clang-3.8 llvm-3.8 libclang-3.8-dev && \ apt-get clean -y diff --git a/cmake/cblas.cmake b/cmake/cblas.cmake index 0918e6cc633e7067b8bd2d5c5e1622d4139d4d14..913f711afff3b8f9f77b8da978a3b9e7165d0077 100644 --- a/cmake/cblas.cmake +++ b/cmake/cblas.cmake @@ -33,20 +33,18 @@ find_library(MKL_INTEL_LP64 NAMES mkl_intel_lp64 PATHS ${MKL_ROOT}/lib ${MKL_ROOT}/lib/intel64) - -if(MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) +if(MKL_LAPACK_INC_DIR AND MKL_INC_DIR AND MKL_CORE_LIB AND MKL_SEQUENTIAL_LIB AND MKL_INTEL_LP64) + set(CBLAS_FOUND ON) set(CBLAS_PROVIDER MKL) - set(CBLAS_INC_DIR ${MKL_INC_DIR}) - set(CBLAS_LIBRARIES ${MKL_INTEL_LP64} - ${MKL_SEQUENTIAL_LIB} - ${MKL_CORE_LIB}) + set(CBLAS_INC_DIR ${MKL_INC_DIR} ${MKL_LAPACK_INC_DIR}) + set(CBLAS_LIBRARIES ${MKL_INTEL_LP64} ${MKL_SEQUENTIAL_LIB} ${MKL_CORE_LIB}) + add_definitions(-DPADDLE_USE_MKL) - message(STATUS "Found MKL (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") - set(CBLAS_FOUND ON) - if(${MKL_LAPACK_INC_DIR}) - message(STATUS "Found lapack in MKL (include: ${MKL_LAPACK_INC_DIR})") - endif() - return() # return file. + add_definitions(-DLAPACK_FOUND) + + message(STATUS "Found MKL (include: ${MKL_INC_DIR}, library: ${CBLAS_LIBRARIES})") + message(STATUS "Found lapack in MKL (include: ${MKL_LAPACK_INC_DIR})") + return() endif() ## Then find atlas. @@ -68,20 +66,20 @@ find_path(ATLAS_CLAPACK_INC_DIR NAMES clapack.h PATHS ${ATLAS_INCLUDE_SEARCH_PATHS}) find_library(ATLAS_CBLAS_LIB NAMES cblas libcblas.so.3 PATHS ${ATLAS_LIB_SEARCH_PATHS}) -find_library(ATLAS_LIB NAMES lapack_atlas liblapack_atlas.so.3 +find_library(ATLAS_CLAPACK_LIB NAMES lapack_atlas liblapack_atlas.so.3 PATHS ${ATLAS_LIB_SEARCH_PATHS}) -if(ATLAS_INC_DIR AND ATLAS_CBLAS_LIB AND ATLAS_LIB AND NOT CBLAS_FOUND) +if(ATLAS_CLAPACK_INC_DIR AND ATLAS_INC_DIR AND ATLAS_CBLAS_LIB AND ATLAS_CLAPACK_LIB) + set(CBLAS_FOUND ON) set(CBLAS_PROVIDER ATLAS) - set(CBLAS_INC_DIR ${ATLAS_INC_DIR}) - set(CBLAS_LIBRARIES ${ATLAS_LIB} ${ATLAS_CBLAS_LIB}) + set(CBLAS_INC_DIR ${ATLAS_INC_DIR} ${ATLAS_CLAPACK_INC_DIR}) + set(CBLAS_LIBRARIES ${ATLAS_CLAPACK_LIB} ${ATLAS_CBLAS_LIB}) + add_definitions(-DPADDLE_USE_ATLAS) - message(STATUS "Found ATLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") - set(CBLAS_FOUND ON) - if(ATLAS_CLAPACK_INC_DIR) - set(CBLAS_INC_DIR ${CBLAS_INC_DIR} ${ATLAS_CLAPACK_INC_DIR}) - message(STATUS "Found lapack in ATLAS (include: ${ATLAS_CLAPACK_INC_DIR})") - endif() + add_definitions(-DLAPACK_FOUND) + + message(STATUS "Found ATLAS (include: ${ATLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + message(STATUS "Found lapack in ATLAS (include: ${ATLAS_CLAPACK_INC_DIR})") return() endif() @@ -106,15 +104,17 @@ find_path(OPENBLAS_LAPACKE_INC_DIR NAMES lapacke.h find_library(OPENBLAS_LIB NAMES openblas PATHS ${OPENBLAS_LIB_SEARCH_PATHS}) -if(OPENBLAS_INC_DIR AND OPENBLAS_LIB) +if(OPENBLAS_LAPACKE_INC_DIR AND OPENBLAS_INC_DIR AND OPENBLAS_LIB) + set(CBLAS_FOUND ON) set(CBLAS_PROVIDER OPENBLAS) - set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR}) + set(CBLAS_INC_DIR ${OPENBLAS_INC_DIR} ${OPENBLAS_LAPACKE_INC_DIR}) set(CBLAS_LIBRARIES ${OPENBLAS_LIB}) - message(STATUS "Found OpenBLAS (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") - set(CBLAS_FOUND ON) - if(OPENBLAS_LAPACKE_INC_DIR) - message(STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})") - endif() + + add_definitions(-DPADDLE_USE_OPENBLAS) + add_definitions(-DLAPACK_FOUND) + + message(STATUS "Found OpenBLAS (include: ${OPENBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") + message(STATUS "Found lapack in OpenBLAS (include: ${OPENBLAS_LAPACKE_INC_DIR})") return() endif() @@ -143,9 +143,10 @@ find_library(REFERENCE_CBLAS_LIBRARY NAMES cblas PATHS ${REFERENCE_CBLAS_LIB_SEARCH_PATHS}) if (REFERENCE_CBLAS_INCLUDE_DIR AND REFERENCE_CBLAS_LIBRARY) + set(CBLAS_FOUND ON) set(CBLAS_PROVIDER REFERENCE) set(CBLAS_INC_DIR ${REFERENCE_CBLAS_INCLUDE_DIR}) set(CBLAS_LIBRARIES ${REFERENCE_CBLAS_LIBRARY}) - message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBS})") - set(CBLAS_FOUND ON) + add_definitions(-DPADDLE_USE_REFERENCE_CBLAS) + message(STATUS "Found reference-cblas (include: ${CBLAS_INC_DIR}, library: ${CBLAS_LIBRARIES})") endif() diff --git a/cmake/external/openblas.cmake b/cmake/external/openblas.cmake index 46398b22c27ae22abf261d61807c6b10becfff36..18ac74aa6f7531c4001fe91960f8332619c99342 100644 --- a/cmake/external/openblas.cmake +++ b/cmake/external/openblas.cmake @@ -36,20 +36,10 @@ IF(NOT ${CBLAS_FOUND}) INSTALL_DIR ${CBLAS_INSTALL_DIR} BUILD_IN_SOURCE 1 BUILD_COMMAND ${CMAKE_MAKE_PROGRAM} FC=${CMAKE_Fortran_COMPILER} CC=${CMAKE_C_COMPILER} HOSTCC=${CMAKE_C_COMPILER} NO_LAPACK=1 DYNAMIC_ARCH=1 NO_SHARED=1 libs netlib - INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 PREFIX= + INSTALL_COMMAND ${CMAKE_MAKE_PROGRAM} install NO_SHARED=1 NO_LAPACK=1 PREFIX= UPDATE_COMMAND "" CONFIGURE_COMMAND "" ) - - ExternalProject_Add_Step( - openblas lapacke_install - COMMAND ${CMAKE_COMMAND} -E copy "${CBLAS_SOURCES_DIR}/src/openblas/lapack-netlib/LAPACKE/include/lapacke_mangling_with_flags.h" "${CBLAS_INSTALL_DIR}/include/lapacke_mangling.h" - COMMAND ${CMAKE_COMMAND} -E copy "${CBLAS_SOURCES_DIR}/src/openblas/lapack-netlib/LAPACKE/include/lapacke.h" "${CBLAS_INSTALL_DIR}/include/lapacke.h" - COMMAND ${CMAKE_COMMAND} -E copy "${CBLAS_SOURCES_DIR}/src/openblas/lapack-netlib/LAPACKE/include/lapacke_config.h" "${CBLAS_INSTALL_DIR}/include/lapacke_config.h" - COMMAND ${CMAKE_COMMAND} -E copy "${CBLAS_SOURCES_DIR}/src/openblas/lapack-netlib/LAPACKE/include/lapacke_utils.h" "${CBLAS_INSTALL_DIR}/include/lapacke_utils.h" - DEPENDEES install - ) - LIST(APPEND external_project_dependencies openblas) ENDIF(NOT ${CBLAS_FOUND}) diff --git a/cmake/external/python.cmake b/cmake/external/python.cmake index fc66d6b2154b73d8f6a259ecfa55c7ef5ce999fa..f4d0daab06c9fcf17f4af59c25f62b415074a52f 100644 --- a/cmake/external/python.cmake +++ b/cmake/external/python.cmake @@ -16,11 +16,13 @@ INCLUDE(ExternalProject) INCLUDE(python_module) FIND_PACKAGE(PythonInterp 2.7) -FIND_PACKAGE(PythonLibs 2.7) +IF(WITH_PYTHON) + FIND_PACKAGE(PythonLibs 2.7) +ENDIF(WITH_PYTHON) SET(py_env "") SET(USE_VIRTUALENV_FOR_TEST 1) -IF(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) +IF(PYTHONINTERP_FOUND) find_python_module(pip REQUIRED) find_python_module(numpy REQUIRED) find_python_module(wheel REQUIRED) @@ -30,7 +32,7 @@ IF(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) MESSAGE(FATAL_ERROR "Found Python Protobuf ${PY_GOOGLE.PROTOBUF_VERSION} < 3.0.0, " "please use pip to upgrade protobuf. pip install -U protobuf") ENDIF() -ELSE(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) +ELSE(PYTHONINTERP_FOUND) MESSAGE(FATAL_ERROR "Please install python 2.7 before building PaddlePaddle.") ##################################### PYTHON ######################################## SET(PYTHON_SOURCES_DIR ${THIRD_PARTY_PATH}/python) @@ -217,7 +219,7 @@ ELSE(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) LIST(APPEND external_project_dependencies python setuptools six cython wheel python-protobuf numpy) -ENDIF(PYTHONLIBS_FOUND AND PYTHONINTERP_FOUND) +ENDIF(PYTHONINTERP_FOUND) IF(WITH_PYTHON) INCLUDE_DIRECTORIES(${PYTHON_INCLUDE_DIR}) diff --git a/cmake/util.cmake b/cmake/util.cmake index 966e0a7bf60fdeaac575b9d9c19c7095e104b13c..b828eef322bc570c07f5c357353641117a094c16 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -149,7 +149,6 @@ endfunction() # Create a python unittest using run_python_tests.sh, # which takes care of making correct running environment function(add_python_test TEST_NAME) - message("PYTHON: ${PYTHON_EXECUTABLE}") add_test(NAME ${TEST_NAME} COMMAND bash ${PROJ_ROOT}/paddle/scripts/run_python_tests.sh ${USE_VIRTUALENV_FOR_TEST} ${PYTHON_EXECUTABLE} ${ARGN} diff --git a/paddle/capi/Main.cpp b/paddle/capi/Main.cpp index 7f24561e9aafc1e900f6371ad3c7e5a45033a9ef..78c43949dfe325d0e1a6ba10ae51cb7b858f6c52 100644 --- a/paddle/capi/Main.cpp +++ b/paddle/capi/Main.cpp @@ -25,7 +25,6 @@ limitations under the License. */ static void initPaddle(int argc, char** argv) { paddle::initMain(argc, argv); paddle::initPython(argc, argv); - feenableexcept(FE_INVALID | FE_DIVBYZERO | FE_OVERFLOW); } extern "C" { diff --git a/paddle/gserver/tests/img_conv_cudnn.py b/paddle/gserver/tests/img_conv_cudnn.py new file mode 100644 index 0000000000000000000000000000000000000000..3934607fa41f9b6d401f1c9ff4aec6715786799b --- /dev/null +++ b/paddle/gserver/tests/img_conv_cudnn.py @@ -0,0 +1,32 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) +data = data_layer(name="input", size=8 * 16 * 16) +conv = img_conv_layer( + input=data, + filter_size=1, + filter_size_y=1, + num_channels=8, + num_filters=16, + stride=1, + bias_attr=True, + act=LinearActivation(), + groups=2, + layer_type="cudnn_conv") + +outputs(conv) diff --git a/paddle/gserver/tests/img_conv_exconv.py b/paddle/gserver/tests/img_conv_exconv.py new file mode 100644 index 0000000000000000000000000000000000000000..ad5a8ba2bde17000ca3d7057c6f399ae28d938b0 --- /dev/null +++ b/paddle/gserver/tests/img_conv_exconv.py @@ -0,0 +1,32 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) +data = data_layer(name="input", size=8 * 16 * 16) +conv = img_conv_layer( + input=data, + filter_size=1, + filter_size_y=1, + num_channels=8, + num_filters=16, + stride=1, + bias_attr=True, + act=LinearActivation(), + groups=2, + layer_type="exconv") + +outputs(conv) diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index 4db30f37a5bc92d4348caed0aebdd8a589b55712..40e662b22bac0a2d22aea31fe99b11695bac3f57 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -258,12 +258,15 @@ TEST(Compare, img_conv) { // Test cudnn_conv and exconv give the same result TEST(Compare, img_conv2) { - std::string config_file_a = "./gserver/tests/img_conv_a.conf"; - std::string config_file_b = "./gserver/tests/img_conv_c.conf"; + std::string config_file_a = "./gserver/tests/img_conv_cudnn.py"; + std::string config_file_b = "./gserver/tests/img_conv_exconv.py"; bool useGpu = FLAGS_use_gpu; + double eps = FLAGS_checkgrad_eps; FLAGS_use_gpu = true; + FLAGS_checkgrad_eps = 1e-2; compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; + FLAGS_checkgrad_eps = eps; } #endif diff --git a/paddle/math/MathFunctions.cpp b/paddle/math/MathFunctions.cpp index 802a56a0d15a2ebbe6c62350832611373001ac9f..1a3bb432bfb743fe814fa94c0c104bb6bc598cb8 100644 --- a/paddle/math/MathFunctions.cpp +++ b/paddle/math/MathFunctions.cpp @@ -34,6 +34,9 @@ void* lapack_dso_handle = nullptr; // We have to use two levels of macro to do the expansion. // See https://gcc.gnu.org/onlinedocs/cpp/Stringizing.html #define STR(x) #x + +// clang-format off +#ifndef LAPACK_FOUND #define DYNAMIC_LOAD_LAPACK_WRAP(__name) \ struct DynLoad__##__name { \ template \ @@ -46,8 +49,16 @@ void* lapack_dso_handle = nullptr; return reinterpret_cast(p_##__name)(args...); \ } \ } __name; // struct DynLoad__##__name +#else +#define DYNAMIC_LOAD_LAPACK_WRAP(__name) \ + struct DynLoad__##__name { \ + template \ + auto operator()(Args... args) -> decltype(__name(args...)) { \ + return __name(args...); \ + } \ + } __name; // struct DynLoad__##__name +#endif -// clang-format off #ifdef PADDLE_USE_ATLAS #define PADDLE_SGETRF clapack_sgetrf #define PADDLE_DGETRF clapack_dgetrf diff --git a/paddle/math/MathFunctions.h b/paddle/math/MathFunctions.h index c8559eefd8378450fc18c2ba821c65b39c8cc046..8ada0d34c6733d13a45505492909124010c85a91 100644 --- a/paddle/math/MathFunctions.h +++ b/paddle/math/MathFunctions.h @@ -18,17 +18,32 @@ limitations under the License. */ #ifdef PADDLE_USE_MKL #include #include -#else -extern "C" { -#include -} +#endif + #ifdef PADDLE_USE_ATLAS extern "C" { +#include #include } -#else +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include #include #endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf( + int matrix_layout, int m, int n, float* a, int lda, int* ipiv); +int LAPACKE_dgetrf( + int matrix_layout, int m, int n, double* a, int lda, int* ipiv); +int LAPACKE_sgetri( + int matrix_layout, int n, float* a, int lda, const int* ipiv); +int LAPACKE_dgetri( + int matrix_layout, int n, double* a, int lda, const int* ipiv); +} #endif #include diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index 782a9613d8752d1fccf1a59078cea0aaa625f815..5a0dffe086c4e265d17c79dba435b66c0873e3c7 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -236,8 +236,19 @@ TEST(Matrix, unary) { testMatrixTranspose(height, width); testMatrixRotate(height, width); } +#ifdef LAPACK_FOUND // inverse matrix testMatrixInverse(height); +#else + LOG(WARNING) << "Cannot run Matrix Inverse Unit Test.\n" + << "Failed to find lapack library in current system.\n" + << "To address this issue, Please adopt one of the following " + "approaches: \n" + << "1. Simply issue `sudo apt-get install liblapacke-dev` to " + "avoid re-build source code. \n" + << "2. Install MKL/Openblas/ATLAS and re-build PaddlePaddle " + "source code."; +#endif } } diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh index 9739ec9555bec6a3ea5048929684cbd3a683f2d8..b1a8274b1dd3b909660d63b8e6ace7b9f377c827 100644 --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -55,7 +55,8 @@ Building in /paddle/build ... EOF make -j `nproc` if [ ${WITH_TESTING:-OFF} == "ON" ] && [ ${RUN_TEST:-OFF} == "ON" ] ; then - ctest -V -j `nproc` + pip uninstall -y py-paddle paddle || true + ctest -V fi diff --git a/paddle/utils/DynamicLoader.cpp b/paddle/utils/DynamicLoader.cpp index 87c36eae6fbcd0564ee46b5ca0e3e22b5cf04192..76cf3c300616e6961be905d0e54f3b9fac4922a4 100644 --- a/paddle/utils/DynamicLoader.cpp +++ b/paddle/utils/DynamicLoader.cpp @@ -165,8 +165,8 @@ void GetWarpCTCDsoHandle(void** dso_handle) { void GetLapackDsoHandle(void** dso_handle) { #if defined(__APPLE__) || defined(__OSX__) - GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.dylib", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.dylib", dso_handle); #else - GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapack.so", dso_handle); + GetDsoHandleFromSearchPath(FLAGS_lapack_dir, "liblapacke.so", dso_handle); #endif } diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index 7c8f6ea62fcb74700f7356ed4b937a3aaa1c7092..d13850597e034ce538681c7a532e9cb55c996eb6 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os import optimizer import layer import activation @@ -42,8 +43,16 @@ __all__ = [ def init(**kwargs): args = [] - for key in kwargs.keys(): - args.append('--%s=%s' % (key, str(kwargs[key]))) + args_dict = {} + # NOTE: append arguments if they are in ENV + for ek, ev in os.environ.iteritems(): + if ek.startswith("PADDLE_INIT_"): + args_dict[ek.replace("PADDLE_INIT_", "").lower()] = str(ev) + + args_dict.update(kwargs) + # NOTE: overwrite arguments from ENV if it is in kwargs + for key in args_dict.keys(): + args.append('--%s=%s' % (key, str(args_dict[key]))) api.initPaddle(*args) diff --git a/python/paddle/v2/config_base.py b/python/paddle/v2/config_base.py index b0e8da563e0d65d534d3f224fe5f1c39a67eeb54..acda778e0aee1a8339ad6bd0d719868151d4fabe 100644 --- a/python/paddle/v2/config_base.py +++ b/python/paddle/v2/config_base.py @@ -16,6 +16,7 @@ import collections import re from paddle.trainer_config_helpers.default_decorators import wrap_name_default import paddle.trainer_config_helpers as conf_helps +from topology import Topology class LayerType(type): @@ -161,6 +162,10 @@ class Layer(object): """ return self.__context__[self.context_name()].size + def attr(self): + topo = Topology(self) + return topo.get_layer_proto(self.name) + def __convert_to_v2__(method_name, parent_names, diff --git a/python/paddle/v2/inference.py b/python/paddle/v2/inference.py index 1fea7917e1553f63a6e6df50e1a8c6473018085f..b4bb38496937bb6fb520334331c619f9b6f64b51 100644 --- a/python/paddle/v2/inference.py +++ b/python/paddle/v2/inference.py @@ -5,15 +5,22 @@ import topology import minibatch from data_feeder import DataFeeder -__all__ = ['infer'] +__all__ = ['infer', 'Inference'] class Inference(object): """ Inference combines neural network output and parameters together to do inference. + + .. code-block:: python + + inferer = Inference(output_layer=prediction, parameters=parameters) + for data_batch in batches: + print inferer.infer(data_batch) + - :param outptut_layer: The neural network that should be inferenced. + :param output_layer: The neural network that should be inferenced. :type output_layer: paddle.v2.config_base.Layer or the sequence of paddle.v2.config_base.Layer :param parameters: The parameters dictionary. @@ -56,8 +63,14 @@ class Inference(object): item = [each_result[each_field] for each_field in field] yield item - def infer(self, field='value', **kwargs): + def infer(self, input, field='value', **kwargs): + """ + Infer a data by model. + :param input: input data batch. Should be python iterable object. + :param field: output field. + """ retv = None + kwargs['input'] = input for result in self.iter_infer_field(field=field, **kwargs): if retv is None: retv = [[] for i in xrange(len(result))] @@ -79,7 +92,7 @@ def infer(output_layer, parameters, input, feeding=None, field='value'): .. code-block:: python - result = paddle.infer(outptut_layer=prediction, + result = paddle.infer(output_layer=prediction, parameters=parameters, input=SomeData) print result diff --git a/python/paddle/v2/topology.py b/python/paddle/v2/topology.py index ff28c85c53dc8255b6ad5e3975b07f72a9a64e4b..1e46e4973f467a017de3d2b45186690af16dd123 100644 --- a/python/paddle/v2/topology.py +++ b/python/paddle/v2/topology.py @@ -130,6 +130,12 @@ class Topology(object): return [(nm, data_layers[nm].type) for nm in self.proto().input_layer_names] + def get_layer_proto(self, name): + for layer in self.__model_config__.layers: + if layer.name == name: + return layer + return None + def __check_layer_type__(layer): if not isinstance(layer, v2_layer.LayerV2):