diff --git a/benchmark/paddle/image/alexnet.py b/benchmark/paddle/image/alexnet.py index 77d130ae34059d1e87040d00346ac1dadd86b0d8..cad6051f1413a5bb95f87a940f3aa81e49e5d282 100644 --- a/benchmark/paddle/image/alexnet.py +++ b/benchmark/paddle/image/alexnet.py @@ -19,7 +19,11 @@ args = { 'num_samples': num_samples } define_py_data_sources2( - "train.list", None, module="provider", obj="process", args=args) + "train.list" if not is_infer else None, + "test.list" if is_infer else None, + module="provider", + obj="process", + args=args) settings( batch_size=batch_size, diff --git a/benchmark/paddle/image/run_openblas_infer.sh b/benchmark/paddle/image/run_openblas_infer.sh index da034f3b9dff794e22086a5295ad2b0c2361c356..71a49231a5527ebee9f45d5f4650ce2a4f6a1c31 100755 --- a/benchmark/paddle/image/run_openblas_infer.sh +++ b/benchmark/paddle/image/run_openblas_infer.sh @@ -8,15 +8,19 @@ function clock_to_seconds() { } function infer() { - unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY topology=$1 layer_num=$2 bs=$3 - thread=`nproc` - if [ $thread -gt $bs ]; then - thread=$bs + trainers=`nproc` + if [ $trainers -gt $bs ]; then + trainers=$bs fi - log="logs/infer-${topology}-${layer_num}-${thread}openblas-${bs}.log" + log="logs/infer-${topology}-${layer_num}-${trainers}openblas-${bs}.log" + threads=$((`nproc` / trainers)) + if [ $threads -eq 0 ]; then + threads=1 + fi + export OPENBLAS_NUM_THREADS=$threads models_in="models/${topology}-${layer_num}/pass-00000/" if [ ! -d $models_in ]; then @@ -28,7 +32,7 @@ function infer() { --config="${topology}.py" \ --use_mkldnn=False \ --use_gpu=False \ - --trainer_count=$thread \ + --trainer_count=$trainers \ --log_period=$log_period \ --config_args="batch_size=${bs},layer_num=${layer_num},is_infer=True,num_samples=256" \ --init_model_path=$models_in \ diff --git a/benchmark/paddle/image/run_openblas_train.sh b/benchmark/paddle/image/run_openblas_train.sh index e9df83fee2a3f796b7234b39619364f6ee4d5dc9..935cff6f2c97d25d6de556cfee25e27dbe49b5b6 100755 --- a/benchmark/paddle/image/run_openblas_train.sh +++ b/benchmark/paddle/image/run_openblas_train.sh @@ -1,7 +1,7 @@ set -e function train() { - unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY + export OPENBLAS_NUM_THREADS=1 topology=$1 layer_num=$2 bs=$3 diff --git a/doc/getstarted/build_and_install/docker_install_cn.rst b/doc/getstarted/build_and_install/docker_install_cn.rst index fa1b6a372728ccac128d2e6e79a6514b8884ea3f..bae42593ddc6f7a7eb47d603752ad6efa9820b45 100644 --- a/doc/getstarted/build_and_install/docker_install_cn.rst +++ b/doc/getstarted/build_and_install/docker_install_cn.rst @@ -15,7 +15,7 @@ 获取PaddlePaddle的Docker镜像 ------------------------------ -执行下面的命令获取最新的PaddlePaddle Docker镜像 +执行下面的命令获取最新的PaddlePaddle Docker镜像,版本为cpu_avx_mkl: .. code-block:: bash @@ -27,7 +27,7 @@ docker pull docker.paddlepaddle.org/paddle -下载GPU版本的Docker镜像: +下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像: .. code-block:: bash @@ -54,7 +54,7 @@ .. _docker_run: 在Docker中执行PaddlePaddle训练程序 ------------------------------- +---------------------------------- 假设您已经在当前目录(比如在/home/work)编写了一个PaddlePaddle的程序 :code:`train.py` (可以参考 `PaddlePaddleBook `_ @@ -82,7 +82,7 @@ .. _docker_run_book: 使用Docker启动PaddlePaddle Book教程 ------------------------------- +----------------------------------- 使用Docker可以快速在本地启动一个包含了PaddlePaddle官方Book教程的Jupyter Notebook,可以通过网页浏览。 PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Notebook。 diff --git a/doc/getstarted/build_and_install/docker_install_en.rst b/doc/getstarted/build_and_install/docker_install_en.rst index 06012bf65e75c32957516f6b7f62e09480871b84..56a7c68e4d39c45249fa55a964dc48b7081596a6 100644 --- a/doc/getstarted/build_and_install/docker_install_en.rst +++ b/doc/getstarted/build_and_install/docker_install_en.rst @@ -16,7 +16,7 @@ After you've read above tutorials you may proceed the following steps. Pull PaddlePaddle Docker Image ------------------------------ -Run the following command to download the latest Docker images: +Run the following command to download the latest Docker images, the version is cpu_avx_mkl: .. code-block:: bash @@ -28,7 +28,7 @@ For users in China, we provide a faster mirror: docker pull docker.paddlepaddle.org/paddle -Download GPU version images: +Download GPU version (cuda8.0_cudnn5_avx_mkl) images: .. code-block:: bash @@ -58,7 +58,7 @@ and run: .. _docker_run: Launch your training program in Docker ------------------------------- +-------------------------------------- Assume that you have already written a PaddlePaddle program named :code:`train.py` under directory :code:`/home/work` (refer to diff --git a/doc/getstarted/build_and_install/pip_install_cn.rst b/doc/getstarted/build_and_install/pip_install_cn.rst index a4587f82a984acf243f49834e707fcd66d5b1252..0c741e936b46eda5e7165e4ee54b545b14a28a19 100644 --- a/doc/getstarted/build_and_install/pip_install_cn.rst +++ b/doc/getstarted/build_and_install/pip_install_cn.rst @@ -11,14 +11,14 @@ PaddlePaddle可以使用常用的Python包管理工具 ------------------------------ -执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件。 +执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件,版本为cpu_avx_openblas。 .. code-block:: bash pip install paddlepaddle -如果需要安装支持GPU的版本,需要执行: +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: .. code-block:: bash diff --git a/doc/getstarted/build_and_install/pip_install_en.rst b/doc/getstarted/build_and_install/pip_install_en.rst index 55e31560a0f5087ab69966a6281c6c8573c04204..285ed09805b09790beaef014f6813c227aff33ac 100644 --- a/doc/getstarted/build_and_install/pip_install_en.rst +++ b/doc/getstarted/build_and_install/pip_install_en.rst @@ -12,14 +12,14 @@ Install Using pip ------------------------------ Run the following command to install PaddlePaddle on the current -machine, it will also download requirements. +machine, it will also download requirements, the version is cpu_avx_openblas. .. code-block:: bash pip install paddlepaddle -If you wish to install GPU version, just run: +If you wish to install GPU version (cuda7.5_cudnn5_avx_openblas), just run: .. code-block:: bash diff --git a/doc/getstarted/index_cn.rst b/doc/getstarted/index_cn.rst index a9087be6f350c5656cabb0c64ba0f200d1c666cc..9f6ee25987d51dcca3a37cf0f62a70a5a5a2d89a 100644 --- a/doc/getstarted/index_cn.rst +++ b/doc/getstarted/index_cn.rst @@ -7,13 +7,13 @@ ++++++++ PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.04以及MacOS 10.12,并安装有Python2.7。 -执行下面的命令完成快速安装: +执行下面的命令完成快速安装,版本为cpu_avx_openblas: .. code-block:: bash pip install paddlepaddle -如果需要安装支持GPU的版本,需要执行: +如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行: .. code-block:: bash diff --git a/doc/getstarted/index_en.rst b/doc/getstarted/index_en.rst index d14e3f5c0cc90792fce9cb82e65da482c44dc433..063d9d880c82550f7f5d47d3d0b1fff59865bca7 100644 --- a/doc/getstarted/index_en.rst +++ b/doc/getstarted/index_en.rst @@ -8,13 +8,13 @@ Quick Install You can use pip to install PaddlePaddle with a single command, supports CentOS 6 above, Ubuntu 14.04 above or MacOS 10.12, with Python 2.7 installed. -Simply run the following command to install: +Simply run the following command to install, the version is cpu_avx_openblas: .. code-block:: bash pip install paddlepaddle -If you need to install GPU version, run: +If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run: .. code-block:: bash diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8bfa41715ffa82c8f6fb321e1a562649d429672f..6788cb34fbaf5941cbb1537c7a83577c623bf76a 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -5,10 +5,18 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim) -cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context framework_proto) +if (WITH_GPU) + nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto) +else() + cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto) +endif () cc_test(tensor_test SRCS tensor_test.cc DEPS tensor) -cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) +if (WITH_GPU) + nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor) +else() + cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor) +endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index 2191dd3783d5ed7bb59b96c70d38a72bb0b2fee7..bd6d301c12e0611c5b01c3ff58869dbeb96b268e 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -27,9 +27,8 @@ limitations under the License. */ namespace paddle { namespace framework { -using DataTransformFn = - std::function ctx, - const Variable& in, Variable* out)>; +using DataTransformFn = std::function; using KernelTypePair = std::pair; struct KernelTypePairHash { diff --git a/paddle/framework/data_transform_test.cc b/paddle/framework/data_transform_test.cc index 4e2141ecd2ebe35402a8a04613702a2f79f6a179..5f05e881fa16eead1dc690f85375706bf3cd3e6d 100644 --- a/paddle/framework/data_transform_test.cc +++ b/paddle/framework/data_transform_test.cc @@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1}); auto kernel2 = GenFromBit({0, 0, 1, 0}); auto kernel3 = GenFromBit({0, 0, 1, 1}); -void TransDataType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value++; } -void TransDataLayout_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value--; } -void TransLibraryType_t(std::vector ctx, - const Variable& in, Variable* out) { +void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in, + Variable* out) { test_value += 2; } @@ -83,7 +83,8 @@ TEST(DataTransform, Register) { using namespace paddle::platform; auto& instance = DataTransformFnMap::Instance(); - std::vector ctx; + ASSERT_EQ(instance.Map().size(), 3UL); + DeviceContext* ctx = nullptr; paddle::framework::Variable in; paddle::framework::Variable out; diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 31749743a58835ee2209d5a448f28b011cb3f7af..bf1f0471ccbfccf13cb6f74c8088da7acd68ec0b 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -14,18 +14,17 @@ limitations under the License. */ #include "paddle/framework/executor.h" -#include -#include -#include #include -#include +#include "gflags/gflags.h" #include "paddle/framework/feed_fetch_type.h" #include "paddle/framework/lod_rank_table.h" -#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/op_registry.h" -#include "paddle/framework/scope.h" + +DEFINE_bool(check_nan_inf, false, + "Checking whether operator produce NAN/INF or not. It will be " + "extremely slow so please use this flag wisely."); namespace paddle { namespace framework { @@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { } } +static void CheckTensorNANOrInf(const std::string& name, + const framework::Tensor& tensor) { + if (tensor.memory_size() == 0) { + return; + } + if (tensor.type().hash_code() != typeid(float).hash_code() && + tensor.type().hash_code() != typeid(double).hash_code()) { + return; + } + PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name); + PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name); +} + void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, bool create_local_scope, bool create_vars) { // TODO(tonyyang-svail): @@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); VLOG(3) << op->DebugString(); op->Run(*local_scope, place_); + if (FLAGS_check_nan_inf) { + for (auto& vname : op->OutputVars(true)) { + auto* var = local_scope->FindVar(vname); + if (var == nullptr) continue; + if (var->IsType()) { + CheckTensorNANOrInf(vname, var->Get()); + } + } + } } if (create_vars && create_local_scope) { scope->DeleteScope(local_scope); diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index c0be11294c4a6b49ae4bc2f805f76e9f04508349..a3ce96c409675ad52a811586c736ca22b5c7e99e 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext { const Scope& scope_; }; +const platform::DeviceContext* GetDeviceContext( + framework::KernelTypePair& kernel_pair) { + auto& actual_kernel_key = kernel_pair.first; + auto& expected_kernel_key = kernel_pair.second; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + + if (platform::is_gpu_place(actual_kernel_key.place_) && + platform::is_cpu_place(expected_kernel_key.place_)) { + return pool.Get(actual_kernel_key.place_); + } else if (platform::is_cpu_place(actual_kernel_key.place_) && + platform::is_gpu_place(expected_kernel_key.place_)) { + return pool.Get(expected_kernel_key.place_); + } else { + PADDLE_THROW( + "Currently, model parallelism is only supported between CPU and CUDA"); + } +} + void OperatorWithKernel::Run(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); @@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope, "CPU and other devices. For example, multi-GPU model " "parallelism will failed."); } else { + auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key); const DataTransformFn* trans_fun = - DataTransformFnMap::Instance().GetNullable( - std::make_pair(actual_kernel_key, expected_kernel_key)); + DataTransformFnMap::Instance().GetNullable(kernel_pair); if (trans_fun) { auto input_vars = this->InputVars(); // TODO(qijun) filter the input vars that do not need to be transformed @@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope, } if (!need_trans.empty()) { - // TODO(qijun) get appropriate DeviceContext from DeviceContext pool - platform::DeviceContext* trans_dev_ctx = nullptr; - std::vector trans_dev_ctx_vec{trans_dev_ctx}; + auto trans_dev_ctx = GetDeviceContext(kernel_pair); // Wait for transform starting dev_ctx->Wait(); for (auto var_name : need_trans) { - (*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + (*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)), scope.FindVar(var_name + framework::KernelTypeToString( expected_kernel_key))); } // Wait for data transform finishing - for (auto ctx : trans_dev_ctx_vec) { - ctx->Wait(); - } + trans_dev_ctx->Wait(); } } } diff --git a/paddle/framework/tensor_util.cc b/paddle/framework/tensor_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..7efc649d0bcda67c663d148e83bcbb6789b0f371 --- /dev/null +++ b/paddle/framework/tensor_util.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/framework/tensor_util.h" + +namespace paddle { +namespace framework { +template +struct AnyDTypeVisitor { + Predicate predicate_; + const Tensor& tensor_; + const DevCtx& ctx_; + Tensor* out_; + + AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx, + Tensor* out) + : predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {} + + template + void operator()() const { + auto t = EigenVector::Flatten(tensor_); + auto o = EigenScalar::From(*out_); + // return any of predicate_(t) is true. + o.device(*ctx_.eigen_device()) = predicate_(t).any(); + } +}; + +template +inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor, + const DevCtx& ctx, framework::Tensor* out) { + VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor( + predicate, tensor, ctx, out)); +} + +template +struct AnyVisitor : public boost::static_visitor { + const framework::Tensor& tensor_; + Predicate predicate_; + + AnyVisitor(const framework::Tensor& tensor, Predicate predicate) + : tensor_(tensor), predicate_(std::move(predicate)) {} + + template + bool operator()(const Place& place) const { + framework::Tensor out; + out.Resize({1}); + out.mutable_data(place); + auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place); + AnyImpl(predicate_, tensor_, *ctx, &out); + return this->GetResult(out, place); + } + + bool GetResult(const framework::Tensor& out, + const platform::CUDAPlace& gpu) const { + platform::CPUPlace cpu; + framework::Tensor tmp; + tmp.Resize({1}); + tmp.mutable_data(cpu); + auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu); + gpuctx->Wait(); + CopyFrom(out, cpu, *gpuctx, &tmp); + gpuctx->Wait(); + return GetResult(tmp, cpu); + } + + bool GetResult(const framework::Tensor& out, + const platform::CPUPlace& cpu) const { + return *out.data(); + } +}; + +template +inline bool Any(const framework::Tensor& tensor, Predicate predicate) { + AnyVisitor visitor(tensor, predicate); + auto place = tensor.place(); + return platform::VisitPlace(place, visitor); +} + +struct HasNANPredicate { + template + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isnan()) { + // Cast eigen_vector to vector of bool. true if is inf. + return eigen_vec.isnan(); + } +}; + +bool HasNAN(const framework::Tensor& tensor) { + HasNANPredicate predicate; + return Any(tensor, predicate); +} + +struct HasInfPredicate { + template + auto operator()(const T& eigen_vec) const + -> decltype(std::declval().isinf()) { + // Cast eigen_vector to vector of bool. true if is inf. + return eigen_vec.isinf(); + } +}; + +bool HasInf(const framework::Tensor& tensor) { + HasInfPredicate predicate; + return Any(tensor, predicate); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/tensor_util.cu b/paddle/framework/tensor_util.cu new file mode 120000 index 0000000000000000000000000000000000000000..b00e6e59d93328bf3142597ea4de0dc225501e56 --- /dev/null +++ b/paddle/framework/tensor_util.cu @@ -0,0 +1 @@ +./tensor_util.cc \ No newline at end of file diff --git a/paddle/framework/tensor_util.h b/paddle/framework/tensor_util.h index 108006911a5f30691a01c6579a62c8111b653986..6a21f8db1e3966fd23eee0da2346b2d61f9321fb 100644 --- a/paddle/framework/tensor_util.h +++ b/paddle/framework/tensor_util.h @@ -14,8 +14,10 @@ limitations under the License. */ #pragma once #include "paddle/framework/data_type.h" +#include "paddle/framework/eigen.h" #include "paddle/framework/framework.pb.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { @@ -207,6 +209,12 @@ inline void CopyToVector(const Tensor& src, std::vector* dst) { src_ptr, size); } +// Returns true if a tensor contains NAN, i.e., Not A Number. +bool HasNAN(const framework::Tensor& tensor); + +// Returns true if a tensor contains Inf, i.e., Infinity. +bool HasInf(const framework::Tensor& tensor); + inline void SerializeToStream(std::ostream& os, const Tensor& tensor, const platform::DeviceContext& dev_ctx) { // TODO(typhoonzero): serialize to ostream diff --git a/paddle/framework/tensor_util_test.cc b/paddle/framework/tensor_util_test.cc index 1281e9c2a45858810d505aca6b2a2587e7a7280c..0dc5166fcabf77b48b8681ab1f050e2bc88f44ab 100644 --- a/paddle/framework/tensor_util_test.cc +++ b/paddle/framework/tensor_util_test.cc @@ -13,6 +13,7 @@ #include "paddle/framework/tensor_util.h" #include +#include #include namespace paddle { @@ -230,6 +231,29 @@ TEST(CopyToVector, Tensor) { #endif } +TEST(HasNAN, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + float* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 0.0; + buf[1] = NAN; + buf[2] = 0.0; + + ASSERT_TRUE(HasNAN(src)); +} + +TEST(HasInf, CPU) { + using namespace paddle::framework; + using namespace paddle::platform; + Tensor src; + double* buf = src.mutable_data({3}, CPUPlace()); + buf[0] = 1.0; + buf[1] = INFINITY; + buf[2] = 0.0; + ASSERT_TRUE(HasInf(src)); +} + TEST(Tensor, SerializeAndDeserialize) { framework::Tensor src_tensor; int array[6] = {1, 2, 3, 4, 5, 6}; diff --git a/paddle/framework/tensor_util_test.cu b/paddle/framework/tensor_util_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..ebd35fdf6c2a1388fec23057070f723c8ef9da9c --- /dev/null +++ b/paddle/framework/tensor_util_test.cu @@ -0,0 +1,57 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/framework/tensor_util.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace framework { + +static __global__ void FillNAN(float* buf) { + buf[0] = 0.0; + buf[1] = 0.1; + buf[2] = NAN; +} +static __global__ void FillInf(float* buf) { + buf[0] = 0.0; + buf[1] = INFINITY; + buf[2] = 0.5; +} + +TEST(HasNAN, GPU) { + Tensor tensor; + platform::CUDAPlace gpu(0); + auto& pool = platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + float* buf = tensor.mutable_data({3}, gpu); + FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(HasNAN(tensor)); +} + +TEST(HasInf, GPU) { + Tensor tensor; + platform::CUDAPlace gpu(0); + auto& pool = platform::DeviceContextPool::Instance(); + auto* cuda_ctx = pool.GetByPlace(gpu); + float* buf = tensor.mutable_data({3}, gpu); + FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf); + cuda_ctx->Wait(); + ASSERT_TRUE(HasInf(tensor)); +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/threadpool.h b/paddle/framework/threadpool.h index 5f6b2d458f7ee764c22d203f285b78023b6012f3..bcd8190755083ec30687675602a1c95a9c15c69e 100644 --- a/paddle/framework/threadpool.h +++ b/paddle/framework/threadpool.h @@ -16,6 +16,7 @@ limitations under the License. */ #include #include +#include #include #include #include @@ -25,10 +26,11 @@ limitations under the License. */ namespace paddle { namespace framework { -typedef std::function Task; - class ThreadPool { public: + typedef std::packaged_task Task; + typedef std::function Fun; + /** * @brief Get a instance of threadpool, the thread number will * be specified as the number of hardware thread contexts @@ -61,13 +63,18 @@ class ThreadPool { /** * @brief Push a function to the queue, and will be scheduled and * executed if a thread is available. - * @param[in] Task will be pushed to the task queue. + * @param[in] Task, will be pushed to the task queue. + * @return std::future, we could wait for the task finished by + * f.wait(). */ - void Run(const Task& fn) { + std::future Run(const Fun& fn) { std::unique_lock lock(mutex_); - tasks_.push(fn); + Task task(std::bind(fn)); + std::future f = task.get_future(); + tasks_.push(std::move(task)); lock.unlock(); scheduled_.notify_one(); + return f; } /** @@ -110,7 +117,7 @@ class ThreadPool { break; } // pop a task from the task queue - auto task = tasks_.front(); + auto task = std::move(tasks_.front()); tasks_.pop(); --available_; diff --git a/paddle/framework/threadpool_test.cc b/paddle/framework/threadpool_test.cc index 012d92a5edc415f0bb2f8a0ea38ffeb9549d54fa..50b6238cd8786be9d8cf2d5f821daadea12bd208 100644 --- a/paddle/framework/threadpool_test.cc +++ b/paddle/framework/threadpool_test.cc @@ -20,16 +20,21 @@ limitations under the License. */ namespace framework = paddle::framework; void do_sum(framework::ThreadPool* pool, std::atomic& sum, int cnt) { + std::vector> fs; for (int i = 0; i < cnt; ++i) { - pool->Run([&sum]() { sum.fetch_add(1); }); + auto f = pool->Run([&sum]() { sum.fetch_add(1); }); + fs.push_back(std::move(f)); + } + for (auto& f : fs) { + f.wait(); } } TEST(ThreadPool, ConcurrentInit) { framework::ThreadPool* pool; - int concurrent_cnt = 50; + int n = 50; std::vector threads; - for (int i = 0; i < concurrent_cnt; ++i) { + for (int i = 0; i < n; ++i) { std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); }); threads.push_back(std::move(t)); } @@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) { } } -TEST(ThreadPool, ConcurrentStart) { +TEST(ThreadPool, ConcurrentRun) { framework::ThreadPool* pool = framework::ThreadPool::GetInstance(); std::atomic sum(0); std::vector threads; - int concurrent_cnt = 50; + int n = 50; // sum = (n * (n + 1)) / 2 - for (int i = 1; i <= concurrent_cnt; ++i) { + for (int i = 1; i <= n; ++i) { std::thread t(do_sum, pool, std::ref(sum), i); threads.push_back(std::move(t)); } @@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) { t.join(); } pool->Wait(); - EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2); + EXPECT_EQ(sum, ((n + 1) * n) / 2); } diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index de7b70e271b38ebe3a4c38704d0cced47d010788..cbdbf5335d32d55a0221728758025c9d2cb3e7d1 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -126,14 +126,165 @@ public: inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } + } +}; + #ifdef PADDLE_MOBILE_INFERENCE - if (Device == DEVICE_TYPE_CPU) { - memory_.reset(); + +/* + * \brief Forward calculation of convolution, optimized for mobile. + */ +template +class GemmConvMobileFunction : public ConvFunctionBase { +public: + void init(const FuncConfig& config) override { + ConvFunctionBase::init(config); + } + + void check(const BufferArgs& inputs, const BufferArgs& outputs) override { + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + checkShape(input, filter, output); + } + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(numInputs_, inputs.size()); + CHECK_EQ(numOutputs_, outputs.size()); + check(inputs, outputs); + // TODO(hedaoyuan): Need to define some index macros, + // to avoid useing 0 and 1. + const TensorShape& input = inputs[0].shape(); + const TensorShape& filter = inputs[1].shape(); + const TensorShape& output = outputs[0].shape(); + + real beta; + if (outputs[0].getArgType() == ADD_TO) { + beta = 1.0; + } else { + beta = 0.0; } -#endif + + size_t batchSize = input[0]; + size_t inputChannels = input[1]; + size_t inputHeight = input[2]; + size_t inputWidth = input[3]; + size_t filterHeight = getFilterHeight(filter); + size_t filterWidth = getFilterWidth(filter); + size_t outputChannels = output[1]; + size_t outputHeight = output[2]; + size_t outputWidth = output[3]; + + real* inputData = inputs[0].data(); + real* filterData = inputs[1].data(); + real* outputData = outputs[0].data(); + bool needIm2col = isNeedIm2col(filter); + + TensorShape imShape = + TensorShape({inputChannels / groups_, inputHeight, inputWidth}); + + TensorShape colShape; + real* colData = NULL; + + size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth; + size_t colWidth = outputHeight * outputWidth; + // Max col matrix height 256, Max col matrix width 1024 + size_t stepColHeight = std::min(colHeight, static_cast(256)); + size_t stepColWidth = std::min(colWidth, static_cast(2048)); + + if (needIm2col) { + colShape = TensorShape({inputChannels / groups_, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + resizeBuffer(stepColHeight * stepColWidth * sizeof(real)); + colData = reinterpret_cast(memory_->getBuf()); + } + + Im2ColMobileFunctor im2col; + size_t inputOffset = imShape.getElements(); + size_t outputOffset = + (outputChannels / groups_) * outputHeight * outputWidth; + size_t filterOffset = filter.getElements() / groups_; + + int nStride = colWidth; + int kStride = colHeight; + for (size_t i = 0; i < batchSize; i++) { + for (size_t g = 0; g < groups_; g++) { + if (needIm2col) { + real beta_ = beta; + for (size_t colHeightStart = 0; colHeightStart < colHeight; + colHeightStart += stepColHeight) { + for (size_t colWidthStart = 0; colWidthStart < colWidth; + colWidthStart += stepColWidth) { + int N = std::min(colWidth - colWidthStart, stepColWidth); + int K = std::min(colHeight - colHeightStart, stepColHeight); + // im2col + im2col(inputData + g * inputOffset, + imShape, + colData, + colShape, + strideH(), + strideW(), + paddingH(), + paddingW(), + dilationH(), + dilationW(), + colHeightStart, + K, + colWidthStart, + N); + + // gemm + int M = outputChannels / groups_; + BlasGemm::compute( + false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset + colHeightStart, + kStride, + colData, + N, + beta_, + outputData + g * outputOffset + colWidthStart, + nStride); + } + beta_ = 1.0; + } + } else { + int M = outputChannels / groups_; + int N = outputHeight * outputWidth; + int K = inputChannels / groups_ * filterHeight * filterWidth; + BlasGemm::compute(false, + false, + M, + N, + K, + 1.0f, + filterData + g * filterOffset, + K, + inputData + g * inputOffset, + N, + beta, + outputData + g * outputOffset, + N); + } + } + inputData += inputChannels * inputHeight * inputWidth; + outputData += outputChannels * outputHeight * outputWidth; + } + + memory_.reset(); } }; +#endif + /* * \brief Backward input calculation of convolution. */ @@ -348,7 +499,11 @@ public: } }; +#ifdef PADDLE_MOBILE_INFERENCE +REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction); +#else REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); +#endif REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction); REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction); #ifdef PADDLE_WITH_CUDA diff --git a/paddle/function/Im2Col.h b/paddle/function/Im2Col.h index 0c37fc972484bfbede01d23652e384071bf883af..36a9bcf84e4b14965c83627821b71d1c7c0da1b2 100644 --- a/paddle/function/Im2Col.h +++ b/paddle/function/Im2Col.h @@ -98,4 +98,54 @@ public: int dilationWidth = 1); }; +template +class Im2ColMobileFunctor { +public: + void operator()(const T* imData, + const TensorShape& imShape, + T* colData, + const TensorShape& colShape, + int strideHeight, + int strideWidth, + int paddingHeight, + int paddingWidth, + int dilationHeight, + int dilationWidth, + int colHeightStart, + int colHeightSize, + int colWidthStart, + int colWidthSize) { + int inputHeight = imShape[1]; + int inputWidth = imShape[2]; + int filterHeight = colShape[1]; + int filterWidth = colShape[2]; + int outputWidth = colShape[4]; + + for (int colh = 0; colh < colHeightSize; colh++) { + int wOffset = (colHeightStart + colh) % filterWidth; + int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight; + int c_im = (colHeightStart + colh) / filterWidth / filterHeight; + + for (int colw = 0; colw < colWidthSize; colw++) { + int h = (colWidthStart + colw) / outputWidth; + int w = (colWidthStart + colw) % outputWidth; + + int imRowIdx = h * strideHeight + hOffset * dilationHeight; + int imColIdx = w * strideWidth + wOffset * dilationWidth; + if ((imRowIdx - paddingHeight) < 0 || + (imRowIdx - paddingHeight) >= inputHeight || + (imColIdx - paddingWidth) < 0 || + (imColIdx - paddingWidth) >= inputWidth) { + colData[colh * colWidthSize + colw] = static_cast(0); + } else { + imRowIdx += c_im * inputHeight - paddingHeight; + imColIdx -= paddingWidth; + colData[colh * colWidthSize + colw] = + imData[imRowIdx * inputWidth + imColIdx]; + } + } + } + } +}; + } // namespace paddle diff --git a/paddle/function/Im2ColTest.cpp b/paddle/function/Im2ColTest.cpp index 1f085538d81904dbd5b5d6bcd014adaed22e37d7..3ba866dcdd845403d52f7a85adfef08cbb11c305 100644 --- a/paddle/function/Im2ColTest.cpp +++ b/paddle/function/Im2ColTest.cpp @@ -138,4 +138,86 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor(); } #endif +template +void TestIm2ColMobileFunctor() { + for (size_t channels : {32}) { + for (size_t inputHeight : {33, 100}) { + for (size_t inputWidth : {32, 96}) { + for (size_t filterHeight : {5}) { + for (size_t filterWidth : {7}) { + for (size_t stride : {2}) { + for (size_t padding : {1}) { + for (size_t dilation : {1, 3}) { + size_t filterSizeH = (filterHeight - 1) * dilation + 1; + size_t filterSizeW = (filterWidth - 1) * dilation + 1; + if (inputHeight + 2 * padding < filterSizeH || + inputWidth + 2 * padding < filterSizeW) + break; + if (padding >= filterSizeH || padding >= filterSizeW) break; + size_t outputHeight = + (inputHeight - filterSizeH + 2 * padding) / stride + 1; + size_t outputWidth = + (inputWidth - filterSizeW + 2 * padding) / stride + 1; + + TensorShape imShape = + TensorShape({channels, inputHeight, inputWidth}); + TensorShape colShape1 = TensorShape({channels, + filterHeight, + filterWidth, + outputHeight, + outputWidth}); + + size_t height = channels * filterHeight * filterWidth; + size_t width = outputHeight * outputWidth; + VectorPtr input1 = + Vector::create(imShape.getElements(), false); + VectorPtr input2 = + Vector::create(imShape.getElements(), false); + MatrixPtr output1 = + Matrix::create(height, width, false, false); + MatrixPtr output2 = + Matrix::create(height, width, false, false); + input1->uniform(0.001, 1); + input2->copyFrom(*input1); + + Im2ColFunctor im2Col1; + Im2ColMobileFunctor im2Col2; + im2Col1(input1->getData(), + imShape, + output1->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation); + im2Col2(input2->getData(), + imShape, + output2->getData(), + colShape1, + stride, + stride, + padding, + padding, + dilation, + dilation, + 0, + height, + 0, + width); + + autotest::TensorCheckEqual(*output1, *output2); + } + } + } + } + } + } + } + } +} + +TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor(); } + } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index dfef2c16d8f2277d57cbcfe51d108402e518799b..2b366e6383d23e2d31a194edd04412892a8311eb 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext { std::unique_ptr eigen_device_; }; +template +struct DefaultDeviceContextType; + +template <> +struct DefaultDeviceContextType { + using TYPE = CPUDeviceContext; +}; + #ifdef PADDLE_WITH_CUDA class EigenCudaStreamDevice; @@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; +template <> +struct DefaultDeviceContextType { + using TYPE = CUDADeviceContext; +}; + class CUDNNDeviceContext : public CUDADeviceContext { public: explicit CUDNNDeviceContext(CUDAPlace place); @@ -125,6 +138,13 @@ class DeviceContextPool { /*! \brief Return handle of single device context. */ const platform::DeviceContext* Get(const platform::Place& place); + template + const typename DefaultDeviceContextType::TYPE* GetByPlace( + const Place& place) { + return reinterpret_cast< + const typename DefaultDeviceContextType::TYPE*>(Get(place)); + } + private: static DeviceContextPool* pool; constexpr static int LEFT_SHIFT = 8; diff --git a/paddle/platform/place.h b/paddle/platform/place.h index d25eaa689f4a4baa951db5c61bbf99288e365ee1..76b5c502cc48431a4e9b13b07505978884576e1d 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -15,7 +15,7 @@ limitations under the License. */ #pragma once #include - +#include "paddle/platform/enforce.h" #include "paddle/platform/variant.h" namespace paddle { @@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &); std::ostream &operator<<(std::ostream &, const Place &); +template +struct PlaceVisitorWrapper + : public boost::static_visitor { + const Visitor &visitor_; + explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {} + + typename Visitor::result_type operator()(const CPUPlace &cpu) const { + return visitor_(cpu); + } + + typename Visitor::result_type operator()(const CUDAPlace &cuda) const { +#ifdef PADDLE_WITH_CUDA + return visitor_(cuda); +#else + PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device"); + return typename Visitor::result_type(); +#endif + } +}; + +template +typename Visitor::result_type VisitPlace(const Place &place, + const Visitor &visitor) { + return boost::apply_visitor(PlaceVisitorWrapper(visitor), place); +} + } // namespace platform } // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 6afed7eec7001b646d55cef0bc3f59782b80b15f..ced75cbfd899980390d41610d863d6cf154570b0 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -3,6 +3,7 @@ if(WITH_PYTHON) SRCS pybind.cc exception.cc protobuf.cc const_value.cc DEPS pybind python backward proto_desc paddle_memory executor prune init ${GLOB_OP_LIB}) + target_link_libraries(paddle_pybind rt) endif(WITH_PYTHON) if(WITH_DOC) diff --git a/paddle/scripts/submit_local.sh.in b/paddle/scripts/submit_local.sh.in index a94bc01b358c508132eb85920a2d4c0aa934dd51..8a352b0078d701f797f7202c85bd0e08201ac9b8 100755 --- a/paddle/scripts/submit_local.sh.in +++ b/paddle/scripts/submit_local.sh.in @@ -71,9 +71,7 @@ function threads_config() { # auto set OMP_NUM_THREADS and MKL_NUM_THREADS # according to trainer_count and total processors # only when MKL enabled - if [ "@WITH_MKL@" == "OFF" ]; then - return 0 - fi + # auto set OPENBLAS_NUM_THREADS when do not use MKL processors=`grep "processor" /proc/cpuinfo|sort -u|wc -l` trainers=`grep -Eo 'trainer_count.[0-9]+' <<< "$@" |grep -Eo '[0-9]+'|xargs` if [ -z $trainers ]; then @@ -83,12 +81,19 @@ function threads_config() { if [ $threads -eq 0 ]; then threads=1 fi - if [ -z "$OMP_NUM_THREADS" ]; then - export OMP_NUM_THREADS=$threads - fi - if [ -z "$MKL_NUM_THREADS" ]; then - export MKL_NUM_THREADS=$threads + if [ "@WITH_MKL@" == "ON" ]; then + if [ -z "$OMP_NUM_THREADS" ]; then + export OMP_NUM_THREADS=$threads + fi + if [ -z "$MKL_NUM_THREADS" ]; then + export MKL_NUM_THREADS=$threads + fi + else + if [ -z "$OPENBLAS_NUM_THREADS" ]; then + export OPENBLAS_NUM_THREADS=$threads + fi fi + } PADDLE_CONF_HOME="$HOME/.config/paddle" @@ -150,7 +155,7 @@ fi case "$1" in "train") threads_config $@ - # echo $OMP_NUM_THREADS $MKL_NUM_THREADS + # echo $OMP_NUM_THREADS $MKL_NUM_THREADS $OPENBLAS_NUM_THREADS ${DEBUGGER} $PADDLE_BIN_PATH/paddle_trainer ${@:2} ;; "merge_model") diff --git a/python/paddle/v2/dataset/flowers.py b/python/paddle/v2/dataset/flowers.py index 634388094c804827657dc83d5c205e680625b156..7bdddeaabec733ef26b3f766c6437f5c53d65044 100644 --- a/python/paddle/v2/dataset/flowers.py +++ b/python/paddle/v2/dataset/flowers.py @@ -44,7 +44,7 @@ __all__ = ['train', 'test', 'valid'] DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz' LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat' SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat' -DATA_MD5 = '52808999861908f626f3c1f4e79d11fa' +DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118' LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d' SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c' # In official 'readme', tstid is the flag of test data diff --git a/python/paddle/v2/fluid/__init__.py b/python/paddle/v2/fluid/__init__.py index c72b5730695dbc4f772015f1fb8dec6814cd1837..225b41c5043b5792abb90bbad53cbbfce9a3156e 100644 --- a/python/paddle/v2/fluid/__init__.py +++ b/python/paddle/v2/fluid/__init__.py @@ -36,7 +36,7 @@ def __read_gflags_from_env__(): """ import sys import core - read_env_flags = ['use_pinned_memory'] + read_env_flags = ['use_pinned_memory', 'check_nan_inf'] if core.is_compile_gpu(): read_env_flags.append('fraction_of_gpu_memory_to_use') core.init_gflags([sys.argv[0]] +