提交 3ca5750b 编写于 作者: W wangyanfei01

fix conficts

...@@ -9,3 +9,6 @@ build/ ...@@ -9,3 +9,6 @@ build/
.pydevproject .pydevproject
Makefile Makefile
.test_env/ .test_env/
*~
bazel-*
[submodule "warp-ctc"]
path = warp-ctc
url = https://github.com/baidu-research/warp-ctc.git
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
sha: c25201a00e6b0514370501050cf2a8538ac12270 sha: c25201a00e6b0514370501050cf2a8538ac12270
hooks: hooks:
- id: remove-crlf - id: remove-crlf
files: (?!.*warp-ctc)^.*$
- repo: https://github.com/reyoung/mirrors-yapf.git - repo: https://github.com/reyoung/mirrors-yapf.git
sha: v0.13.2 sha: v0.13.2
hooks: hooks:
- id: yapf - id: yapf
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ # Bazel BUILD files follow Python syntax.
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
sha: 7539d8bd1a00a3c1bfd34cdb606d3a6372e83469 sha: 7539d8bd1a00a3c1bfd34cdb606d3a6372e83469
hooks: hooks:
...@@ -13,6 +15,7 @@ ...@@ -13,6 +15,7 @@
- id: check-merge-conflict - id: check-merge-conflict
- id: check-symlinks - id: check-symlinks
- id: detect-private-key - id: detect-private-key
files: (?!.*warp-ctc)^.*$
- id: end-of-file-fixer - id: end-of-file-fixer
- repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git - repo: https://github.com/PaddlePaddle/clang-format-pre-commit-hook.git
sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29 sha: 28c0ea8a67a3e2dbbf4822ef44e85b63a0080a29
......
...@@ -8,10 +8,13 @@ os: ...@@ -8,10 +8,13 @@ os:
env: env:
- JOB=DOCS - JOB=DOCS
- JOB=BUILD_AND_TEST - JOB=BUILD_AND_TEST
- JOB=PRE_COMMIT
matrix: matrix:
exclude: exclude:
- os: osx - os: osx
env: JOB=DOCS # Only generate documentation in linux env: JOB=DOCS # Only generate documentation in linux.
- os: osx
env: JOB=PRE_COMMIT # Only check pre-commit hook in linux
addons: addons:
apt: apt:
...@@ -26,10 +29,6 @@ addons: ...@@ -26,10 +29,6 @@ addons:
- python-pip - python-pip
- python2.7-dev - python2.7-dev
- m4 - m4
- libprotobuf-dev
- doxygen
- protobuf-compiler
- python-protobuf
- python-numpy - python-numpy
- python-wheel - python-wheel
- libgoogle-glog-dev - libgoogle-glog-dev
...@@ -39,18 +38,25 @@ addons: ...@@ -39,18 +38,25 @@ addons:
- lcov - lcov
- graphviz - graphviz
- swig - swig
- clang-format-3.8
- automake
- libtool
before_install: before_install:
- | - |
if [ ${JOB} == "BUILD_AND_TEST" ]; then if [ ${JOB} == "BUILD_AND_TEST" ]; then
if ! git diff --name-only $TRAVIS_COMMIT_RANGE | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)' local change_list=`git diff --name-only $TRAVIS_COMMIT_RANGE`
if [ $? -eq 0 ]; then # if git diff return no zero, then rerun unit test.
if ! echo ${change_list} | grep -qvE '(\.md$)|(\.rst$)|(\.jpg$)|(\.png$)'
then then
echo "Only markdown docs were updated, stopping build process." echo "Only markdown docs were updated, stopping build process."
exit exit
fi fi
fi fi
fi
- if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then sudo paddle/scripts/travis/before_install.linux.sh; fi - if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then sudo paddle/scripts/travis/before_install.linux.sh; fi
- if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi - if [[ "$TRAVIS_OS_NAME" == "osx" ]]; then paddle/scripts/travis/before_install.osx.sh; fi
- pip install wheel protobuf sphinx breathe recommonmark virtualenv numpy sphinx_rtd_theme - if [[ "$JOB" == "PRE_COMMIT" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
- pip install wheel protobuf sphinx recommonmark virtualenv numpy sphinx_rtd_theme pre-commit requests==2.9.2 LinkChecker
script: script:
- paddle/scripts/travis/main.sh - paddle/scripts/travis/main.sh
notifications: notifications:
......
...@@ -25,8 +25,8 @@ find_package(ZLIB REQUIRED) ...@@ -25,8 +25,8 @@ find_package(ZLIB REQUIRED)
find_package(NumPy REQUIRED) find_package(NumPy REQUIRED)
find_package(Threads REQUIRED) find_package(Threads REQUIRED)
find_package(AVX QUIET) find_package(AVX QUIET)
find_package(Glog) find_package(Glog REQUIRED)
find_package(Gflags QUIET) find_package(Gflags REQUIRED)
find_package(GTest) find_package(GTest)
find_package(Sphinx) find_package(Sphinx)
find_package(Doxygen) find_package(Doxygen)
...@@ -40,8 +40,6 @@ option(WITH_AVX "Compile PaddlePaddle with avx intrinsics" ${AVX_FOUND}) ...@@ -40,8 +40,6 @@ option(WITH_AVX "Compile PaddlePaddle with avx intrinsics" ${AVX_FOUND})
option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON) option(WITH_PYTHON "Compile PaddlePaddle with python interpreter" ON)
option(WITH_STYLE_CHECK "Style Check for PaddlePaddle" ${PYTHONINTERP_FOUND}) option(WITH_STYLE_CHECK "Style Check for PaddlePaddle" ${PYTHONINTERP_FOUND})
option(WITH_RDMA "Compile PaddlePaddle with rdma support" OFF) option(WITH_RDMA "Compile PaddlePaddle with rdma support" OFF)
option(WITH_GLOG "Compile PaddlePaddle use glog, otherwise use a log implement internally" ${LIBGLOG_FOUND})
option(WITH_GFLAGS "Compile PaddlePaddle use gflags, otherwise use a flag implement internally" ${GFLAGS_FOUND})
option(WITH_TIMER "Compile PaddlePaddle use timer" OFF) option(WITH_TIMER "Compile PaddlePaddle use timer" OFF)
option(WITH_PROFILER "Compile PaddlePaddle use gpu profiler" OFF) option(WITH_PROFILER "Compile PaddlePaddle use gpu profiler" OFF)
option(WITH_TESTING "Compile and run unittest for PaddlePaddle" ${GTEST_FOUND}) option(WITH_TESTING "Compile and run unittest for PaddlePaddle" ${GTEST_FOUND})
...@@ -51,13 +49,7 @@ option(ON_TRAVIS "Running test on travis-ci or not." OFF) ...@@ -51,13 +49,7 @@ option(ON_TRAVIS "Running test on travis-ci or not." OFF)
option(ON_COVERALLS "Generating code coverage data on coveralls or not." OFF) option(ON_COVERALLS "Generating code coverage data on coveralls or not." OFF)
option(COVERALLS_UPLOAD "Uploading the generated coveralls json." ON) option(COVERALLS_UPLOAD "Uploading the generated coveralls json." ON)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
"Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
endif()
include(enableCXX11)
include(cpplint) include(cpplint)
include(ccache) include(ccache)
if(WITH_RDMA) if(WITH_RDMA)
...@@ -75,26 +67,21 @@ include(coveralls) ...@@ -75,26 +67,21 @@ include(coveralls)
find_package(Git REQUIRED) find_package(Git REQUIRED)
# version.cmake will get the current PADDLE_VERSION # version.cmake will get the current PADDLE_VERSION
include(version) include(version)
add_definitions(-DPADDLE_VERSION=\"${PADDLE_VERSION}\") add_definitions(-DPADDLE_VERSION=${PADDLE_VERSION})
if(NOT WITH_GPU) if(NOT WITH_GPU)
add_definitions(-DPADDLE_ONLY_CPU) add_definitions(-DPADDLE_ONLY_CPU)
add_definitions(-DHPPL_STUB_FUNC) add_definitions(-DHPPL_STUB_FUNC)
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
else() else()
if(${CUDA_VERSION_MAJOR} GREATER 6) if(${CUDA_VERSION_MAJOR} VERSION_LESS 7)
if(COMPILER_SUPPORT_CXX11) message(FATAL_ERROR "Paddle need CUDA >= 7.0 to compile")
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11)
endif()
endif() endif()
# TODO(yuyang18): Change it to remove std=c++11 in cuda compile.
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
if(NOT CUDNN_FOUND) if(NOT CUDNN_FOUND)
message(FATAL_ERROR "Paddle need cudnn to compile") message(FATAL_ERROR "Paddle need cudnn to compile")
endif() endif()
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-g -O3 --use_fast_math")
if(WITH_AVX) if(WITH_AVX)
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${AVX_FLAG}") set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${AVX_FLAG}")
...@@ -102,15 +89,15 @@ else() ...@@ -102,15 +89,15 @@ else()
set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SSE3_FLAG}") set(CUDA_NVCC_FLAGS ${CUDA_NVCC_FLAGS} "-Xcompiler ${SSE3_FLAG}")
endif(WITH_AVX) endif(WITH_AVX)
if(WITH_DSO)
add_definitions(-DPADDLE_USE_DSO)
endif(WITH_DSO)
# Include cuda and cudnn # Include cuda and cudnn
include_directories(${CUDNN_INCLUDE_DIR}) include_directories(${CUDNN_INCLUDE_DIR})
include_directories(${CUDA_TOOLKIT_INCLUDE}) include_directories(${CUDA_TOOLKIT_INCLUDE})
endif(NOT WITH_GPU) endif(NOT WITH_GPU)
if(WITH_DSO)
add_definitions(-DPADDLE_USE_DSO)
endif(WITH_DSO)
if(WITH_DOUBLE) if(WITH_DOUBLE)
add_definitions(-DPADDLE_TYPE_DOUBLE) add_definitions(-DPADDLE_TYPE_DOUBLE)
set(ACCURACY double) set(ACCURACY double)
...@@ -147,16 +134,12 @@ else(WITH_RDMA) ...@@ -147,16 +134,12 @@ else(WITH_RDMA)
add_definitions(-DPADDLE_DISABLE_RDMA) add_definitions(-DPADDLE_DISABLE_RDMA)
endif(WITH_RDMA) endif(WITH_RDMA)
if(WITH_GLOG) # glog
add_definitions(-DPADDLE_USE_GLOG) include_directories(${LIBGLOG_INCLUDE_DIR})
include_directories(${LIBGLOG_INCLUDE_DIR})
endif()
if(WITH_GFLAGS) #gflags
add_definitions(-DPADDLE_USE_GFLAGS) add_definitions(-DGFLAGS_NS=${GFLAGS_NAMESPACE})
add_definitions(-DGFLAGS_NS=${GFLAGS_NAMESPACE}) include_directories(${GFLAGS_INCLUDE_DIRS})
include_directories(${GFLAGS_INCLUDE_DIRS})
endif()
if(WITH_TESTING) if(WITH_TESTING)
enable_testing() enable_testing()
...@@ -180,5 +163,4 @@ add_subdirectory(paddle) ...@@ -180,5 +163,4 @@ add_subdirectory(paddle)
add_subdirectory(python) add_subdirectory(python)
if(WITH_DOC) if(WITH_DOC)
add_subdirectory(doc) add_subdirectory(doc)
add_subdirectory(doc_cn)
endif() endif()
./doc/howto/dev/contribute_to_paddle_en.md
Copyright (c) 2016 Baidu, Inc. All Rights Reserved Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
Apache License Apache License
Version 2.0, January 2004 Version 2.0, January 2004
...@@ -188,7 +188,7 @@ Copyright (c) 2016 Baidu, Inc. All Rights Reserved ...@@ -188,7 +188,7 @@ Copyright (c) 2016 Baidu, Inc. All Rights Reserved
same "printed page" as the copyright notice for easier same "printed page" as the copyright notice for easier
identification within third-party archives. identification within third-party archives.
Copyright (c) 2016 Baidu, Inc. All Rights Reserve. Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
......
# External dependency to Google protobuf.
http_archive(
name="protobuf",
url="http://github.com/google/protobuf/archive/v3.1.0.tar.gz",
sha256="0a0ae63cbffc274efb573bdde9a253e3f32e458c41261df51c5dbc5ad541e8f7",
strip_prefix="protobuf-3.1.0")
# External dependency to gtest 1.7.0. This method comes from
# https://www.bazel.io/versions/master/docs/tutorial/cpp.html.
new_http_archive(
name="gtest",
url="https://github.com/google/googletest/archive/release-1.7.0.zip",
sha256="b58cb7547a28b2c718d1e38aee18a3659c9e3ff52440297e965f5edffe34b6d0",
build_file="third_party/gtest.BUILD",
strip_prefix="googletest-release-1.7.0")
# External dependency to gflags. This method comes from
# https://github.com/gflags/example/blob/master/WORKSPACE.
new_git_repository(
name="gflags",
tag="v2.2.0",
remote="https://github.com/gflags/gflags.git",
build_file="third_party/gflags.BUILD")
# External dependency to glog. This method comes from
# https://github.com/reyoung/bazel_playground/blob/master/WORKSPACE
new_git_repository(
name="glog",
remote="https://github.com/google/glog.git",
commit="b6a5e0524c28178985f0d228e9eaa43808dbec3c",
build_file="third_party/glog.BUILD")
...@@ -25,4 +25,3 @@ test 4 2 256 512 ...@@ -25,4 +25,3 @@ test 4 2 256 512
test 4 2 512 128 test 4 2 512 128
test 4 2 512 256 test 4 2 512 256
test 4 2 512 512 test 4 2 512 512
...@@ -72,6 +72,7 @@ function( Sphinx_add_target target_name builder conf cache source destination ) ...@@ -72,6 +72,7 @@ function( Sphinx_add_target target_name builder conf cache source destination )
${source} ${source}
${destination} ${destination}
COMMENT "Generating sphinx documentation: ${builder}" COMMENT "Generating sphinx documentation: ${builder}"
COMMAND ln -sf ${destination}/index_*.html ${destination}/index.html
) )
set_property( set_property(
......
...@@ -14,13 +14,9 @@ if(WITH_STYLE_CHECK) ...@@ -14,13 +14,9 @@ if(WITH_STYLE_CHECK)
find_package(PythonInterp REQUIRED) find_package(PythonInterp REQUIRED)
endif() endif()
if(WITH_GLOG) find_package(Glog REQUIRED)
find_package(Glog REQUIRED)
endif()
if(WITH_GFLAGS) find_package(Gflags REQUIRED)
find_package(Gflags REQUIRED)
endif()
if(WITH_TESTING) if(WITH_TESTING)
find_package(GTest REQUIRED) find_package(GTest REQUIRED)
...@@ -28,9 +24,7 @@ endif() ...@@ -28,9 +24,7 @@ endif()
if(WITH_DOC) if(WITH_DOC)
find_package(Sphinx REQUIRED) find_package(Sphinx REQUIRED)
find_package(Doxygen REQUIRED)
find_python_module(recommonmark REQUIRED) find_python_module(recommonmark REQUIRED)
find_python_module(breathe REQUIRED)
endif() endif()
if(WITH_SWIG_PY) if(WITH_SWIG_PY)
......
# Enable C++ 11 for GCC.
# NOTE: It's only tested for gcc.
include(CheckCXXCompilerFlag)
CHECK_CXX_COMPILER_FLAG("-std=c++11" COMPILER_SUPPORT_CXX11)
CHECK_CXX_COMPILER_FLAG("-std=c++0x" COMPILER_SUPPORT_CXX0X)
if(COMPILER_SUPPORT_CXX11)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
elseif(COMPILER_SUPPORT_CXX0X)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++0x")
else()
message(FATAL_ERROR "Your compiler must support c++11")
endif()
\ No newline at end of file
...@@ -2,6 +2,37 @@ ...@@ -2,6 +2,37 @@
include(CheckCXXCompilerFlag) include(CheckCXXCompilerFlag)
include(CheckCCompilerFlag) include(CheckCCompilerFlag)
include(CheckCXXSymbolExists) include(CheckCXXSymbolExists)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "RelWithDebInfo" CACHE STRING
"Choose the type of build, options are: Debug Release RelWithDebInfo MinSizeRel"
FORCE)
endif()
function(CheckCompilerCXX11Flag)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 4.8)
message(FATAL_ERROR "Unsupported GCC version. GCC >= 4.8 required.")
endif()
elseif(CMAKE_CXX_COMPILER_ID STREQUAL "AppleClang" OR CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
# https://gist.github.com/yamaya/2924292
if(APPLE) # cmake < 3.0 compiler id "Clang" on Mac OS X
if(${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 5.1)
message(FATAL_ERROR "Unsupported AppleClang version. AppleClang >= 5.1 required.")
endif()
else()
if (${CMAKE_CXX_COMPILER_VERSION} VERSION_LESS 3.3)
message(FATAL_ERROR "Unsupported Clang version. Clang >= 3.3 required.")
endif()
endif()
endif()
endfunction()
CheckCompilerCXX11Flag()
LIST(APPEND CMAKE_CXX_FLAGS -std=c++11)
# safe_set_flag # safe_set_flag
# #
# Set a compile flag only if compiler is support # Set a compile flag only if compiler is support
...@@ -41,9 +72,7 @@ macro(safe_set_nvflag flag_name) ...@@ -41,9 +72,7 @@ macro(safe_set_nvflag flag_name)
CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name}) CHECK_C_COMPILER_FLAG(${flag_name} C_COMPILER_SUPPORT_FLAG_${safe_name})
set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name}) set(safe_name C_COMPILER_SUPPORT_FLAG_${safe_name})
if(${safe_name}) if(${safe_name})
set(CUDA_NVCC_FLAGS LIST(APPEND CUDA_NVCC_FLAGS -Xcompiler ${flag_name})
--compiler-options;${flag_name}
${CUDA_NVCC_FLAGS})
endif() endif()
endmacro() endmacro()
...@@ -109,8 +138,22 @@ foreach(flag ${GPU_COMMON_FLAGS}) ...@@ -109,8 +138,22 @@ foreach(flag ${GPU_COMMON_FLAGS})
endforeach() endforeach()
set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc. # Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here. # So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
LIST(APPEND CUDA_NVCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
endif()
function(specify_cuda_arch cuda_version cuda_arch) function(specify_cuda_arch cuda_version cuda_arch)
if(${cuda_version} VERSION_GREATER "8.0") if(${cuda_version} VERSION_GREATER "8.0")
......
...@@ -65,7 +65,7 @@ endmacro() ...@@ -65,7 +65,7 @@ endmacro()
# link_paddle_exe # link_paddle_exe
# add paddle library for a paddle executable, such as trainer, pserver. # add paddle library for a paddle executable, such as trainer, pserver.
# #
# It will handle WITH_PYTHON/WITH_GLOG etc. # It will handle WITH_PYTHON etc.
function(link_paddle_exe TARGET_NAME) function(link_paddle_exe TARGET_NAME)
if(WITH_RDMA) if(WITH_RDMA)
generate_rdma_links() generate_rdma_links()
...@@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME) ...@@ -96,6 +96,7 @@ function(link_paddle_exe TARGET_NAME)
target_circle_link_libraries(${TARGET_NAME} target_circle_link_libraries(${TARGET_NAME}
ARCHIVE_START ARCHIVE_START
paddle_gserver paddle_gserver
paddle_function
${METRIC_LIBS} ${METRIC_LIBS}
ARCHIVE_END ARCHIVE_END
paddle_pserver paddle_pserver
...@@ -106,8 +107,11 @@ function(link_paddle_exe TARGET_NAME) ...@@ -106,8 +107,11 @@ function(link_paddle_exe TARGET_NAME)
paddle_parameter paddle_parameter
paddle_proto paddle_proto
paddle_cuda paddle_cuda
paddle_test_main
${METRIC_LIBS} ${METRIC_LIBS}
${PROTOBUF_LIBRARY} ${PROTOBUF_LIBRARY}
${LIBGLOG_LIBRARY}
${GFLAGS_LIBRARIES}
${CMAKE_THREAD_LIBS_INIT} ${CMAKE_THREAD_LIBS_INIT}
${CBLAS_LIBS} ${CBLAS_LIBS}
${ZLIB_LIBRARIES} ${ZLIB_LIBRARIES}
...@@ -125,16 +129,6 @@ function(link_paddle_exe TARGET_NAME) ...@@ -125,16 +129,6 @@ function(link_paddle_exe TARGET_NAME)
${PYTHON_LIBRARIES}) ${PYTHON_LIBRARIES})
endif() endif()
if(WITH_GLOG)
target_link_libraries(${TARGET_NAME}
${LIBGLOG_LIBRARY})
endif()
if(WITH_GFLAGS)
target_link_libraries(${TARGET_NAME}
${GFLAGS_LIBRARIES})
endif()
if(WITH_GPU) if(WITH_GPU)
if(NOT WITH_DSO OR WITH_METRIC) if(NOT WITH_DSO OR WITH_METRIC)
target_link_libraries(${TARGET_NAME} target_link_libraries(${TARGET_NAME}
...@@ -148,6 +142,11 @@ function(link_paddle_exe TARGET_NAME) ...@@ -148,6 +142,11 @@ function(link_paddle_exe TARGET_NAME)
target_link_libraries(${TARGET_NAME} rt) target_link_libraries(${TARGET_NAME} rt)
endif() endif()
endif() endif()
if(NOT WITH_DSO)
target_link_libraries(${TARGET_NAME}
${WARPCTC_LIBRARY})
endif()
endfunction() endfunction()
# link_paddle_test # link_paddle_test
...@@ -201,5 +200,5 @@ function(create_resources res_file output) ...@@ -201,5 +200,5 @@ function(create_resources res_file output)
# Convert hex data for C compatibility # Convert hex data for C compatibility
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," filedata ${filedata}) string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," filedata ${filedata})
# Append data to output file # Append data to output file
file(APPEND ${output} "const unsigned char ${filename}[] = {${filedata}};\nconst unsigned ${filename}_size = sizeof(${filename});\n") file(APPEND ${output} "const unsigned char ${filename}[] = {${filedata}0};\nconst unsigned ${filename}_size = sizeof(${filename});\n")
endfunction() endfunction()
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved #!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,4 +16,3 @@ set -e ...@@ -15,4 +16,3 @@ set -e
wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz wget https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
tar zxf cifar-10-python.tar.gz tar zxf cifar-10-python.tar.gz
rm cifar-10-python.tar.gz rm cifar-10-python.tar.gz
...@@ -15,5 +15,3 @@ do ...@@ -15,5 +15,3 @@ do
gunzip ${fname}.gz gunzip ${fname}.gz
fi fi
done done
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,10 +14,9 @@ ...@@ -14,10 +14,9 @@
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
mode = get_config_arg("mode", str, "generator") mode = get_config_arg("mode", str, "generator")
assert mode in set(["generator", assert mode in set([
"discriminator", "generator", "discriminator", "generator_training", "discriminator_training"
"generator_training", ])
"discriminator_training"])
is_generator_training = mode == "generator_training" is_generator_training = mode == "generator_training"
is_discriminator_training = mode == "discriminator_training" is_discriminator_training = mode == "discriminator_training"
...@@ -38,8 +37,8 @@ sample_dim = 2 ...@@ -38,8 +37,8 @@ sample_dim = 2
settings( settings(
batch_size=128, batch_size=128,
learning_rate=1e-4, learning_rate=1e-4,
learning_method=AdamOptimizer(beta1=0.5) learning_method=AdamOptimizer(beta1=0.5))
)
def discriminator(sample): def discriminator(sample):
""" """
...@@ -50,71 +49,88 @@ def discriminator(sample): ...@@ -50,71 +49,88 @@ def discriminator(sample):
of the sample is from real data. of the sample is from real data.
""" """
param_attr = ParamAttr(is_static=is_generator_training) param_attr = ParamAttr(is_static=is_generator_training)
bias_attr = ParamAttr(is_static=is_generator_training, bias_attr = ParamAttr(
initial_mean=1.0, is_static=is_generator_training, initial_mean=1.0, initial_std=0)
initial_std=0)
hidden = fc_layer(input=sample, name="dis_hidden", size=hidden_dim, hidden = fc_layer(
input=sample,
name="dis_hidden",
size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=ReluActivation()) act=ReluActivation())
hidden2 = fc_layer(input=hidden, name="dis_hidden2", size=hidden_dim, hidden2 = fc_layer(
input=hidden,
name="dis_hidden2",
size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=LinearActivation()) act=LinearActivation())
hidden_bn = batch_norm_layer(hidden2, hidden_bn = batch_norm_layer(
hidden2,
act=ReluActivation(), act=ReluActivation(),
name="dis_hidden_bn", name="dis_hidden_bn",
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=ParamAttr(is_static=is_generator_training, param_attr=ParamAttr(
initial_mean=1.0, is_static=is_generator_training, initial_mean=1.0,
initial_std=0.02), initial_std=0.02),
use_global_stats=False) use_global_stats=False)
return fc_layer(input=hidden_bn, name="dis_prob", size=2, return fc_layer(
input=hidden_bn,
name="dis_prob",
size=2,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=SoftmaxActivation()) act=SoftmaxActivation())
def generator(noise): def generator(noise):
""" """
generator generates a sample given noise generator generates a sample given noise
""" """
param_attr = ParamAttr(is_static=is_discriminator_training) param_attr = ParamAttr(is_static=is_discriminator_training)
bias_attr = ParamAttr(is_static=is_discriminator_training, bias_attr = ParamAttr(
initial_mean=1.0, is_static=is_discriminator_training, initial_mean=1.0, initial_std=0)
initial_std=0)
hidden = fc_layer(input=noise, hidden = fc_layer(
input=noise,
name="gen_layer_hidden", name="gen_layer_hidden",
size=hidden_dim, size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=ReluActivation()) act=ReluActivation())
hidden2 = fc_layer(input=hidden, name="gen_hidden2", size=hidden_dim, hidden2 = fc_layer(
input=hidden,
name="gen_hidden2",
size=hidden_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=LinearActivation()) act=LinearActivation())
hidden_bn = batch_norm_layer(hidden2, hidden_bn = batch_norm_layer(
hidden2,
act=ReluActivation(), act=ReluActivation(),
name="gen_layer_hidden_bn", name="gen_layer_hidden_bn",
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=ParamAttr(is_static=is_discriminator_training, param_attr=ParamAttr(
is_static=is_discriminator_training,
initial_mean=1.0, initial_mean=1.0,
initial_std=0.02), initial_std=0.02),
use_global_stats=False) use_global_stats=False)
return fc_layer(input=hidden_bn, return fc_layer(
input=hidden_bn,
name="gen_layer1", name="gen_layer1",
size=sample_dim, size=sample_dim,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=LinearActivation()) act=LinearActivation())
if is_generator_training: if is_generator_training:
noise = data_layer(name="noise", size=noise_dim) noise = data_layer(name="noise", size=noise_dim)
sample = generator(noise) sample = generator(noise)
...@@ -126,7 +142,8 @@ if is_generator_training or is_discriminator_training: ...@@ -126,7 +142,8 @@ if is_generator_training or is_discriminator_training:
label = data_layer(name="label", size=1) label = data_layer(name="label", size=1)
prob = discriminator(sample) prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label) cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(input=prob, label=label, name=mode+'_error') classification_error_evaluator(
input=prob, label=label, name=mode + '_error')
outputs(cost) outputs(cost)
if is_generator: if is_generator:
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,10 +15,9 @@ from paddle.trainer_config_helpers import * ...@@ -15,10 +15,9 @@ from paddle.trainer_config_helpers import *
mode = get_config_arg("mode", str, "generator") mode = get_config_arg("mode", str, "generator")
dataSource = get_config_arg("data", str, "mnist") dataSource = get_config_arg("data", str, "mnist")
assert mode in set(["generator", assert mode in set([
"discriminator", "generator", "discriminator", "generator_training", "discriminator_training"
"generator_training", ])
"discriminator_training"])
is_generator_training = mode == "generator_training" is_generator_training = mode == "generator_training"
is_discriminator_training = mode == "discriminator_training" is_discriminator_training = mode == "discriminator_training"
...@@ -41,19 +40,28 @@ if dataSource == "mnist": ...@@ -41,19 +40,28 @@ if dataSource == "mnist":
else: else:
sample_dim = 32 sample_dim = 32
c_dim = 3 c_dim = 3
s2, s4 = int(sample_dim/2), int(sample_dim/4), s2, s4 = int(sample_dim / 2), int(sample_dim / 4),
s8, s16 = int(sample_dim/8), int(sample_dim/16) s8, s16 = int(sample_dim / 8), int(sample_dim / 16)
settings( settings(
batch_size=128, batch_size=128,
learning_rate=2e-4, learning_rate=2e-4,
learning_method=AdamOptimizer(beta1=0.5) learning_method=AdamOptimizer(beta1=0.5))
)
def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name, def conv_bn(input,
param_attr, bias_attr, param_attr_bn, bn, trans=False, channels,
imgSize,
num_filters,
output_x,
stride,
name,
param_attr,
bias_attr,
param_attr_bn,
bn,
trans=False,
act=ReluActivation()): act=ReluActivation()):
""" """
conv_bn is a utility function that constructs a convolution/deconv layer conv_bn is a utility function that constructs a convolution/deconv layer
with an optional batch_norm layer with an optional batch_norm layer
...@@ -76,24 +84,35 @@ def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name, ...@@ -76,24 +84,35 @@ def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name,
filter_size = tmp filter_size = tmp
padding = 0 padding = 0
print (imgSize, output_x, stride, filter_size, padding) print(imgSize, output_x, stride, filter_size, padding)
if trans: if trans:
nameApx = "_conv"
else:
nameApx = "_convt" nameApx = "_convt"
else:
nameApx = "_conv"
if bn: if bn:
conv = img_conv_layer(input, filter_size=filter_size, conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters, num_filters=num_filters,
name=name + nameApx, num_channels=channels, name=name + nameApx,
act=LinearActivation(), groups=1, stride=stride, num_channels=channels,
padding=padding, bias_attr=bias_attr, act=LinearActivation(),
param_attr=param_attr, shared_biases=True, layer_attr=None, groups=1,
filter_size_y=None, stride_y=None, padding_y=None, stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans) trans=trans)
conv_bn = batch_norm_layer(conv, conv_bn = batch_norm_layer(
conv,
act=act, act=act,
name=name + nameApx + "_bn", name=name + nameApx + "_bn",
bias_attr=bias_attr, bias_attr=bias_attr,
...@@ -102,49 +121,60 @@ def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name, ...@@ -102,49 +121,60 @@ def conv_bn(input, channels, imgSize, num_filters, output_x, stride, name,
return conv_bn return conv_bn
else: else:
conv = img_conv_layer(input, filter_size=filter_size, conv = img_conv_layer(
input,
filter_size=filter_size,
num_filters=num_filters, num_filters=num_filters,
name=name + nameApx, num_channels=channels, name=name + nameApx,
act=act, groups=1, stride=stride, num_channels=channels,
padding=padding, bias_attr=bias_attr, act=act,
param_attr=param_attr, shared_biases=True, layer_attr=None, groups=1,
filter_size_y=None, stride_y=None, padding_y=None, stride=stride,
padding=padding,
bias_attr=bias_attr,
param_attr=param_attr,
shared_biases=True,
layer_attr=None,
filter_size_y=None,
stride_y=None,
padding_y=None,
trans=trans) trans=trans)
return conv return conv
def generator(noise): def generator(noise):
""" """
generator generates a sample given noise generator generates a sample given noise
""" """
param_attr = ParamAttr(is_static=is_discriminator_training, param_attr = ParamAttr(
initial_mean=0.0, is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.02)
initial_std=0.02) bias_attr = ParamAttr(
bias_attr = ParamAttr(is_static=is_discriminator_training, is_static=is_discriminator_training, initial_mean=0.0, initial_std=0.0)
initial_mean=0.0,
initial_std=0.0) param_attr_bn = ParamAttr(
is_static=is_discriminator_training, initial_mean=1.0, initial_std=0.02)
param_attr_bn=ParamAttr(is_static=is_discriminator_training,
initial_mean=1.0, h1 = fc_layer(
initial_std=0.02) input=noise,
h1 = fc_layer(input=noise,
name="gen_layer_h1", name="gen_layer_h1",
size=s8 * s8 * gf_dim * 4, size=s8 * s8 * gf_dim * 4,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=LinearActivation()) act=LinearActivation())
h1_bn = batch_norm_layer(h1, h1_bn = batch_norm_layer(
h1,
act=ReluActivation(), act=ReluActivation(),
name="gen_layer_h1_bn", name="gen_layer_h1_bn",
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr_bn, param_attr=param_attr_bn,
use_global_stats=False) use_global_stats=False)
h2_bn = conv_bn(h1_bn, h2_bn = conv_bn(
channels=gf_dim*4, h1_bn,
channels=gf_dim * 4,
output_x=s8, output_x=s8,
num_filters=gf_dim*2, num_filters=gf_dim * 2,
imgSize=s4, imgSize=s4,
stride=2, stride=2,
name="gen_layer_h2", name="gen_layer_h2",
...@@ -154,8 +184,9 @@ def generator(noise): ...@@ -154,8 +184,9 @@ def generator(noise):
bn=True, bn=True,
trans=True) trans=True)
h3_bn = conv_bn(h2_bn, h3_bn = conv_bn(
channels=gf_dim*2, h2_bn,
channels=gf_dim * 2,
output_x=s4, output_x=s4,
num_filters=gf_dim, num_filters=gf_dim,
imgSize=s2, imgSize=s2,
...@@ -167,8 +198,8 @@ def generator(noise): ...@@ -167,8 +198,8 @@ def generator(noise):
bn=True, bn=True,
trans=True) trans=True)
return conv_bn(
return conv_bn(h3_bn, h3_bn,
channels=gf_dim, channels=gf_dim,
output_x=s2, output_x=s2,
num_filters=c_dim, num_filters=c_dim,
...@@ -191,18 +222,16 @@ def discriminator(sample): ...@@ -191,18 +222,16 @@ def discriminator(sample):
of the sample is from generator and dimension 1 is the probabblity of the sample is from generator and dimension 1 is the probabblity
of the sample is from real data. of the sample is from real data.
""" """
param_attr = ParamAttr(is_static=is_generator_training, param_attr = ParamAttr(
initial_mean=0.0, is_static=is_generator_training, initial_mean=0.0, initial_std=0.02)
initial_std=0.02) bias_attr = ParamAttr(
bias_attr = ParamAttr(is_static=is_generator_training, is_static=is_generator_training, initial_mean=0.0, initial_std=0.0)
initial_mean=0.0,
initial_std=0.0) param_attr_bn = ParamAttr(
is_static=is_generator_training, initial_mean=1.0, initial_std=0.02)
param_attr_bn=ParamAttr(is_static=is_generator_training,
initial_mean=1.0, h0 = conv_bn(
initial_std=0.02) sample,
h0 = conv_bn(sample,
channels=c_dim, channels=c_dim,
imgSize=sample_dim, imgSize=sample_dim,
num_filters=df_dim, num_filters=df_dim,
...@@ -214,10 +243,11 @@ def discriminator(sample): ...@@ -214,10 +243,11 @@ def discriminator(sample):
param_attr_bn=param_attr_bn, param_attr_bn=param_attr_bn,
bn=False) bn=False)
h1_bn = conv_bn(h0, h1_bn = conv_bn(
h0,
channels=df_dim, channels=df_dim,
imgSize=s2, imgSize=s2,
num_filters=df_dim*2, num_filters=df_dim * 2,
output_x=s4, output_x=s4,
stride=2, stride=2,
name="dis_h1", name="dis_h1",
...@@ -226,10 +256,11 @@ def discriminator(sample): ...@@ -226,10 +256,11 @@ def discriminator(sample):
param_attr_bn=param_attr_bn, param_attr_bn=param_attr_bn,
bn=True) bn=True)
h2_bn = conv_bn(h1_bn, h2_bn = conv_bn(
channels=df_dim*2, h1_bn,
channels=df_dim * 2,
imgSize=s4, imgSize=s4,
num_filters=df_dim*4, num_filters=df_dim * 4,
output_x=s8, output_x=s8,
stride=2, stride=2,
name="dis_h2", name="dis_h2",
...@@ -238,25 +269,28 @@ def discriminator(sample): ...@@ -238,25 +269,28 @@ def discriminator(sample):
param_attr_bn=param_attr_bn, param_attr_bn=param_attr_bn,
bn=True) bn=True)
return fc_layer(input=h2_bn, name="dis_prob", size=2, return fc_layer(
input=h2_bn,
name="dis_prob",
size=2,
bias_attr=bias_attr, bias_attr=bias_attr,
param_attr=param_attr, param_attr=param_attr,
act=SoftmaxActivation()) act=SoftmaxActivation())
if is_generator_training: if is_generator_training:
noise = data_layer(name="noise", size=noise_dim) noise = data_layer(name="noise", size=noise_dim)
sample = generator(noise) sample = generator(noise)
if is_discriminator_training: if is_discriminator_training:
sample = data_layer(name="sample", size=sample_dim * sample_dim*c_dim) sample = data_layer(name="sample", size=sample_dim * sample_dim * c_dim)
if is_generator_training or is_discriminator_training: if is_generator_training or is_discriminator_training:
label = data_layer(name="label", size=1) label = data_layer(name="label", size=1)
prob = discriminator(sample) prob = discriminator(sample)
cost = cross_entropy(input=prob, label=label) cost = cross_entropy(input=prob, label=label)
classification_error_evaluator(input=prob, label=label, name=mode+'_error') classification_error_evaluator(
input=prob, label=label, name=mode + '_error')
outputs(cost) outputs(cost)
if is_generator: if is_generator:
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,7 +16,7 @@ import argparse ...@@ -16,7 +16,7 @@ import argparse
import random import random
import numpy import numpy
import cPickle import cPickle
import sys,os import sys, os
from PIL import Image from PIL import Image
from paddle.trainer.config_parser import parse_config from paddle.trainer.config_parser import parse_config
...@@ -24,6 +24,7 @@ from paddle.trainer.config_parser import logger ...@@ -24,6 +24,7 @@ from paddle.trainer.config_parser import logger
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
def plot2DScatter(data, outputfile): def plot2DScatter(data, outputfile):
''' '''
Plot the data as a 2D scatter plot and save to outputfile Plot the data as a 2D scatter plot and save to outputfile
...@@ -41,9 +42,11 @@ def plot2DScatter(data, outputfile): ...@@ -41,9 +42,11 @@ def plot2DScatter(data, outputfile):
plt.scatter(x, y) plt.scatter(x, y)
plt.savefig(outputfile, bbox_inches='tight') plt.savefig(outputfile, bbox_inches='tight')
def CHECK_EQ(a, b): def CHECK_EQ(a, b):
assert a == b, "a=%s, b=%s" % (a, b) assert a == b, "a=%s, b=%s" % (a, b)
def copy_shared_parameters(src, dst): def copy_shared_parameters(src, dst):
''' '''
copy the parameters from src to dst copy the parameters from src to dst
...@@ -52,11 +55,9 @@ def copy_shared_parameters(src, dst): ...@@ -52,11 +55,9 @@ def copy_shared_parameters(src, dst):
:param dst: the destination of the parameters :param dst: the destination of the parameters
:type dst: GradientMachine :type dst: GradientMachine
''' '''
src_params = [src.getParameter(i) src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())]
for i in xrange(src.getParameterSize())]
src_params = dict([(p.getName(), p) for p in src_params]) src_params = dict([(p.getName(), p) for p in src_params])
for i in xrange(dst.getParameterSize()): for i in xrange(dst.getParameterSize()):
dst_param = dst.getParameter(i) dst_param = dst.getParameter(i)
src_param = src_params.get(dst_param.getName(), None) src_param = src_params.get(dst_param.getName(), None)
...@@ -68,14 +69,16 @@ def copy_shared_parameters(src, dst): ...@@ -68,14 +69,16 @@ def copy_shared_parameters(src, dst):
dst_value.copyFrom(src_value) dst_value.copyFrom(src_value)
dst_param.setValueUpdated() dst_param.setValueUpdated()
def print_parameters(src): def print_parameters(src):
src_params = [src.getParameter(i) src_params = [src.getParameter(i) for i in xrange(src.getParameterSize())]
for i in xrange(src.getParameterSize())]
print "***************" print "***************"
for p in src_params: for p in src_params:
print "Name is %s" % p.getName() print "Name is %s" % p.getName()
print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray() print "value is %s \n" % p.getBuf(api.PARAMETER_VALUE).copyToNumpyArray(
)
def load_mnist_data(imageFile): def load_mnist_data(imageFile):
f = open(imageFile, "rb") f = open(imageFile, "rb")
...@@ -87,32 +90,35 @@ def load_mnist_data(imageFile): ...@@ -87,32 +90,35 @@ def load_mnist_data(imageFile):
else: else:
n = 10000 n = 10000
data = numpy.fromfile(f, 'ubyte', count=n*28*28).reshape((n, 28*28)) data = numpy.fromfile(f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28))
data = data / 255.0 * 2.0 - 1.0 data = data / 255.0 * 2.0 - 1.0
f.close() f.close()
return data.astype('float32') return data.astype('float32')
def load_cifar_data(cifar_path): def load_cifar_data(cifar_path):
batch_size = 10000 batch_size = 10000
data = numpy.zeros((5*batch_size, 32*32*3), dtype = "float32") data = numpy.zeros((5 * batch_size, 32 * 32 * 3), dtype="float32")
for i in range(1, 6): for i in range(1, 6):
file = cifar_path + "/data_batch_" + str(i) file = cifar_path + "/data_batch_" + str(i)
fo = open(file, 'rb') fo = open(file, 'rb')
dict = cPickle.load(fo) dict = cPickle.load(fo)
fo.close() fo.close()
data[(i - 1)*batch_size:(i*batch_size), :] = dict["data"] data[(i - 1) * batch_size:(i * batch_size), :] = dict["data"]
data = data / 255.0 * 2.0 - 1.0 data = data / 255.0 * 2.0 - 1.0
return data return data
# synthesize 2-D uniform data # synthesize 2-D uniform data
def load_uniform_data(): def load_uniform_data():
data = numpy.random.rand(1000000, 2).astype('float32') data = numpy.random.rand(1000000, 2).astype('float32')
return data return data
def merge(images, size): def merge(images, size):
if images.shape[1] == 28*28: if images.shape[1] == 28 * 28:
h, w, c = 28, 28, 1 h, w, c = 28, 28, 1
else: else:
h, w, c = 32, 32, 3 h, w, c = 32, 32, 3
...@@ -124,6 +130,7 @@ def merge(images, size): ...@@ -124,6 +130,7 @@ def merge(images, size):
((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0) ((images[idx, :].reshape((h, w, c), order="F").transpose(1, 0, 2) + 1.0) / 2.0 * 255.0)
return img.astype('uint8') return img.astype('uint8')
def save_images(images, path): def save_images(images, path):
merged_img = merge(images, [8, 8]) merged_img = merge(images, [8, 8])
if merged_img.shape[2] == 1: if merged_img.shape[2] == 1:
...@@ -132,13 +139,16 @@ def save_images(images, path): ...@@ -132,13 +139,16 @@ def save_images(images, path):
im = Image.fromarray(merged_img, mode="RGB") im = Image.fromarray(merged_img, mode="RGB")
im.save(path) im.save(path)
def get_real_samples(batch_size, data_np): def get_real_samples(batch_size, data_np):
return data_np[numpy.random.choice(data_np.shape[0], batch_size, return data_np[numpy.random.choice(
replace=False),:] data_np.shape[0], batch_size, replace=False), :]
def get_noise(batch_size, noise_dim): def get_noise(batch_size, noise_dim):
return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32') return numpy.random.normal(size=(batch_size, noise_dim)).astype('float32')
def get_fake_samples(generator_machine, batch_size, noise): def get_fake_samples(generator_machine, batch_size, noise):
gen_inputs = api.Arguments.createArguments(1) gen_inputs = api.Arguments.createArguments(1)
gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise)) gen_inputs.setSlotValue(0, api.Matrix.createDenseFromNumpy(noise))
...@@ -147,12 +157,14 @@ def get_fake_samples(generator_machine, batch_size, noise): ...@@ -147,12 +157,14 @@ def get_fake_samples(generator_machine, batch_size, noise):
fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat() fake_samples = gen_outputs.getSlotValue(0).copyToNumpyMat()
return fake_samples return fake_samples
def get_training_loss(training_machine, inputs): def get_training_loss(training_machine, inputs):
outputs = api.Arguments.createArguments(0) outputs = api.Arguments.createArguments(0)
training_machine.forward(inputs, outputs, api.PASS_TEST) training_machine.forward(inputs, outputs, api.PASS_TEST)
loss = outputs.getSlotValue(0).copyToNumpyMat() loss = outputs.getSlotValue(0).copyToNumpyMat()
return numpy.mean(loss) return numpy.mean(loss)
def prepare_discriminator_data_batch_pos(batch_size, data_np): def prepare_discriminator_data_batch_pos(batch_size, data_np):
real_samples = get_real_samples(batch_size, data_np) real_samples = get_real_samples(batch_size, data_np)
labels = numpy.ones(batch_size, dtype='int32') labels = numpy.ones(batch_size, dtype='int32')
...@@ -161,6 +173,7 @@ def prepare_discriminator_data_batch_pos(batch_size, data_np): ...@@ -161,6 +173,7 @@ def prepare_discriminator_data_batch_pos(batch_size, data_np):
inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
return inputs return inputs
def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise):
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
labels = numpy.zeros(batch_size, dtype='int32') labels = numpy.zeros(batch_size, dtype='int32')
...@@ -169,6 +182,7 @@ def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise): ...@@ -169,6 +182,7 @@ def prepare_discriminator_data_batch_neg(generator_machine, batch_size, noise):
inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels)) inputs.setSlotIds(1, api.IVector.createVectorFromNumpy(labels))
return inputs return inputs
def prepare_generator_data_batch(batch_size, noise): def prepare_generator_data_batch(batch_size, noise):
label = numpy.ones(batch_size, dtype='int32') label = numpy.ones(batch_size, dtype='int32')
inputs = api.Arguments.createArguments(2) inputs = api.Arguments.createArguments(2)
...@@ -193,10 +207,9 @@ def get_layer_size(model_conf, layer_name): ...@@ -193,10 +207,9 @@ def get_layer_size(model_conf, layer_name):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform") parser.add_argument("-d", "--data_source", help="mnist or cifar or uniform")
parser.add_argument("--use_gpu", default="1", parser.add_argument(
help="1 means use gpu for training") "--use_gpu", default="1", help="1 means use gpu for training")
parser.add_argument("--gpu_id", default="0", parser.add_argument("--gpu_id", default="0", help="the gpu_id parameter")
help="the gpu_id parameter")
args = parser.parse_args() args = parser.parse_args()
data_source = args.data_source data_source = args.data_source
use_gpu = args.use_gpu use_gpu = args.use_gpu
...@@ -209,8 +222,9 @@ def main(): ...@@ -209,8 +222,9 @@ def main():
if not os.path.exists("./%s_params/" % data_source): if not os.path.exists("./%s_params/" % data_source):
os.makedirs("./%s_params/" % data_source) os.makedirs("./%s_params/" % data_source)
api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10', '--log_period=100', api.initPaddle('--use_gpu=' + use_gpu, '--dot_period=10',
'--gpu_id=' + args.gpu_id, '--save_dir=' + "./%s_params/" % data_source) '--log_period=100', '--gpu_id=' + args.gpu_id,
'--save_dir=' + "./%s_params/" % data_source)
if data_source == "uniform": if data_source == "uniform":
conf = "gan_conf.py" conf = "gan_conf.py"
...@@ -220,7 +234,8 @@ def main(): ...@@ -220,7 +234,8 @@ def main():
num_iter = 1000 num_iter = 1000
gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source) gen_conf = parse_config(conf, "mode=generator_training,data=" + data_source)
dis_conf = parse_config(conf, "mode=discriminator_training,data=" + data_source) dis_conf = parse_config(conf,
"mode=discriminator_training,data=" + data_source)
generator_conf = parse_config(conf, "mode=generator,data=" + data_source) generator_conf = parse_config(conf, "mode=generator,data=" + data_source)
batch_size = dis_conf.opt_config.batch_size batch_size = dis_conf.opt_config.batch_size
noise_dim = get_layer_size(gen_conf.model_config, "noise") noise_dim = get_layer_size(gen_conf.model_config, "noise")
...@@ -245,11 +260,9 @@ def main(): ...@@ -245,11 +260,9 @@ def main():
generator_machine = api.GradientMachine.createFromConfigProto( generator_machine = api.GradientMachine.createFromConfigProto(
generator_conf.model_config) generator_conf.model_config)
dis_trainer = api.Trainer.create( dis_trainer = api.Trainer.create(dis_conf, dis_training_machine)
dis_conf, dis_training_machine)
gen_trainer = api.Trainer.create( gen_trainer = api.Trainer.create(gen_conf, gen_training_machine)
gen_conf, gen_training_machine)
dis_trainer.startTrain() dis_trainer.startTrain()
gen_trainer.startTrain() gen_trainer.startTrain()
...@@ -272,21 +285,23 @@ def main(): ...@@ -272,21 +285,23 @@ def main():
noise = get_noise(batch_size, noise_dim) noise = get_noise(batch_size, noise_dim)
data_batch_dis_pos = prepare_discriminator_data_batch_pos( data_batch_dis_pos = prepare_discriminator_data_batch_pos(
batch_size, data_np) batch_size, data_np)
dis_loss_pos = get_training_loss(dis_training_machine, data_batch_dis_pos) dis_loss_pos = get_training_loss(dis_training_machine,
data_batch_dis_pos)
data_batch_dis_neg = prepare_discriminator_data_batch_neg( data_batch_dis_neg = prepare_discriminator_data_batch_neg(
generator_machine, batch_size, noise) generator_machine, batch_size, noise)
dis_loss_neg = get_training_loss(dis_training_machine, data_batch_dis_neg) dis_loss_neg = get_training_loss(dis_training_machine,
data_batch_dis_neg)
dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0 dis_loss = (dis_loss_pos + dis_loss_neg) / 2.0
# Do forward pass in generator to get the gen_loss # Do forward pass in generator to get the gen_loss
data_batch_gen = prepare_generator_data_batch( data_batch_gen = prepare_generator_data_batch(batch_size, noise)
batch_size, noise)
gen_loss = get_training_loss(gen_training_machine, data_batch_gen) gen_loss = get_training_loss(gen_training_machine, data_batch_gen)
if i % 100 == 0: if i % 100 == 0:
print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos, dis_loss_neg) print "d_pos_loss is %s d_neg_loss is %s" % (dis_loss_pos,
dis_loss_neg)
print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss) print "d_loss is %s g_loss is %s" % (dis_loss, gen_loss)
# Decide which network to train based on the training history # Decide which network to train based on the training history
...@@ -300,7 +315,8 @@ def main(): ...@@ -300,7 +315,8 @@ def main():
curr_strike = 1 curr_strike = 1
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg) dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_neg)
dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos) dis_trainer.trainOneDataBatch(batch_size, data_batch_dis_pos)
copy_shared_parameters(dis_training_machine, gen_training_machine) copy_shared_parameters(dis_training_machine,
gen_training_machine)
else: else:
if curr_train == "gen": if curr_train == "gen":
...@@ -311,7 +327,8 @@ def main(): ...@@ -311,7 +327,8 @@ def main():
gen_trainer.trainOneDataBatch(batch_size, data_batch_gen) gen_trainer.trainOneDataBatch(batch_size, data_batch_gen)
# TODO: add API for paddle to allow true parameter sharing between different GradientMachines # TODO: add API for paddle to allow true parameter sharing between different GradientMachines
# so that we do not need to copy shared parameters. # so that we do not need to copy shared parameters.
copy_shared_parameters(gen_training_machine, dis_training_machine) copy_shared_parameters(gen_training_machine,
dis_training_machine)
copy_shared_parameters(gen_training_machine, generator_machine) copy_shared_parameters(gen_training_machine, generator_machine)
dis_trainer.finishTrainPass() dis_trainer.finishTrainPass()
...@@ -319,11 +336,14 @@ def main(): ...@@ -319,11 +336,14 @@ def main():
# At the end of each pass, save the generated samples/images # At the end of each pass, save the generated samples/images
fake_samples = get_fake_samples(generator_machine, batch_size, noise) fake_samples = get_fake_samples(generator_machine, batch_size, noise)
if data_source == "uniform": if data_source == "uniform":
plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass)) plot2DScatter(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass))
else: else:
save_images(fake_samples, "./%s_samples/train_pass%s.png" % (data_source, train_pass)) save_images(fake_samples, "./%s_samples/train_pass%s.png" %
(data_source, train_pass))
dis_trainer.finishTrain() dis_trainer.finishTrain()
gen_trainer.finishTrain() gen_trainer.finishTrain()
if __name__ == '__main__': if __name__ == '__main__':
main() main()
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved #!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -21,7 +21,7 @@ from paddle.trainer.PyDataProvider2 import * ...@@ -21,7 +21,7 @@ from paddle.trainer.PyDataProvider2 import *
# #
# {'img_size': 32, # {'img_size': 32,
# 'settings': <paddle.trainer.PyDataProviderWrapper.Cls instance at 0x7fea27cb6050>, # 'settings': a global object,
# 'color': True, # 'color': True,
# 'mean_img_size': 32, # 'mean_img_size': 32,
# 'meta': './data/cifar-out/batches/batches.meta', # 'meta': './data/cifar-out/batches/batches.meta',
...@@ -50,10 +50,10 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg, ...@@ -50,10 +50,10 @@ def hook(settings, img_size, mean_img_size, num_classes, color, meta, use_jpeg,
settings.logger.info('Image size: %s', settings.img_size) settings.logger.info('Image size: %s', settings.img_size)
settings.logger.info('Meta path: %s', settings.meta_path) settings.logger.info('Meta path: %s', settings.meta_path)
settings.input_types = [ settings.input_types = {
dense_vector(settings.img_raw_size), # image feature 'image': dense_vector(settings.img_raw_size),
integer_value(settings.num_classes) 'label': integer_value(settings.num_classes)
] # labels }
settings.logger.info('DataProvider Initialization finished') settings.logger.info('DataProvider Initialization finished')
...@@ -83,4 +83,7 @@ def processData(settings, file_list): ...@@ -83,4 +83,7 @@ def processData(settings, file_list):
img, settings.img_mean, settings.img_size, img, settings.img_mean, settings.img_size,
settings.is_train, settings.color) settings.is_train, settings.color)
label = data['labels'][i] label = data['labels'][i]
yield img_feat.astype('float32'), int(label) yield {
'image': img_feat.astype('float32'),
'label': int(label)
}
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
dataprovider.pyc
empty.list
train.log
output
train.list
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,8 +17,10 @@ import random ...@@ -17,8 +17,10 @@ import random
# define data types of input: 2 real numbers # define data types of input: 2 real numbers
@provider(input_types=[dense_vector(1), dense_vector(1)], use_seq=False) @provider(
input_types={'x': dense_vector(1),
'y': dense_vector(1)}, use_seq=False)
def process(settings, input_file): def process(settings, input_file):
for i in xrange(2000): for i in xrange(2000):
x = random.random() x = random.random()
yield [x], [2 * x + 0.3] yield {'x': [x], 'y': [2 * x + 0.3]}
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,11 +15,8 @@ ...@@ -15,11 +15,8 @@
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
# 1. read data. Suppose you saved above python code as dataprovider.py # 1. read data. Suppose you saved above python code as dataprovider.py
data_file = 'empty.list'
with open(data_file, 'w') as f:
f.writelines(' ')
define_py_data_sources2( define_py_data_sources2(
train_list=data_file, train_list=['no_matter.txt'],
test_list=None, test_list=None,
module='dataprovider', module='dataprovider',
obj='process', obj='process',
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
from paddle.trainer.PyDataProvider2 import * from paddle.trainer.PyDataProvider2 import *
import numpy
# Define a py data provider # Define a py data provider
@provider( @provider(
input_types={'pixel': dense_vector(28 * 28), input_types={'pixel': dense_vector(28 * 28),
'label': integer_value(10)}) 'label': integer_value(10)},
cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, filename): # settings is not used currently. def process(settings, filename): # settings is not used currently.
imgf = filename + "-images-idx3-ubyte" imgf = filename + "-images-idx3-ubyte"
labelf = filename + "-labels-idx1-ubyte" labelf = filename + "-labels-idx1-ubyte"
...@@ -20,12 +22,13 @@ def process(settings, filename): # settings is not used currently. ...@@ -20,12 +22,13 @@ def process(settings, filename): # settings is not used currently.
else: else:
n = 10000 n = 10000
for i in range(n): images = numpy.fromfile(
label = ord(l.read(1)) f, 'ubyte', count=n * 28 * 28).reshape((n, 28 * 28)).astype('float32')
pixels = [] images = images / 255.0 * 2.0 - 1.0
for j in range(28 * 28): labels = numpy.fromfile(l, 'ubyte', count=n).astype("int")
pixels.append(float(ord(f.read(1))) / 255.0)
yield {"pixel": pixels, 'label': label} for i in xrange(n):
yield {"pixel": images[i, :], 'label': labels[i]}
f.close() f.close()
l.close() l.close()
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/env python #!/bin/env python
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/env python #!/bin/env python
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -8,6 +8,8 @@ data/test.list ...@@ -8,6 +8,8 @@ data/test.list
data/test.txt data/test.txt
data/train.list data/train.list
data/train.txt data/train.txt
data/pred.list
data/pred.txt
dataprovider_copy_1.py dataprovider_copy_1.py
train.log train.log
output output
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, sys
import numpy as np
from optparse import OptionParser
from py_paddle import swig_paddle, DataProviderConverter
from paddle.trainer.PyDataProvider2 import sparse_binary_vector
from paddle.trainer.config_parser import parse_config
"""
Usage: run following command to show help message.
python api_predict.py -h
"""
class QuickStartPrediction():
def __init__(self, train_conf, dict_file, model_dir=None, label_file=None):
"""
train_conf: trainer configure.
dict_file: word dictionary file name.
model_dir: directory of model.
"""
self.train_conf = train_conf
self.dict_file = dict_file
self.word_dict = {}
self.dict_dim = self.load_dict()
self.model_dir = model_dir
if model_dir is None:
self.model_dir = os.path.dirname(train_conf)
self.label = None
if label_file is not None:
self.load_label(label_file)
conf = parse_config(train_conf, "is_predict=1")
self.network = swig_paddle.GradientMachine.createFromConfigProto(
conf.model_config)
self.network.loadParameters(self.model_dir)
input_types = [sparse_binary_vector(self.dict_dim)]
self.converter = DataProviderConverter(input_types)
def load_dict(self):
"""
Load dictionary from self.dict_file.
"""
for line_count, line in enumerate(open(self.dict_file, 'r')):
self.word_dict[line.strip().split('\t')[0]] = line_count
return len(self.word_dict)
def load_label(self, label_file):
"""
Load label.
"""
self.label = {}
for v in open(label_file, 'r'):
self.label[int(v.split('\t')[1])] = v.split('\t')[0]
def get_index(self, data):
"""
transform word into integer index according to the dictionary.
"""
words = data.strip().split()
word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
return word_slot
def batch_predict(self, data_batch):
input = self.converter(data_batch)
output = self.network.forwardTest(input)
prob = output[0]["id"].tolist()
print("predicting labels is:")
print prob
def option_parser():
usage = "python predict.py -n config -w model_dir -d dictionary -i input_file "
parser = OptionParser(usage="usage: %s [options]" % usage)
parser.add_option(
"-n",
"--tconf",
action="store",
dest="train_conf",
help="network config")
parser.add_option(
"-d",
"--dict",
action="store",
dest="dict_file",
help="dictionary file")
parser.add_option(
"-b",
"--label",
action="store",
dest="label",
default=None,
help="dictionary file")
parser.add_option(
"-c",
"--batch_size",
type="int",
action="store",
dest="batch_size",
default=1,
help="the batch size for prediction")
parser.add_option(
"-w",
"--model",
action="store",
dest="model_path",
default=None,
help="model path")
return parser.parse_args()
def main():
options, args = option_parser()
train_conf = options.train_conf
batch_size = options.batch_size
dict_file = options.dict_file
model_path = options.model_path
label = options.label
swig_paddle.initPaddle("--use_gpu=0")
predict = QuickStartPrediction(train_conf, dict_file, model_path, label)
batch = []
labels = []
for line in sys.stdin:
[label, text] = line.split("\t")
labels.append(int(label))
batch.append([predict.get_index(text)])
print("labels is:")
print labels
predict.batch_predict(batch)
if __name__ == '__main__':
main()
#!/bin/bash
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
#Note the default model is pass-00002, you shold make sure the model path
#exists or change the mode path.
#only test on trainer_config.lr.py
model=output/pass-00001/
config=trainer_config.lr.py
label=data/labels.list
dict=data/dict.txt
batch_size=20
head -n$batch_size data/test.txt | python api_predict.py \
--tconf=$config\
--model=$model \
--label=$label \
--dict=$dict \
--batch_size=$batch_size
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -31,16 +31,16 @@ def initializer(settings, dictionary, **kwargs): ...@@ -31,16 +31,16 @@ def initializer(settings, dictionary, **kwargs):
# setting.input_types specifies what the data types the data provider # setting.input_types specifies what the data types the data provider
# generates. # generates.
settings.input_types = [ settings.input_types = {
# The first input is a sparse_binary_vector, # The first input is a sparse_binary_vector,
# which means each dimension of the vector is either 0 or 1. It is the # which means each dimension of the vector is either 0 or 1. It is the
# bag-of-words (BOW) representation of the texts. # bag-of-words (BOW) representation of the texts.
sparse_binary_vector(len(dictionary)), 'word': sparse_binary_vector(len(dictionary)),
# The second input is an integer. It represents the category id of the # The second input is an integer. It represents the category id of the
# sample. 2 means there are two labels in the dataset. # sample. 2 means there are two labels in the dataset.
# (1 for positive and 0 for negative) # (1 for positive and 0 for negative)
integer_value(2) 'label': integer_value(2)
] }
# Delaring a data provider. It has an initializer 'data_initialzer'. # Delaring a data provider. It has an initializer 'data_initialzer'.
...@@ -67,12 +67,12 @@ def process(settings, file_name): ...@@ -67,12 +67,12 @@ def process(settings, file_name):
# Return the features for the current comment. The first is a list # Return the features for the current comment. The first is a list
# of ids representing a 0-1 binary sparse vector of the text, # of ids representing a 0-1 binary sparse vector of the text,
# the second is the integer id of the label. # the second is the integer id of the label.
yield word_vector, int(label) yield {'word': word_vector, 'label': int(label)}
def predict_initializer(settings, dictionary, **kwargs): def predict_initializer(settings, dictionary, **kwargs):
settings.word_dict = dictionary settings.word_dict = dictionary
settings.input_types = [sparse_binary_vector(len(dictionary))] settings.input_types = {'word': sparse_binary_vector(len(dictionary))}
# Declaring a data provider for prediction. The difference with process # Declaring a data provider for prediction. The difference with process
...@@ -83,4 +83,4 @@ def process_predict(settings, file_name): ...@@ -83,4 +83,4 @@ def process_predict(settings, file_name):
for line in f: for line in f:
comment = line.strip().split() comment = line.strip().split()
word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment] word_vector = [settings.word_dict.get(w, UNK_IDX) for w in comment]
yield word_vector yield {'word': word_vector}
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,13 +19,13 @@ UNK_IDX = 0 ...@@ -19,13 +19,13 @@ UNK_IDX = 0
def initializer(settings, dictionary, **kwargs): def initializer(settings, dictionary, **kwargs):
settings.word_dict = dictionary settings.word_dict = dictionary
settings.input_types = [ settings.input_types = {
# Define the type of the first input as sequence of integer. # Define the type of the first input as sequence of integer.
# The value of the integers range from 0 to len(dictrionary)-1 # The value of the integers range from 0 to len(dictrionary)-1
integer_value_sequence(len(dictionary)), 'word': integer_value_sequence(len(dictionary)),
# Define the second input for label id # Define the second input for label id
integer_value(2) 'label': integer_value(2)
] }
@provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM) @provider(init_hook=initializer, cache=CacheType.CACHE_PASS_IN_MEM)
...@@ -35,15 +35,12 @@ def process(settings, file_name): ...@@ -35,15 +35,12 @@ def process(settings, file_name):
label, comment = line.strip().split('\t') label, comment = line.strip().split('\t')
words = comment.split() words = comment.split()
word_slot = [settings.word_dict.get(w, UNK_IDX) for w in words] word_slot = [settings.word_dict.get(w, UNK_IDX) for w in words]
yield word_slot, int(label) yield {'word': word_slot, 'label': int(label)}
def predict_initializer(settings, dictionary, **kwargs): def predict_initializer(settings, dictionary, **kwargs):
settings.word_dict = dictionary settings.word_dict = dictionary
settings.input_types = [ settings.input_types = {'word': integer_value_sequence(len(dictionary))}
integer_value(
len(dictionary), seq_type=SequenceType.SEQUENCE)
]
@provider(init_hook=predict_initializer, should_shuffle=False) @provider(init_hook=predict_initializer, should_shuffle=False)
...@@ -52,4 +49,4 @@ def process_predict(settings, file_name): ...@@ -52,4 +49,4 @@ def process_predict(settings, file_name):
for line in f: for line in f:
comment = line.strip().split() comment = line.strip().split()
word_slot = [settings.word_dict.get(w, UNK_IDX) for w in comment] word_slot = [settings.word_dict.get(w, UNK_IDX) for w in comment]
yield word_slot yield {'word': word_slot}
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# edit-mode: -*- python -*- # edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
This configuration is a demonstration of how to implement the stacked LSTM This configuration is a demonstration of how to implement the stacked LSTM
with residual connections, i.e. an LSTM layer takes the sum of the hidden states with residual connections, i.e. an LSTM layer takes the sum of the hidden states
...@@ -46,7 +45,8 @@ is_predict = get_config_arg('is_predict', bool, False) ...@@ -46,7 +45,8 @@ is_predict = get_config_arg('is_predict', bool, False)
trn = 'data/train.list' if not is_predict else None trn = 'data/train.list' if not is_predict else None
tst = 'data/test.list' if not is_predict else 'data/pred.list' tst = 'data/test.list' if not is_predict else 'data/pred.list'
process = 'process' if not is_predict else 'process_predict' process = 'process' if not is_predict else 'process_predict'
define_py_data_sources2(train_list=trn, define_py_data_sources2(
train_list=trn,
test_list=tst, test_list=tst,
module="dataprovider_emb", module="dataprovider_emb",
obj=process, obj=process,
...@@ -58,10 +58,9 @@ settings( ...@@ -58,10 +58,9 @@ settings(
learning_rate=2e-3, learning_rate=2e-3,
learning_method=AdamOptimizer(), learning_method=AdamOptimizer(),
regularization=L2Regularization(8e-4), regularization=L2Regularization(8e-4),
gradient_clipping_threshold=25 gradient_clipping_threshold=25)
)
bias_attr = ParamAttr(initial_std=0.,l2_rate=0.) bias_attr = ParamAttr(initial_std=0., l2_rate=0.)
data = data_layer(name="word", size=len(word_dict)) data = data_layer(name="word", size=len(word_dict))
emb = embedding_layer(input=data, size=128) emb = embedding_layer(input=data, size=128)
...@@ -73,17 +72,15 @@ for i in range(3): ...@@ -73,17 +72,15 @@ for i in range(3):
# The input to the current layer is the sum of the hidden state # The input to the current layer is the sum of the hidden state
# and input of the previous layer. # and input of the previous layer.
current_input = addto_layer(input=[previous_input, previous_hidden_state]) current_input = addto_layer(input=[previous_input, previous_hidden_state])
hidden_state = simple_lstm(input=current_input, size=128, hidden_state = simple_lstm(
lstm_cell_attr=ExtraAttr(drop_rate=0.1)) input=current_input, size=128, lstm_cell_attr=ExtraAttr(drop_rate=0.1))
previous_input, previous_hidden_state = current_input, hidden_state previous_input, previous_hidden_state = current_input, hidden_state
lstm = previous_hidden_state lstm = previous_hidden_state
lstm_last = pooling_layer(input=lstm, pooling_type=MaxPooling()) lstm_last = pooling_layer(input=lstm, pooling_type=MaxPooling())
output = fc_layer(input=lstm_last, size=2, output = fc_layer(
bias_attr=bias_attr, input=lstm_last, size=2, bias_attr=bias_attr, act=SoftmaxActivation())
act=SoftmaxActivation())
if is_predict: if is_predict:
maxid = maxid_layer(output) maxid = maxid_layer(output)
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -17,13 +17,14 @@ from paddle.trainer.PyDataProvider2 import * ...@@ -17,13 +17,14 @@ from paddle.trainer.PyDataProvider2 import *
def meta_to_header(meta, name): def meta_to_header(meta, name):
metas = meta[name]['__meta__']['raw_meta'] metas = meta[name]['__meta__']['raw_meta']
for each_meta in metas: for each_meta in metas:
slot_name = each_meta.get('name', '%s_id' % name)
if each_meta['type'] == 'id': if each_meta['type'] == 'id':
yield integer_value(each_meta['max']) yield slot_name, integer_value(each_meta['max'])
elif each_meta['type'] == 'embedding': elif each_meta['type'] == 'embedding':
is_seq = each_meta['seq'] == 'sequence' is_seq = each_meta['seq'] == 'sequence'
yield integer_value( yield slot_name, integer_value(
len(each_meta['dict']), len(each_meta['dict']),
seq_type=SequenceType.SEQUENCE seq_type=SequenceType.SEQUENCE
if is_seq else SequenceType.NO_SEQUENCE) if is_seq else SequenceType.NO_SEQUENCE)
elif each_meta['type'] == 'one_hot_dense': elif each_meta['type'] == 'one_hot_dense':
yield dense_vector(len(each_meta['dict'])) yield slot_name, dense_vector(len(each_meta['dict']))
#!/bin/env python2 #!/bin/env python2
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/env python2 #!/bin/env python2
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/env python2 #!/bin/env python2
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -16,6 +16,14 @@ from paddle.trainer.PyDataProvider2 import * ...@@ -16,6 +16,14 @@ from paddle.trainer.PyDataProvider2 import *
import common_utils # parse import common_utils # parse
def __list_to_map__(lst):
ret_val = dict()
for each in lst:
k, v = each
ret_val[k] = v
return ret_val
def hook(settings, meta, **kwargs): def hook(settings, meta, **kwargs):
""" """
Init hook is invoked before process data. It will set obj.slots and store Init hook is invoked before process data. It will set obj.slots and store
...@@ -34,12 +42,16 @@ def hook(settings, meta, **kwargs): ...@@ -34,12 +42,16 @@ def hook(settings, meta, **kwargs):
# second part is user features. # second part is user features.
# final part is rating score. # final part is rating score.
# header is a list of [USE_SEQ_OR_NOT?, SlotType] # header is a list of [USE_SEQ_OR_NOT?, SlotType]
headers = list(common_utils.meta_to_header(meta, 'movie')) movie_headers = list(common_utils.meta_to_header(meta, 'movie'))
headers.extend(list(common_utils.meta_to_header(meta, 'user'))) settings.movie_names = [h[0] for h in movie_headers]
headers.append(dense_vector(1)) # Score headers = movie_headers
user_headers = list(common_utils.meta_to_header(meta, 'user'))
settings.user_names = [h[0] for h in user_headers]
headers.extend(user_headers)
headers.append(("rating", dense_vector(1))) # Score
# slot types. # slot types.
settings.input_types = headers settings.input_types = __list_to_map__(headers)
settings.meta = meta settings.meta = meta
...@@ -57,20 +69,20 @@ def process(settings, filename): ...@@ -57,20 +69,20 @@ def process(settings, filename):
movie_meta = settings.meta['movie'][movie_id] movie_meta = settings.meta['movie'][movie_id]
user_meta = settings.meta['user'][user_id] user_meta = settings.meta['user'][user_id]
outputs = [movie_id - 1] outputs = [('movie_id', movie_id - 1)]
# Then add movie features # Then add movie features
for each_meta in movie_meta: for i, each_meta in enumerate(movie_meta):
outputs.append(each_meta) outputs.append((settings.movie_names[i + 1], each_meta))
# Then add user id. # Then add user id.
outputs.append(user_id - 1) outputs.append(('user_id', user_id - 1))
# Then add user features. # Then add user features.
for each_meta in user_meta: for i, each_meta in enumerate(user_meta):
outputs.append(each_meta) outputs.append((settings.user_names[i + 1], each_meta))
# Finally, add score # Finally, add score
outputs.append([score]) outputs.append(('rating', [score]))
# Return data to paddle # Return data to paddle
yield outputs yield __list_to_map__(outputs)
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/env python2 #!/bin/env python2
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -34,8 +34,8 @@ if __name__ == '__main__': ...@@ -34,8 +34,8 @@ if __name__ == '__main__':
network.loadParameters(model_path) network.loadParameters(model_path)
with open('./data/meta.bin', 'rb') as f: with open('./data/meta.bin', 'rb') as f:
meta = pickle.load(f) meta = pickle.load(f)
headers = list(meta_to_header(meta, 'movie')) headers = [h[1] for h in meta_to_header(meta, 'movie')]
headers.extend(list(meta_to_header(meta, 'user'))) headers.extend([h[1] for h in meta_to_header(meta, 'user')])
cvt = DataProviderConverter(headers) cvt = DataProviderConverter(headers)
while True: while True:
movie_id = int(raw_input("Input movie_id: ")) movie_id = int(raw_input("Input movie_id: "))
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -14,6 +14,15 @@ ...@@ -14,6 +14,15 @@
# limitations under the License. # limitations under the License.
set -e set -e
UNAME_STR=`uname`
if [[ ${UNAME_STR} == 'Linux' ]]; then
SHUF_PROG='shuf'
else
SHUF_PROG='gshuf'
fi
cd "$(dirname "$0")" cd "$(dirname "$0")"
delimiter='::' delimiter='::'
dir=ml-1m dir=ml-1m
...@@ -25,7 +34,7 @@ python meta_generator.py $dir meta.bin --config=meta_config.json ...@@ -25,7 +34,7 @@ python meta_generator.py $dir meta.bin --config=meta_config.json
echo 'split train/test file' echo 'split train/test file'
python split.py $dir/ratings.dat --delimiter=${delimiter} --test_ratio=0.1 python split.py $dir/ratings.dat --delimiter=${delimiter} --test_ratio=0.1
echo 'shuffle train file' echo 'shuffle train file'
shuf $dir/ratings.dat.train > ratings.dat.train ${SHUF_PROG} $dir/ratings.dat.train > ratings.dat.train
cp $dir/ratings.dat.test . cp $dir/ratings.dat.test .
echo "./data/ratings.dat.train" > train.list echo "./data/ratings.dat.train" > train.list
echo "./data/ratings.dat.test" > test.list echo "./data/ratings.dat.test" > test.list
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
...@@ -8,3 +8,7 @@ data/test.wsj.seq_pair ...@@ -8,3 +8,7 @@ data/test.wsj.seq_pair
data/test.wsj.words data/test.wsj.words
data/tgt.dict data/tgt.dict
output output
data/emb
data/targetDict.txt
data/verbDict.txt
data/wordDict.txt
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -43,13 +43,13 @@ def extract_dict_features(pair_file, feature_file): ...@@ -43,13 +43,13 @@ def extract_dict_features(pair_file, feature_file):
mark[verb_index] = 1 mark[verb_index] = 1
ctx_0 = sentence_list[verb_index] ctx_0 = sentence_list[verb_index]
if verb_index < len(labels_list) - 2: if verb_index < len(labels_list) - 1:
mark[verb_index + 1] = 1 mark[verb_index + 1] = 1
ctx_p1 = sentence_list[verb_index + 1] ctx_p1 = sentence_list[verb_index + 1]
else: else:
ctx_p1 = 'eos' ctx_p1 = 'eos'
if verb_index < len(labels_list) - 3: if verb_index < len(labels_list) - 2:
mark[verb_index + 2] = 1 mark[verb_index + 2] = 1
ctx_p2 = sentence_list[verb_index + 2] ctx_p2 = sentence_list[verb_index + 2]
else: else:
...@@ -69,7 +69,6 @@ def extract_dict_features(pair_file, feature_file): ...@@ -69,7 +69,6 @@ def extract_dict_features(pair_file, feature_file):
feature_out.write(feature_str + '\n') feature_out.write(feature_str + '\n')
if __name__ == '__main__': if __name__ == '__main__':
usage = '-p pair_file -f feature_file' usage = '-p pair_file -f feature_file'
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -66,7 +66,7 @@ def transform_labels(sentences, labels): ...@@ -66,7 +66,7 @@ def transform_labels(sentences, labels):
else: else:
verb_list = [] verb_list = []
for x in labels[i][0]: for x in labels[i][0]:
if x !='-': if x != '-':
verb_list.append(x) verb_list.append(x)
for j in xrange(1, len(labels[i])): for j in xrange(1, len(labels[i])):
...@@ -93,7 +93,7 @@ def transform_labels(sentences, labels): ...@@ -93,7 +93,7 @@ def transform_labels(sentences, labels):
is_in_bracket = True is_in_bracket = True
else: else:
print 'error:', ll print 'error:', ll
sen_lab_pair.append((sentences[i], verb_list[j-1], label_seq)) sen_lab_pair.append((sentences[i], verb_list[j - 1], label_seq))
return sen_lab_pair return sen_lab_pair
...@@ -103,7 +103,7 @@ def write_file(sen_lab_pair, output_file): ...@@ -103,7 +103,7 @@ def write_file(sen_lab_pair, output_file):
sentence = x[0] sentence = x[0]
label_seq = ' '.join(x[2]) label_seq = ' '.join(x[2])
assert len(sentence.split()) == len(x[2]) assert len(sentence.split()) == len(x[2])
fout.write(sentence + '\t' + x[1]+'\t' +label_seq + '\n') fout.write(sentence + '\t' + x[1] + '\t' + label_seq + '\n')
if __name__ == '__main__': if __name__ == '__main__':
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -30,8 +30,7 @@ def hook(settings, word_dict, label_dict, predicate_dict, **kwargs): ...@@ -30,8 +30,7 @@ def hook(settings, word_dict, label_dict, predicate_dict, **kwargs):
integer_value_sequence(len(word_dict)), integer_value_sequence(len(word_dict)),
integer_value_sequence(len(word_dict)), integer_value_sequence(len(word_dict)),
integer_value_sequence(len(word_dict)), integer_value_sequence(len(word_dict)),
integer_value_sequence(len(predicate_dict)), integer_value_sequence(len(predicate_dict)), integer_value_sequence(2),
integer_value_sequence(2),
integer_value_sequence(len(label_dict)) integer_value_sequence(len(label_dict))
] ]
...@@ -40,8 +39,12 @@ def get_batch_size(yeild_data): ...@@ -40,8 +39,12 @@ def get_batch_size(yeild_data):
return len(yeild_data[0]) return len(yeild_data[0])
@provider(init_hook=hook, should_shuffle=True, calc_batch_size=get_batch_size, @provider(
can_over_batch_size=False, cache=CacheType.CACHE_PASS_IN_MEM) init_hook=hook,
should_shuffle=True,
calc_batch_size=get_batch_size,
can_over_batch_size=False,
cache=CacheType.CACHE_PASS_IN_MEM)
def process(settings, file_name): def process(settings, file_name):
with open(file_name, 'r') as fdata: with open(file_name, 'r') as fdata:
for line in fdata: for line in fdata:
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -20,7 +20,7 @@ from paddle.trainer_config_helpers import * ...@@ -20,7 +20,7 @@ from paddle.trainer_config_helpers import *
#file paths #file paths
word_dict_file = './data/wordDict.txt' word_dict_file = './data/wordDict.txt'
label_dict_file = './data/targetDict.txt' label_dict_file = './data/targetDict.txt'
predicate_file= './data/verbDict.txt' predicate_file = './data/verbDict.txt'
train_list_file = './data/train.list' train_list_file = './data/train.list'
test_list_file = './data/test.list' test_list_file = './data/test.list'
...@@ -47,7 +47,6 @@ if not is_predict: ...@@ -47,7 +47,6 @@ if not is_predict:
w = line.strip() w = line.strip()
predicate_dict[w] = i predicate_dict[w] = i
if is_test: if is_test:
train_list_file = None train_list_file = None
...@@ -57,9 +56,11 @@ if not is_predict: ...@@ -57,9 +56,11 @@ if not is_predict:
test_list=test_list_file, test_list=test_list_file,
module='dataprovider', module='dataprovider',
obj='process', obj='process',
args={'word_dict': word_dict, args={
'word_dict': word_dict,
'label_dict': label_dict, 'label_dict': label_dict,
'predicate_dict': predicate_dict }) 'predicate_dict': predicate_dict
})
word_dict_len = len(word_dict) word_dict_len = len(word_dict)
label_dict_len = len(label_dict) label_dict_len = len(label_dict)
...@@ -77,24 +78,16 @@ mark_dim = 5 ...@@ -77,24 +78,16 @@ mark_dim = 5
hidden_dim = 512 hidden_dim = 512
depth = 8 depth = 8
########################### Optimizer ####################################### ########################### Optimizer #######################################
settings( settings(
batch_size=150, batch_size=150,
learning_method=MomentumOptimizer(momentum=0), learning_method=MomentumOptimizer(momentum=0),
learning_rate=2e-2, learning_rate=2e-2,
regularization=L2Regularization(8e-4), regularization=L2Regularization(8e-4),
is_async=False, is_async=False,
model_average=ModelAverage(average_window=0.5, model_average=ModelAverage(
max_average_window=10000), average_window=0.5, max_average_window=10000), )
)
####################################### network ############################## ####################################### network ##############################
#8 features and 1 target #8 features and 1 target
...@@ -108,22 +101,28 @@ ctx_p1 = data_layer(name='ctx_p1_data', size=word_dict_len) ...@@ -108,22 +101,28 @@ ctx_p1 = data_layer(name='ctx_p1_data', size=word_dict_len)
ctx_p2 = data_layer(name='ctx_p2_data', size=word_dict_len) ctx_p2 = data_layer(name='ctx_p2_data', size=word_dict_len)
mark = data_layer(name='mark_data', size=mark_dict_len) mark = data_layer(name='mark_data', size=mark_dict_len)
if not is_predict: if not is_predict:
target = data_layer(name='target', size=label_dict_len) target = data_layer(name='target', size=label_dict_len)
default_std = 1 / math.sqrt(hidden_dim) / 3.0
default_std=1/math.sqrt(hidden_dim)/3.0
emb_para = ParameterAttribute(name='emb', initial_std=0., learning_rate=0.) emb_para = ParameterAttribute(name='emb', initial_std=0., learning_rate=0.)
std_0 = ParameterAttribute(initial_std=0.) std_0 = ParameterAttribute(initial_std=0.)
std_default = ParameterAttribute(initial_std=default_std) std_default = ParameterAttribute(initial_std=default_std)
predicate_embedding = embedding_layer(size=word_dim, input=predicate, param_attr=ParameterAttribute(name='vemb',initial_std=default_std)) predicate_embedding = embedding_layer(
mark_embedding = embedding_layer(name='word_ctx-in_embedding', size=mark_dim, input=mark, param_attr=std_0) size=word_dim,
input=predicate,
word_input=[word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2] param_attr=ParameterAttribute(
emb_layers = [embedding_layer(size=word_dim, input=x, param_attr=emb_para) for x in word_input] name='vemb', initial_std=default_std))
mark_embedding = embedding_layer(
name='word_ctx-in_embedding', size=mark_dim, input=mark, param_attr=std_0)
word_input = [word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2]
emb_layers = [
embedding_layer(
size=word_dim, input=x, param_attr=emb_para) for x in word_input
]
emb_layers.append(predicate_embedding) emb_layers.append(predicate_embedding)
emb_layers.append(mark_embedding) emb_layers.append(mark_embedding)
...@@ -131,14 +130,18 @@ hidden_0 = mixed_layer( ...@@ -131,14 +130,18 @@ hidden_0 = mixed_layer(
name='hidden0', name='hidden0',
size=hidden_dim, size=hidden_dim,
bias_attr=std_default, bias_attr=std_default,
input=[ full_matrix_projection(input=emb, param_attr=std_default ) for emb in emb_layers ]) input=[
full_matrix_projection(
input=emb, param_attr=std_default) for emb in emb_layers
])
mix_hidden_lr = 1e-3 mix_hidden_lr = 1e-3
lstm_para_attr = ParameterAttribute(initial_std=0.0, learning_rate=1.0) lstm_para_attr = ParameterAttribute(initial_std=0.0, learning_rate=1.0)
hidden_para_attr = ParameterAttribute(initial_std=default_std, learning_rate=mix_hidden_lr) hidden_para_attr = ParameterAttribute(
initial_std=default_std, learning_rate=mix_hidden_lr)
lstm_0 = lstmemory(name='lstm0', lstm_0 = lstmemory(
name='lstm0',
input=hidden_0, input=hidden_0,
act=ReluActivation(), act=ReluActivation(),
gate_act=SigmoidActivation(), gate_act=SigmoidActivation(),
...@@ -149,66 +152,67 @@ lstm_0 = lstmemory(name='lstm0', ...@@ -149,66 +152,67 @@ lstm_0 = lstmemory(name='lstm0',
#stack L-LSTM and R-LSTM with direct edges #stack L-LSTM and R-LSTM with direct edges
input_tmp = [hidden_0, lstm_0] input_tmp = [hidden_0, lstm_0]
for i in range(1, depth): for i in range(1, depth):
mix_hidden = mixed_layer(name='hidden'+str(i), mix_hidden = mixed_layer(
name='hidden' + str(i),
size=hidden_dim, size=hidden_dim,
bias_attr=std_default, bias_attr=std_default,
input=[full_matrix_projection(input=input_tmp[0], param_attr=hidden_para_attr), input=[
full_matrix_projection(input=input_tmp[1], param_attr=lstm_para_attr) full_matrix_projection(
] input=input_tmp[0], param_attr=hidden_para_attr),
) full_matrix_projection(
input=input_tmp[1], param_attr=lstm_para_attr)
lstm = lstmemory(name='lstm'+str(i), ])
lstm = lstmemory(
name='lstm' + str(i),
input=mix_hidden, input=mix_hidden,
act=ReluActivation(), act=ReluActivation(),
gate_act=SigmoidActivation(), gate_act=SigmoidActivation(),
state_act=SigmoidActivation(), state_act=SigmoidActivation(),
reverse=((i % 2)==1), reverse=((i % 2) == 1),
bias_attr=std_0, bias_attr=std_0,
param_attr=lstm_para_attr) param_attr=lstm_para_attr)
input_tmp = [mix_hidden, lstm] input_tmp = [mix_hidden, lstm]
feature_out = mixed_layer(name='output', feature_out = mixed_layer(
name='output',
size=label_dict_len, size=label_dict_len,
bias_attr=std_default, bias_attr=std_default,
input=[full_matrix_projection(input=input_tmp[0], param_attr=hidden_para_attr), input=[
full_matrix_projection(input=input_tmp[1], param_attr=lstm_para_attr) full_matrix_projection(
], input=input_tmp[0], param_attr=hidden_para_attr),
) full_matrix_projection(
input=input_tmp[1], param_attr=lstm_para_attr)
], )
if not is_predict: if not is_predict:
crf_l = crf_layer( name = 'crf', crf_l = crf_layer(
size = label_dict_len, name='crf',
input = feature_out, size=label_dict_len,
label = target, input=feature_out,
param_attr=ParameterAttribute(name='crfw',initial_std=default_std, learning_rate=mix_hidden_lr) label=target,
param_attr=ParameterAttribute(
) name='crfw', initial_std=default_std, learning_rate=mix_hidden_lr))
crf_dec_l = crf_decoding_layer(name = 'crf_dec_l',
size = label_dict_len,
input = feature_out,
label = target,
param_attr=ParameterAttribute(name='crfw')
)
crf_dec_l = crf_decoding_layer(
name='crf_dec_l',
size=label_dict_len,
input=feature_out,
label=target,
param_attr=ParameterAttribute(name='crfw'))
eval = sum_evaluator(input=crf_dec_l) eval = sum_evaluator(input=crf_dec_l)
outputs(crf_l) outputs(crf_l)
else: else:
crf_dec_l = crf_decoding_layer(name = 'crf_dec_l', crf_dec_l = crf_decoding_layer(
size = label_dict_len, name='crf_dec_l',
input = feature_out, size=label_dict_len,
param_attr=ParameterAttribute(name='crfw') input=feature_out,
) param_attr=ParameterAttribute(name='crfw'))
outputs(crf_dec_l) outputs(crf_dec_l)
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,7 +26,8 @@ UNK_IDX = 0 ...@@ -26,7 +26,8 @@ UNK_IDX = 0
class Prediction(): class Prediction():
def __init__(self, train_conf, dict_file, model_dir, label_file, predicate_dict_file): def __init__(self, train_conf, dict_file, model_dir, label_file,
predicate_dict_file):
""" """
train_conf: trainer configure. train_conf: trainer configure.
dict_file: word dictionary file name. dict_file: word dictionary file name.
...@@ -35,7 +36,7 @@ class Prediction(): ...@@ -35,7 +36,7 @@ class Prediction():
self.dict = {} self.dict = {}
self.labels = {} self.labels = {}
self.predicate_dict={} self.predicate_dict = {}
self.labels_reverse = {} self.labels_reverse = {}
self.load_dict_label(dict_file, label_file, predicate_dict_file) self.load_dict_label(dict_file, label_file, predicate_dict_file)
...@@ -44,24 +45,17 @@ class Prediction(): ...@@ -44,24 +45,17 @@ class Prediction():
len_pred = len(self.predicate_dict) len_pred = len(self.predicate_dict)
conf = parse_config( conf = parse_config(
train_conf, train_conf, 'dict_len=' + str(len_dict) + ',label_len=' +
'dict_len=' + str(len_dict) + str(len_label) + ',pred_len=' + str(len_pred) + ',is_predict=True')
',label_len=' + str(len_label) +
',pred_len=' + str(len_pred) +
',is_predict=True')
self.network = swig_paddle.GradientMachine.createFromConfigProto( self.network = swig_paddle.GradientMachine.createFromConfigProto(
conf.model_config) conf.model_config)
self.network.loadParameters(model_dir) self.network.loadParameters(model_dir)
slots = [ slots = [
integer_value_sequence(len_dict), integer_value_sequence(len_dict), integer_value_sequence(len_dict),
integer_value_sequence(len_dict), integer_value_sequence(len_dict), integer_value_sequence(len_dict),
integer_value_sequence(len_dict), integer_value_sequence(len_dict), integer_value_sequence(len_dict),
integer_value_sequence(len_dict), integer_value_sequence(len_pred), integer_value_sequence(2)
integer_value_sequence(len_dict),
integer_value_sequence(len_dict),
integer_value_sequence(len_pred),
integer_value_sequence(2)
] ]
self.converter = DataProviderConverter(slots) self.converter = DataProviderConverter(slots)
...@@ -78,6 +72,7 @@ class Prediction(): ...@@ -78,6 +72,7 @@ class Prediction():
for line_count, line in enumerate(open(predicate_dict_file, 'r')): for line_count, line in enumerate(open(predicate_dict_file, 'r')):
self.predicate_dict[line.strip()] = line_count self.predicate_dict[line.strip()] = line_count
def get_data(self, data_file): def get_data(self, data_file):
""" """
Get input data of paddle format. Get input data of paddle format.
...@@ -90,7 +85,8 @@ class Prediction(): ...@@ -90,7 +85,8 @@ class Prediction():
sen_len = len(words) sen_len = len(words)
word_slot = [self.dict.get(w, UNK_IDX) for w in words] word_slot = [self.dict.get(w, UNK_IDX) for w in words]
predicate_slot = [self.predicate_dict.get(predicate, UNK_IDX)] * sen_len predicate_slot = [self.predicate_dict.get(predicate, UNK_IDX)
] * sen_len
ctx_n2_slot = [self.dict.get(ctx_n2, UNK_IDX)] * sen_len ctx_n2_slot = [self.dict.get(ctx_n2, UNK_IDX)] * sen_len
ctx_n1_slot = [self.dict.get(ctx_n1, UNK_IDX)] * sen_len ctx_n1_slot = [self.dict.get(ctx_n1, UNK_IDX)] * sen_len
ctx_0_slot = [self.dict.get(ctx_0, UNK_IDX)] * sen_len ctx_0_slot = [self.dict.get(ctx_0, UNK_IDX)] * sen_len
...@@ -123,7 +119,8 @@ class Prediction(): ...@@ -123,7 +119,8 @@ class Prediction():
def option_parser(): def option_parser():
usage = ("python predict.py -c config -w model_dir " usage = (
"python predict.py -c config -w model_dir "
"-d word dictionary -l label_file -i input_file -p pred_dict_file") "-d word dictionary -l label_file -i input_file -p pred_dict_file")
parser = OptionParser(usage="usage: %s [options]" % usage) parser = OptionParser(usage="usage: %s [options]" % usage)
parser.add_option( parser.add_option(
...@@ -187,8 +184,9 @@ def main(): ...@@ -187,8 +184,9 @@ def main():
output_file = options.output_file output_file = options.output_file
swig_paddle.initPaddle("--use_gpu=0") swig_paddle.initPaddle("--use_gpu=0")
predict = Prediction(train_conf, dict_file, model_path, label_file, predict_dict_file) predict = Prediction(train_conf, dict_file, model_path, label_file,
predict.predict(data_file,output_file) predict_dict_file)
predict.predict(data_file, output_file)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os import os, sys
import numpy as np import numpy as np
from optparse import OptionParser from optparse import OptionParser
from py_paddle import swig_paddle, DataProviderConverter from py_paddle import swig_paddle, DataProviderConverter
...@@ -66,34 +66,24 @@ class SentimentPrediction(): ...@@ -66,34 +66,24 @@ class SentimentPrediction():
for v in open(label_file, 'r'): for v in open(label_file, 'r'):
self.label[int(v.split('\t')[1])] = v.split('\t')[0] self.label[int(v.split('\t')[1])] = v.split('\t')[0]
def get_data(self, data_file): def get_index(self, data):
""" """
Get input data of paddle format. transform word into integer index according to the dictionary.
""" """
with open(data_file, 'r') as fdata: words = data.strip().split()
for line in fdata: word_slot = [self.word_dict[w] for w in words if w in self.word_dict]
words = line.strip().split() return word_slot
word_slot = [
self.word_dict[w] for w in words if w in self.word_dict def batch_predict(self, data_batch):
] input = self.converter(data_batch)
if not word_slot:
print "all words are not in dictionary: %s", line
continue
yield [word_slot]
def predict(self, data_file):
"""
data_file: file name of input data.
"""
input = self.converter(self.get_data(data_file))
output = self.network.forwardTest(input) output = self.network.forwardTest(input)
prob = output[0]["value"] prob = output[0]["value"]
lab = np.argsort(-prob) labs = np.argsort(-prob)
for idx, lab in enumerate(labs):
if self.label is None: if self.label is None:
print("%s: predicting label is %d" % (data_file, lab[0][0])) print("predicting label is %d" % (lab[0]))
else: else:
print("%s: predicting label is %s" % print("predicting label is %s" % (self.label[lab[0]]))
(data_file, self.label[lab[0][0]]))
def option_parser(): def option_parser():
...@@ -119,11 +109,13 @@ def option_parser(): ...@@ -119,11 +109,13 @@ def option_parser():
default=None, default=None,
help="dictionary file") help="dictionary file")
parser.add_option( parser.add_option(
"-i", "-c",
"--data", "--batch_size",
type="int",
action="store", action="store",
dest="data", dest="batch_size",
help="data file to predict") default=1,
help="the batch size for prediction")
parser.add_option( parser.add_option(
"-w", "-w",
"--model", "--model",
...@@ -137,13 +129,21 @@ def option_parser(): ...@@ -137,13 +129,21 @@ def option_parser():
def main(): def main():
options, args = option_parser() options, args = option_parser()
train_conf = options.train_conf train_conf = options.train_conf
data = options.data batch_size = options.batch_size
dict_file = options.dict_file dict_file = options.dict_file
model_path = options.model_path model_path = options.model_path
label = options.label label = options.label
swig_paddle.initPaddle("--use_gpu=0") swig_paddle.initPaddle("--use_gpu=0")
predict = SentimentPrediction(train_conf, dict_file, model_path, label) predict = SentimentPrediction(train_conf, dict_file, model_path, label)
predict.predict(data)
batch = []
for line in sys.stdin:
batch.append([predict.get_index(line)])
if len(batch) == batch_size:
predict.batch_predict(batch)
batch = []
if len(batch) > 0:
predict.batch_predict(batch)
if __name__ == '__main__': if __name__ == '__main__':
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -19,9 +19,9 @@ set -e ...@@ -19,9 +19,9 @@ set -e
model=model_output/pass-00002/ model=model_output/pass-00002/
config=trainer_config.py config=trainer_config.py
label=data/pre-imdb/labels.list label=data/pre-imdb/labels.list
python predict.py \ cat ./data/aclImdb/test/pos/10007_10.txt | python predict.py \
-n $config\ --tconf=$config\
-w $model \ --model=$model \
-b $label \ --label=$label \
-d ./data/pre-imdb/dict.txt \ --dict=./data/pre-imdb/dict.txt \
-i ./data/aclImdb/test/pos/10007_10.txt --batch_size=1
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
#!/bin/bash #!/bin/bash
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved # Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册