diff --git a/doc/api/index_cn.rst b/doc/api/index_cn.rst index 9be0b370ee5e301aee4a6e31b1cfa905754968e8..84f9097a6cdc2da269bd6a0685796e14e26da37e 100644 --- a/doc/api/index_cn.rst +++ b/doc/api/index_cn.rst @@ -7,3 +7,4 @@ API 模型配置 数据访问 训练与应用 + v2/fluid.rst diff --git a/doc/howto/read_source.md b/doc/howto/read_source.md index 383acb0c8251043c3c6bbf309d2e07bf0074cd4f..e4211abb3be9cace80bc14dbe3db3e0a31221dd0 100644 --- a/doc/howto/read_source.md +++ b/doc/howto/read_source.md @@ -6,10 +6,10 @@ Core: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/framework Operator: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators -Optimizer: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/optimizer - Memory: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/memory +Platform: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/platform + # Compile Time The following **defines** the NN. The definition goes into this [protocol buffer](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto). diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 4b0eff3adb6fff0c9599b8613c5f19daea840674..206e298eb27a2daaec5c674d45cfe4b81a6b522d 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -58,3 +58,6 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry proto_desc) cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) + +cc_library(init SRCS init.cc DEPS gflags executor place stringpiece) +cc_test(init_test SRCS init_test.cc DEPS init) diff --git a/paddle/framework/ddim_test.cc b/paddle/framework/ddim_test.cc index 756232b1b56a49d2c91cc2cac950ca508c54fb3f..bd5ea09d7da700479aa387283d7bde77c64c1293 100644 --- a/paddle/framework/ddim_test.cc +++ b/paddle/framework/ddim_test.cc @@ -1,3 +1,16 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ #include #include diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index 83aa927c293676c3800ed945c175e4f3dc5629d6..a8b8a6f8e82525bd9a1f709516483de6f44142dc 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -33,32 +33,12 @@ namespace framework { const std::string kFeedOpType = "feed"; const std::string kFetchOpType = "fetch"; -Executor::Executor(const std::vector& places) : own_(true) { - PADDLE_ENFORCE_GT(places.size(), 0); - device_contexts_.resize(places.size()); - for (size_t i = 0; i < places.size(); i++) { - if (platform::is_cpu_place(places[i])) { - device_contexts_[i] = new platform::CPUDeviceContext( - boost::get(places[i])); - } else if (platform::is_gpu_place(places[i])) { -#ifdef PADDLE_WITH_CUDA - device_contexts_[i] = new platform::CUDADeviceContext( - boost::get(places[i])); -#else - PADDLE_THROW( - "'GPUPlace' is not supported, Please re-compile with WITH_GPU " - "option"); -#endif - } - } -} +DeviceContextPool* DeviceContextPool::pool = nullptr; -Executor::~Executor() { - if (own_) { - for (auto& device_context : device_contexts_) { - delete device_context; - } - } +Executor::Executor(const std::vector& places) { + DeviceContextPool& pool = DeviceContextPool::Get(); + auto borrowed_contexts = pool.Borrow(places); + device_contexts_.swap(borrowed_contexts); } static void CreateTensor(Variable* var, VarDesc::VarType var_type) { @@ -132,8 +112,5 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, } } -Executor::Executor(const platform::DeviceContext& device) - : device_contexts_({&device}), own_(false) {} - } // namespace framework } // namespace paddle diff --git a/paddle/framework/executor.h b/paddle/framework/executor.h index b745f4f6474ef688774f4c833a3958942e9aa8cb..073e04729b1166f1cabd16709d161fda0d580f1c 100644 --- a/paddle/framework/executor.h +++ b/paddle/framework/executor.h @@ -14,19 +14,98 @@ limitations under the License. */ #pragma once +#include +#include + #include "paddle/framework/op_info.h" #include "paddle/framework/program_desc.h" #include "paddle/framework/scope.h" #include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" namespace paddle { namespace framework { +class DeviceContextPool { + public: + static DeviceContextPool& Get() { + PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!"); + return *pool; + } + + static DeviceContextPool& Create(const std::vector& places) { + if (pool == nullptr) { + pool = new DeviceContextPool(places); + } + return *pool; + } + + std::vector Borrow( + const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + PADDLE_ENFORCE_LE(places.size(), device_contexts_.size()); + std::vector borrowed_contexts; + for (auto& place : places) { + auto range = device_contexts_.equal_range(place); + if (range.first == range.second) { + PADDLE_THROW( + "'Place' is not supported, Please re-compile with WITH_GPU " + "option"); + } + // TODO(dzhwinter) : assign the first found device. Will enhanced later. + // device load balancer maybe useful here. + borrowed_contexts.emplace_back(range.first->second); + } + return borrowed_contexts; + } + + explicit DeviceContextPool(const std::vector& places) { + PADDLE_ENFORCE_GT(places.size(), 0); + for (size_t i = 0; i < places.size(); i++) { + if (platform::is_cpu_place(places[i])) { + device_contexts_.emplace( + places[i], new platform::CPUDeviceContext( + boost::get(places[i]))); + } else if (platform::is_gpu_place(places[i])) { +#ifdef PADDLE_WITH_CUDA + device_contexts_.emplace( + places[i], new platform::CUDADeviceContext( + boost::get(places[i]))); +#else + PADDLE_THROW( + "'GPUPlace' is not supported, Please re-compile with WITH_GPU " + "option"); +#endif + } + } + } + + ~DeviceContextPool() {} + + private: + static DeviceContextPool* pool; + struct Hash { + std::hash hash_; + size_t operator()(const platform::Place& place) const { + return hash_(place.which()); + } + }; + std::unordered_multimap + device_contexts_; + DISABLE_COPY_AND_ASSIGN(DeviceContextPool); +}; + class Executor { public: + // TODO(dzhwinter) : Do not rely on this function, it will be removed + explicit Executor(const platform::DeviceContext& device) + : Executor(std::vector({device.GetPlace()})) {} + + explicit Executor(const platform::Place& place) + : Executor(std::vector({place})) {} + explicit Executor(const std::vector& places); - explicit Executor(const platform::DeviceContext& devices); - ~Executor(); /* @Brief * Runtime evaluation of the given ProgramDesc under certain Scope @@ -39,7 +118,6 @@ class Executor { private: std::vector device_contexts_; - bool own_; }; } // namespace framework diff --git a/paddle/framework/init.cc b/paddle/framework/init.cc new file mode 100644 index 0000000000000000000000000000000000000000..1c4476f4b30aebf094eb27b45fb435c24a9061c1 --- /dev/null +++ b/paddle/framework/init.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include + +#include "paddle/framework/executor.h" +#include "paddle/framework/init.h" +#include "paddle/platform/place.h" +#include "paddle/string/piece.h" + +namespace paddle { +namespace framework { + +std::once_flag gflags_init_flag; + +// TODO(qijun) move init gflags to init.cc +void InitGflags(std::vector &argv) { + std::call_once(gflags_init_flag, [&]() { + int argc = argv.size(); + char **arr = new char *[argv.size()]; + std::string line; + for (size_t i = 0; i < argv.size(); i++) { + arr[i] = &argv[i][0]; + line += argv[i]; + line += ' '; + } + google::ParseCommandLineFlags(&argc, &arr, true); + VLOG(1) << "Init commandline: " << line; + }); +} + +bool InitDevices(const std::vector &devices) { + // device format + // CPU + // GPU:1 + // TODO(dzhwinter) : add device format annotation for users. + std::vector places; + for (auto &device : devices) { + auto p = string::Piece(device); + if (string::Find(p, ':', 0) == string::Piece::npos) { + places.emplace_back(platform::CPUPlace()); + } else if (string::HasPrefix(p, "GPU")) { +#ifdef PADDLE_WITH_CUDA + auto pos = string::RFind(p, ':', string::Piece::npos); + auto number = device.substr(pos + 1); + places.emplace_back(platform::GPUPlace(std::stoi(number))); +#else + LOG(WARNING) + << "'GPU' is not supported, Please re-compile with WITH_GPU option"; +#endif + } else { + return false; + } + } + + if (std::find_if(places.begin(), places.end(), + [&](const platform::Place &place) { + return platform::is_cpu_place(place); + }) == places.end()) { + places.emplace_back(platform::CPUPlace()); + LOG(WARNING) << "Not specified any device, use CPU by Default."; + } + DeviceContextPool::Create(places); + return true; + return true; +} + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/init.h b/paddle/framework/init.h new file mode 100644 index 0000000000000000000000000000000000000000..1715cd81e6647158e269e39d4d91fbe065cd0008 --- /dev/null +++ b/paddle/framework/init.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#pragma once +#include + +#include "gflags/gflags.h" +#include "glog/logging.h" + +namespace paddle { +namespace framework { + +void InitGflags(std::vector &argv); + +bool InitDevices(const std::vector &devices); + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/init_test.cc b/paddle/framework/init_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..f65e881a761e0a546d595eced26dd5b12475a763 --- /dev/null +++ b/paddle/framework/init_test.cc @@ -0,0 +1,27 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include "gtest/gtest.h" + +#include "paddle/framework/init.h" + +TEST(Init, InitDevices) { + using paddle::framework::InitDevices; + std::vector ds1 = {"CPU"}; + ASSERT_EQ(InitDevices(ds1), true); + +#ifdef PADDLE_WITH_CUDA + std::vector ds2 = {"CPU", "GPU:0", "GPU:1"}; + ASSERT_EQ(InitDevices(ds2), true); +#endif +} diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 8d34eee886a6202691e5dec2ab62e7c5b0ac7fb1..de7b70e271b38ebe3a4c38704d0cced47d010788 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -126,6 +126,11 @@ public: inputData += inputChannels * inputHeight * inputWidth; outputData += outputChannels * outputHeight * outputWidth; } +#ifdef PADDLE_MOBILE_INFERENCE + if (Device == DEVICE_TYPE_CPU) { + memory_.reset(); + } +#endif } }; diff --git a/paddle/gserver/layers/ROIPoolLayer.cpp b/paddle/gserver/layers/ROIPoolLayer.cpp index 2c8256b91c97b513ce7237b8174c522430094926..7d7c30b4d89e2dd137e7fc7de3159c07bbab9fb4 100644 --- a/paddle/gserver/layers/ROIPoolLayer.cpp +++ b/paddle/gserver/layers/ROIPoolLayer.cpp @@ -84,12 +84,15 @@ void ROIPoolLayer::forward(PassType passType) { size_t poolChannelOffset = pooledHeight_ * pooledWidth_; real* outputData = outputValue->getData(); - Matrix::resizeOrCreate(maxIdxs_, - numROIs, - channels_ * pooledHeight_ * pooledWidth_, - false, - false); - real* argmaxData = maxIdxs_->getData(); + real* argmaxData = nullptr; + if (passType != PASS_TEST) { + Matrix::resizeOrCreate(maxIdxs_, + numROIs, + channels_ * pooledHeight_ * pooledWidth_, + false, + false); + argmaxData = maxIdxs_->getData(); + } for (size_t n = 0; n < numROIs; ++n) { // the first five elememts of each RoI should be: @@ -128,14 +131,18 @@ void ROIPoolLayer::forward(PassType passType) { bool isEmpty = (hend <= hstart) || (wend <= wstart); size_t poolIndex = ph * pooledWidth_ + pw; outputData[poolIndex] = isEmpty ? 0 : -FLT_MAX; - argmaxData[poolIndex] = -1; + if (argmaxData) { + argmaxData[poolIndex] = -1; + } for (size_t h = hstart; h < hend; ++h) { for (size_t w = wstart; w < wend; ++w) { size_t index = h * width_ + w; if (batchData[index] > outputData[poolIndex]) { outputData[poolIndex] = batchData[index]; - argmaxData[poolIndex] = index; + if (argmaxData) { + argmaxData[poolIndex] = index; + } } } } @@ -143,7 +150,9 @@ void ROIPoolLayer::forward(PassType passType) { } batchData += channelOffset; outputData += poolChannelOffset; - argmaxData += poolChannelOffset; + if (argmaxData) { + argmaxData += poolChannelOffset; + } } bottomROIs += roiOffset; } diff --git a/paddle/gserver/layers/SequenceToBatch.cpp b/paddle/gserver/layers/SequenceToBatch.cpp index 5fa7b6f4881b9582b540a5b1bfe849220cc2a4ea..6b769378d24838364701d0f128a7308c6195cc41 100644 --- a/paddle/gserver/layers/SequenceToBatch.cpp +++ b/paddle/gserver/layers/SequenceToBatch.cpp @@ -171,12 +171,31 @@ void SequenceToBatch::sequence2BatchCopy(Matrix &batch, hl_sequence2batch_copy( batchData, seqData, idxData, seqWidth, batchCount, seq2batch); } else { - for (int i = 0; i < batchCount; ++i) { - if (seq2batch) { + if (seq2batch) { +#ifdef PADDLE_USE_MKLML + const int blockMemSize = 8 * 1024; + const int blockSize = blockMemSize / sizeof(real); +#pragma omp parallel for collapse(2) + for (int i = 0; i < batchCount; ++i) { + for (int j = 0; j < seqWidth; j += blockSize) { + memcpy(batch.rowBuf(i) + j, + sequence.rowBuf(idxData[i]) + j, + (j + blockSize > seqWidth) ? (seqWidth - j) * sizeof(real) + : blockMemSize); + } + } +#else + for (int i = 0; i < batchCount; ++i) { memcpy(batch.rowBuf(i), sequence.rowBuf(idxData[i]), seqWidth * sizeof(real)); - } else { + } +#endif + } else { +#ifdef PADDLE_USE_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < batchCount; ++i) { memcpy(sequence.rowBuf(idxData[i]), batch.rowBuf(i), seqWidth * sizeof(real)); diff --git a/paddle/memory/detail/system_allocator.cc b/paddle/memory/detail/system_allocator.cc index 6a815a1b57db1d833781ca224f34e4559af9b9a5..509250debc2b2fd2e87078ab5f233ae2db6fd898 100644 --- a/paddle/memory/detail/system_allocator.cc +++ b/paddle/memory/detail/system_allocator.cc @@ -19,6 +19,7 @@ limitations under the License. */ #include // for malloc and free #include // for mlock and munlock +#include // for std::max #include "gflags/gflags.h" @@ -28,7 +29,7 @@ limitations under the License. */ // of memory available to the system for paging. So, by default, we // should set false to use_pinned_memory. DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory."); - +DECLARE_double(fraction_of_gpu_memory_to_use); namespace paddle { namespace memory { namespace detail { @@ -77,45 +78,20 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) { // CUDA documentation doesn't explain if cudaMalloc returns nullptr // if size is 0. We just make sure it does. if (size <= 0) return nullptr; - - size_t available = 0; - size_t capacity = 0; - paddle::platform::GpuMemoryUsage(available, capacity); - - // Reserve memory for page tables, etc. - size_t reserving = 0.05 * capacity + paddle::platform::GpuMinChunkSize(); - size_t usable = available > reserving ? available - reserving : 0; - - // If remaining size no less than expected size, using general - // cudaMalloc to allocate GPU memory. - void* p = 0; - if (size <= usable) { - cudaError_t result = cudaMalloc(&p, size); - if (result == cudaSuccess) { - index = 0; - gpu_alloc_size_ += size; - return p; - } - } - - // If remaining size less than expected size or cudaMalloc failed, - // cudaMallocHost will be considered as a fallback allocator. - // - // NOTE: here, we use GpuMaxAllocSize() as the maximum memory size - // of host fallback allocation. Allocates too much would reduce - // the amount of memory available to the underlying system for paging. - usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_; - - if (size > usable) return nullptr; - - cudaError_t result = cudaMallocHost(&p, size); + void* p; + cudaError_t result = cudaMalloc(&p, size); if (result == cudaSuccess) { - index = 1; - fallback_alloc_size_ += size; + index = 0; + gpu_alloc_size_ += size; return p; + } else { + LOG(WARNING) + << "Cannot malloc " << size / 1024.0 / 1024.0 + << " MB GPU memory. Please shrink FLAGS_fraction_of_gpu_memory_to_use " + "environment variable to a lower value. Current value is " + << FLAGS_fraction_of_gpu_memory_to_use; + return nullptr; } - - return nullptr; } void GPUAllocator::Free(void* p, size_t size, size_t index) { diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 7c2a0ac7a705e5aac3d181545f8dfc8881e811f2..5c973fbb3cf9513d82a5b87719cb947466082424 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -88,7 +88,8 @@ There are two ways to set shape: The input should be a k-D tensor(k > 0 and k < 7). As an example: -Given: +Case 1: +Given X = [[0, 1, 2, 0, 0] [0, 3, 4, 0, 0] @@ -107,6 +108,27 @@ we get: Out = [[1, 2], [3, 4]]. + +Case 2: +Given + + X = [[0, 1, 2, 5, 0] + [0, 3, 4, 6, 0] + [0, 0, 0, 0, 0]], + +and + + offsets = [0, 1], + +and + + Y = [[0, 0, 0] + [0, 0, 0]], + +we get: + + Out = [[1, 2, 5], + [3, 4, 6]]. )DOC"); } }; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index e33070c40fbfa7f2794426247ef77b8fcaee4ec6..7852bb53a9035f71f52a51529c8e3cea22b0d4aa 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -274,7 +274,7 @@ void set_constant_with_place( } template <> -void set_constant_with_place( +void set_constant_with_place( const platform::DeviceContext& context, framework::Tensor* tensor, float value) { set_constant_with_place(context, tensor, value); diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index 7fd33bf662a1d0b7b6fa4e772bdadbf34b2f4fdd..d82d828747c0c822195b699359b8e62d1cf7e92d 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -34,21 +34,33 @@ class ReshapeOp : public framework::OperatorWithKernel { auto shape = ctx->Attrs().Get>("shape"); PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty."); auto x_dims = ctx->GetInputDim("X"); - // TODO(qiao) change batch_size - for (size_t i = 1; i < shape.size(); ++i) { - PADDLE_ENFORCE(shape[i] > 0, - "Each dimension of Attr(shape) " - "must be positive except the first one."); - } - if (shape[0] < 0) { - shape[0] = x_dims[0]; + + std::vector neg_dims_idx; + // set some dimension to -1 if it is unknown + const int unknown_size = -1; + for (size_t i = 0; i < shape.size(); ++i) { + PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size, + "Each dimension of Attr(shape) must be positive or %d.", + unknown_size); + if (shape[i] == unknown_size) { + neg_dims_idx.push_back(i); + PADDLE_ENFORCE(neg_dims_idx.size() <= 1, + "Only one dimension of Attr(shape) can be unknown."); + } } - // capacity check + int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); int64_t in_size = framework::product(x_dims); - PADDLE_ENFORCE_EQ(capacity, in_size, - "The size of Input(X) mismatches with Attr(shape)."); + if (neg_dims_idx.size() == 1) { + // dim infer + shape[neg_dims_idx[0]] = in_size / (-capacity); + // recalculate capacity + capacity = shape[neg_dims_idx[0]] * (-capacity); + } + // capacity check + PADDLE_ENFORCE(capacity == in_size, + "The size of Input(X) mismatches with Attr(shape)."); // resize output std::vector shape_int64(shape.size(), 0); std::transform(shape.begin(), shape.end(), shape_int64.begin(), @@ -88,6 +100,9 @@ the tensor X into a 2-D tensor: [[1, 2, 3, 4]] +One dimension in the target shape can be set -1, representing that its +size is unknown. In this case, the real dimension will be infered from +the original shape of Input(X) and other dimensions in the target shape. )DOC"); } }; diff --git a/paddle/operators/spp_op.cc b/paddle/operators/spp_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b1807b62616b80ea8a9e48409e0760c1c7b36a38 --- /dev/null +++ b/paddle/operators/spp_op.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/spp_op.h" +namespace paddle { +namespace operators { + +class SppOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SppOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput( + "X", + "(Tensor) The input tensor of spp operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of spp operator." + "N * M." + "M = C * H * W"); + AddAttr("pyramid_height", "(int), multi level pooling"); + AddAttr( + "pooling_type", + "(string), pooling type, can be \"max\" for max-pooling " + "and \"avg\" for average-pooling.") + .InEnum({"max", "avg"}); + AddComment(R"DOC( + "With spatial pyramid pooling, the input image can + be of any sizes. This not only allows arbitrary aspect + ratios, but also allows arbitrary scales. We can resize + the input image to any scale (e.g., min(w, h)=180, 224, + ...) and apply the same deep network. When the + input image is at different scales, the network (with + the same filter sizes) will extract features at different + scales. The scales play important roles in traditional + methods. + Input shape: $(N, C_{in}, H_{in}, W_{in})$ + Output shape: $(H_{out}, W_{out})$ + Where + $$ + H_{out} = N \\ + W_{out} = (((4^pyramid_height) - 1) / (4 - 1))$ * C_{in} + $$ + paper https://arxiv.org/pdf/1406.4729v4.pdf + )DOC"); + } +}; + +class SppOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SppOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SppOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + int pyramid_height = ctx->Attrs().Get("pyramid_height"); + PADDLE_ENFORCE(in_x_dims.size() == 4, + "Spping intput must be of 4-dimensional."); + int outlen = ((std::pow(4, pyramid_height) - 1) / (4 - 1)) * in_x_dims[1]; + std::vector output_shape({in_x_dims[0], outlen}); + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + +class SppOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(spp, ops::SppOp, ops::SppOpMaker, spp_grad, ops::SppOpGrad); +REGISTER_OP_CPU_KERNEL( + spp, ops::SppKernel, + ops::SppKernel); +REGISTER_OP_CPU_KERNEL( + spp_grad, ops::SppGradKernel, + ops::SppGradKernel); diff --git a/paddle/operators/spp_op.cu.cc b/paddle/operators/spp_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..761e4d6c4a9639898ba548d56bed3c8817720c1b --- /dev/null +++ b/paddle/operators/spp_op.cu.cc @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/spp_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + spp, ops::SppKernel, + ops::SppKernel); +REGISTER_OP_CUDA_KERNEL( + spp_grad, ops::SppGradKernel, + ops::SppGradKernel); diff --git a/paddle/operators/spp_op.h b/paddle/operators/spp_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f35b305d02c73bcae6e72b8afa5ce55148ea98b8 --- /dev/null +++ b/paddle/operators/spp_op.h @@ -0,0 +1,161 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/pooling.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { +template +class SppKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + auto* out = context.Output("Out"); + int pyramid_height = context.template Attr("pyramid_height"); + std::string pooling_type = + context.template Attr("pooling_type"); + out->mutable_data(context.GetPlace()); + auto out_stride = framework::stride(out->dims()); + int input_h = in_x->dims()[2]; + int input_w = in_x->dims()[3]; + size_t output_offset = 0; + for (int p = 0; p < pyramid_height; ++p) { + int bins = std::pow(2, p); + int kernel_size_h = std::ceil(input_h / static_cast(bins)); + int kernel_size_w = std::ceil(input_w / static_cast(bins)); + int padding_h = (kernel_size_h * bins - input_h + 1) / 2; + int padding_w = (kernel_size_w * bins - input_w + 1) / 2; + std::vector kernel_size({kernel_size_h, kernel_size_w}); + std::vector strides({kernel_size_h, kernel_size_w}); + std::vector paddings({padding_h, padding_w}); + // pooling output shape + framework::Tensor out_level; + std::vector output_shape_vec( + {in_x->dims()[0], in_x->dims()[1], bins, bins}); + framework::DDim output_shape(framework::make_ddim(output_shape_vec)); + out_level.mutable_data(output_shape, context.GetPlace()); + // pooling + if (pooling_type == "max") { + math::Pool2dFunctor, T> pool_forward; + math::MaxPool max_process; + pool_forward(context.template device_context(), *in_x, + kernel_size, strides, paddings, max_process, &out_level); + } else if (pooling_type == "avg") { + math::Pool2dFunctor, T> pool_forward; + math::AvgPool avg_process; + pool_forward(context.template device_context(), *in_x, + kernel_size, strides, paddings, avg_process, &out_level); + } + // flatten pooling output shape + int output_flatten_w = in_x->dims()[1] * bins * bins; + std::vector output_flatten_shape_vec( + {in_x->dims()[0], output_flatten_w}); + framework::DDim output_flatten_shape( + framework::make_ddim(output_flatten_shape_vec)); + out_level.Resize(output_flatten_shape); + // concat + auto out_level_stride = framework::stride(out_level.dims()); + StridedMemcpy(context.template device_context(), + out_level.data(), out_level_stride, out_level.dims(), + out_stride, out->data() + output_offset); + output_offset += out_level.dims()[1] * out_level_stride[1]; + } + } +}; +template +class SppGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const framework::Tensor* in_x = context.Input("X"); + const framework::Tensor* out = context.Input("Out"); + const framework::Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + framework::Tensor* in_x_grad = + context.Output(framework::GradVarName("X")); + int pyramid_height = context.template Attr("pyramid_height"); + std::string pooling_type = + context.template Attr("pooling_type"); + auto& device_ctx = context.template device_context(); + math::SetConstant zero; + in_x_grad->mutable_data(context.GetPlace()); + zero(device_ctx, in_x_grad, static_cast(0)); + auto out_stride = framework::stride(out->dims()); + int input_h = in_x->dims()[2]; + int input_w = in_x->dims()[3]; + size_t out_offset = 0; + for (int p = 0; p < pyramid_height; ++p) { + int bins = std::pow(2, p); + int kernel_size_h = std::ceil(input_h / static_cast(bins)); + int kernel_size_w = std::ceil(input_w / static_cast(bins)); + int padding_h = (kernel_size_h * bins - input_h + 1) / 2; + int padding_w = (kernel_size_w * bins - input_w + 1) / 2; + std::vector kernel_size({kernel_size_h, kernel_size_w}); + std::vector strides({kernel_size_h, kernel_size_w}); + std::vector paddings({padding_h, padding_w}); + // split out and outgrad ... to flatten + framework::Tensor out_level; + framework::Tensor outgrad_level; + int out_flatten_w = in_x->dims()[1] * bins * bins; + std::vector out_flatten_shape_vec( + {in_x->dims()[0], out_flatten_w}); + framework::DDim out_flatten_shape( + framework::make_ddim(out_flatten_shape_vec)); + out_level.mutable_data(out_flatten_shape, context.GetPlace()); + outgrad_level.mutable_data(out_flatten_shape, context.GetPlace()); + auto flatten_stride = framework::stride(out_level.dims()); + // memcpy + StridedMemcpy(context.template device_context(), + out->data() + out_offset, out_stride, + out_level.dims(), flatten_stride, out_level.data()); + + StridedMemcpy(context.template device_context(), + out_grad->data() + out_offset, out_stride, + outgrad_level.dims(), flatten_stride, + outgrad_level.data()); + out_offset += out_level.dims()[1] * out_stride[1]; + // flatten backward to nchw + + std::vector out_shape_vec({in_x->dims()[0], in_x->dims()[1]}); + out_shape_vec.push_back( + (input_h - kernel_size_h + 2 * padding_h) / kernel_size_h + 1); + out_shape_vec.push_back( + (input_w - kernel_size_w + 2 * padding_w) / kernel_size_w + 1); + framework::DDim out_shape(framework::make_ddim(out_shape_vec)); + out_level.ShareDataWith(out_level); + out_level.Resize(out_shape); + outgrad_level.ShareDataWith(outgrad_level); + outgrad_level.Resize(out_shape); + // pooling backward + if (pooling_type == "max") { + math::MaxPool2dGradFunctor pool2d_backward; + pool2d_backward(context.template device_context(), *in_x, + *&out_level, *&outgrad_level, kernel_size, strides, + paddings, in_x_grad); + } else if (pooling_type == "avg") { + math::Pool2dGradFunctor, T> + pool_backward; + math::AvgPoolGrad avg_process; + pool_backward(context.template device_context(), *in_x, + *&out_level, *&outgrad_level, kernel_size, strides, + paddings, avg_process, in_x_grad); + } + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index 1c72b5055971e73c7aa560a61ca9d3c48dc56fbc..8cdc5f43403b0c54d3f1f01a3e97405fd5b2f434 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -125,21 +125,21 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; } -CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place) +CUDNNDeviceContext::CUDNNDeviceContext(CUDNNPlace place) : CUDADeviceContext(place), place_(place) { PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_)); PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream())); } -CudnnDeviceContext::~CudnnDeviceContext() { +CUDNNDeviceContext::~CUDNNDeviceContext() { SetDeviceId(place_.device); Wait(); PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); } -Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); } +Place CUDNNDeviceContext::GetPlace() const { return CUDNNPlace(); } -cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; } +cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; } #endif diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index f67194993db1f4160bd6894b2c845a82f4da2354..56813a1d5b3c2a7f4ff7b4eddc6fa47ed861700c 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -86,10 +86,10 @@ class CUDADeviceContext : public DeviceContext { cublasHandle_t cublas_handle_; }; -class CudnnDeviceContext : public CUDADeviceContext { +class CUDNNDeviceContext : public CUDADeviceContext { public: - explicit CudnnDeviceContext(CudnnPlace place); - virtual ~CudnnDeviceContext(); + explicit CUDNNDeviceContext(CUDNNPlace place); + virtual ~CUDNNDeviceContext(); /*! \brief Return place in the device context. */ Place GetPlace() const final; @@ -99,7 +99,7 @@ class CudnnDeviceContext : public CUDADeviceContext { private: cudnnHandle_t cudnn_handle_; - CudnnPlace place_; + CUDNNPlace place_; }; #endif diff --git a/paddle/platform/device_context_test.cc b/paddle/platform/device_context_test.cc index be3b2af5af09cb18f5156412ff60a7fc15a16487..109c13a8812dffac10d202cbc9d85c4e601bf197 100644 --- a/paddle/platform/device_context_test.cc +++ b/paddle/platform/device_context_test.cc @@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) { } } -TEST(Device, CudnnDeviceContext) { - using paddle::platform::CudnnDeviceContext; - using paddle::platform::CudnnPlace; +TEST(Device, CUDNNDeviceContext) { + using paddle::platform::CUDNNDeviceContext; + using paddle::platform::CUDNNPlace; if (paddle::platform::dynload::HasCUDNN()) { int count = paddle::platform::GetCUDADeviceCount(); for (int i = 0; i < count; ++i) { - CudnnDeviceContext* device_context = - new CudnnDeviceContext(CudnnPlace(i)); + CUDNNDeviceContext* device_context = + new CUDNNDeviceContext(CUDNNPlace(i)); cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); ASSERT_NE(nullptr, cudnn_handle); ASSERT_NE(nullptr, device_context->stream()); diff --git a/paddle/platform/gpu_info.cc b/paddle/platform/gpu_info.cc index 4fa2eaed31c6e9368459c2da6f8b0667b453d58c..541eca5f39c2e6a4b464aec79fd8a920ab4c7732 100644 --- a/paddle/platform/gpu_info.cc +++ b/paddle/platform/gpu_info.cc @@ -73,19 +73,20 @@ size_t GpuMaxChunkSize() { size_t available = 0; GpuMemoryUsage(available, total); - - // Reserving the rest memory for page tables, etc. - size_t reserving = 0.05 * total; - + VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/" + << total / 1024 / 1024 << "M"; + size_t reserving = static_cast(0.05 * total); // If available less than minimum chunk size, no usable memory exists. available = - std::max(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), - reserving) - - reserving; + std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(), + total - reserving); + + // Reserving the rest memory for page tables, etc. - size_t allocating = FLAGS_fraction_of_gpu_memory_to_use * total; + size_t allocating = static_cast(FLAGS_fraction_of_gpu_memory_to_use * + (total - reserving)); - PADDLE_ENFORCE_LT(allocating, available); + PADDLE_ENFORCE_LE(allocating, available); return allocating; } diff --git a/paddle/platform/place.h b/paddle/platform/place.h index 4526945792b2ea96cc4e9df11d8f35897cba7526..ca98920d414bc87ce243995a42e5672d0e61e108 100644 --- a/paddle/platform/place.h +++ b/paddle/platform/place.h @@ -51,9 +51,9 @@ struct GPUPlace { int device; }; -struct CudnnPlace : public GPUPlace { - CudnnPlace() : GPUPlace() {} - explicit CudnnPlace(int d) : GPUPlace(d) {} +struct CUDNNPlace : public GPUPlace { + CUDNNPlace() : GPUPlace() {} + explicit CUDNNPlace(int d) : GPUPlace(d) {} }; struct IsGPUPlace : public boost::static_visitor { @@ -72,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor { // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 -typedef boost::variant Place; +typedef boost::variant Place; // static check number of place types is less equal than // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) diff --git a/paddle/platform/place_test.cc b/paddle/platform/place_test.cc index 184af12c230f1ccd7826e507f16f4e91ca380a45..c536b59ed8f71bd078bd09c5bd5afeab74c71b28 100644 --- a/paddle/platform/place_test.cc +++ b/paddle/platform/place_test.cc @@ -5,16 +5,22 @@ TEST(Place, Equality) { paddle::platform::CPUPlace cpu; paddle::platform::GPUPlace g0(0), g1(1), gg0(0); + paddle::platform::CUDNNPlace d0(0), d1(1), dd0(0); EXPECT_EQ(cpu, cpu); EXPECT_EQ(g0, g0); EXPECT_EQ(g1, g1); EXPECT_EQ(g0, gg0); + EXPECT_EQ(d0, dd0); EXPECT_NE(g0, g1); + EXPECT_NE(d0, d1); EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0)); EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu)); + + EXPECT_TRUE(paddle::platform::is_gpu_place(d0)); + EXPECT_FALSE(paddle::platform::places_are_same_class(g0, d0)); } TEST(Place, Default) { diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index fd55f410d3f0fee418e7efffa927e46c38d23a07..1fb69de90d2fb5386dffdd95825c496a8fa559d3 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,7 +1,7 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc - DEPS pybind python backward proto_desc paddle_memory executor prune + DEPS pybind python backward proto_desc paddle_memory executor prune init ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 1faf24bcb8828596ec37abde9e699f46526e41df..4248db34c6345bd62e63628c7794b40d8a1adab6 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -16,11 +16,11 @@ limitations under the License. */ #include // for call_once #include -#include "gflags/gflags.h" #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" #include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/framework.pb.h" +#include "paddle/framework/init.h" #include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor_array.h" @@ -51,24 +51,6 @@ static size_t UniqueIntegerGenerator(const std::string &prefix) { return generators[prefix].fetch_add(1); } -std::once_flag gflags_init_flag; - -// TODO(qijun) move init gflags to init.cc -void InitGflags(std::vector &argv) { - std::call_once(gflags_init_flag, [&]() { - int argc = argv.size(); - char **arr = new char *[argv.size()]; - std::string line; - for (size_t i = 0; i < argv.size(); i++) { - arr[i] = &argv[i][0]; - line += argv[i]; - line += ' '; - } - google::ParseCommandLineFlags(&argc, &arr, true); - VLOG(1) << "Init commandline: " << line; - }); -} - bool IsCompileGPU() { #ifndef PADDLE_WITH_CUDA return false; @@ -438,7 +420,8 @@ All parameter, weight, gradient are variables in Paddle. .def("run", &Executor::Run); m.def("unique_integer", UniqueIntegerGenerator); - m.def("init_gflags", InitGflags); + m.def("init_gflags", framework::InitGflags); + m.def("init_devices", &framework::InitDevices); m.def("is_compile_gpu", IsCompileGPU); m.def("set_feed_variable", framework::SetFeedVariable); diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 9776ae18057d57dd994fac8b62090258252922c6..8bfe56d795e394efffabb61f145b1a20d806447d 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -1119,8 +1119,9 @@ def simple_gru2(input, :param gru_bias_attr: bias parameter attribute of gru layer, False means no bias, None means default bias. :type gru_bias_attr: ParameterAttribute|False|None - :param gru_layer_attr: Extra attribute of the gru layer. - :type gru_layer_attr: ExtraLayerAttribute + :param gru_param_attr: param parameter attribute of gru layer, + None means default param. + :type gru_param_attr: ParameterAttribute|None :return: the gru group. :rtype: LayerOutput """ diff --git a/python/paddle/v2/fluid/executor.py b/python/paddle/v2/fluid/executor.py index bdc82eede9d93a7cf904999a6b869ce2d23c90dc..9a99b045dc70a9e4662a6f4da141183ffc8f1846 100644 --- a/python/paddle/v2/fluid/executor.py +++ b/python/paddle/v2/fluid/executor.py @@ -46,6 +46,13 @@ class Executor(object): p.set_place(each) act_places.append(p) + # TODO(dzhwinter) : consider that our fluid tests all written in + # GPUPlace(gpu_id), this will be changed in next PR. + if core.is_compile_gpu(): + core.init_devices(["CPU", "GPU:0"]) + else: + core.init_devices(["CPU"]) + self.executor = core.Executor(act_places) self.places = places diff --git a/python/paddle/v2/fluid/tests/test_reshape_op.py b/python/paddle/v2/fluid/tests/test_reshape_op.py index 16bb6bb2af67f7d32a2fafc1cb37412084ec0829..18ee3aece656276fec9671df9baf298b7fd3c9b1 100644 --- a/python/paddle/v2/fluid/tests/test_reshape_op.py +++ b/python/paddle/v2/fluid/tests/test_reshape_op.py @@ -17,5 +17,19 @@ class TestReshapeOp(OpTest): self.check_grad(["X"], "Out") +class TestReshapeOpDimInfer(OpTest): + def setUp(self): + self.op_type = "reshape" + self.inputs = {'X': np.random.random((10, 20)).astype("float32")} + self.attrs = {'shape': [4, -1, 5]} + self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/v2/fluid/tests/test_spp_op.py b/python/paddle/v2/fluid/tests/test_spp_op.py new file mode 100644 index 0000000000000000000000000000000000000000..007723f0e35ad194c427401337bc9b13756576de --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_spp_op.py @@ -0,0 +1,68 @@ +import unittest +import numpy as np +from op_test import OpTest +from test_pool2d_op import max_pool2D_forward_naive +from test_pool2d_op import avg_pool2D_forward_naive + + +class TestSppOp(OpTest): + def setUp(self): + self.op_type = "spp" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + nsize, csize, hsize, wsize = input.shape + out_level_flatten = [] + for i in xrange(self.pyramid_height): + bins = np.power(2, i) + kernel_size = [0, 0] + padding = [0, 0] + kernel_size[0] = np.ceil(hsize / + bins.astype("double")).astype("int32") + padding[0] = ( + (kernel_size[0] * bins - hsize + 1) / 2).astype("int32") + + kernel_size[1] = np.ceil(wsize / + bins.astype("double")).astype("int32") + padding[1] = ( + (kernel_size[1] * bins - wsize + 1) / 2).astype("int32") + out_level = self.pool2D_forward_naive(input, kernel_size, + kernel_size, padding) + out_level_flatten.append( + out_level.reshape(nsize, bins * bins * csize)) + if i == 0: + output = out_level_flatten[i] + else: + output = np.concatenate((output, out_level_flatten[i]), 1) + # output = np.concatenate(out_level_flatten.tolist(), 0); + self.inputs = {'X': input.astype('float32'), } + self.attrs = { + 'pyramid_height': self.pyramid_height, + 'pooling_type': self.pool_type + } + + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + if self.pool_type != "avg": + self.check_grad(['X'], 'Out', max_relative_error=0.05) + + def init_test_case(self): + self.shape = [3, 2, 4, 4] + self.pyramid_height = 3 + self.pool2D_forward_naive = max_pool2D_forward_naive + self.pool_type = "max" + + +class TestCase2(TestSppOp): + def init_test_case(self): + self.shape = [3, 2, 4, 4] + self.pyramid_height = 3 + self.pool2D_forward_naive = avg_pool2D_forward_naive + self.pool_type = "avg" + + +if __name__ == '__main__': + unittest.main()