提交 ed56ed9f 编写于 作者: Y yangyaming

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix-6581

......@@ -7,3 +7,4 @@ API
模型配置 <v2/model_configs.rst>
数据访问 <v2/data.rst>
训练与应用 <v2/run_logic.rst>
v2/fluid.rst
......@@ -6,10 +6,10 @@ Core: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/framework
Operator: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators
Optimizer: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/optimizer
Memory: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/memory
Platform: https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/platform
# Compile Time
The following **defines** the NN. The definition goes into this [protocol buffer](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto).
......
......@@ -58,3 +58,6 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
proto_desc)
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_library(init SRCS init.cc DEPS gflags executor place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <sstream>
#include <vector>
......
......@@ -33,32 +33,12 @@ namespace framework {
const std::string kFeedOpType = "feed";
const std::string kFetchOpType = "fetch";
Executor::Executor(const std::vector<platform::Place>& places) : own_(true) {
PADDLE_ENFORCE_GT(places.size(), 0);
device_contexts_.resize(places.size());
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_[i] = new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i]));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_[i] = new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i]));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
DeviceContextPool* DeviceContextPool::pool = nullptr;
Executor::~Executor() {
if (own_) {
for (auto& device_context : device_contexts_) {
delete device_context;
}
}
Executor::Executor(const std::vector<platform::Place>& places) {
DeviceContextPool& pool = DeviceContextPool::Get();
auto borrowed_contexts = pool.Borrow(places);
device_contexts_.swap(borrowed_contexts);
}
static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
......@@ -132,8 +112,5 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
}
}
Executor::Executor(const platform::DeviceContext& device)
: device_contexts_({&device}), own_(false) {}
} // namespace framework
} // namespace paddle
......@@ -14,19 +14,98 @@ limitations under the License. */
#pragma once
#include <map>
#include <unordered_map>
#include "paddle/framework/op_info.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
class DeviceContextPool {
public:
static DeviceContextPool& Get() {
PADDLE_ENFORCE_NOT_NULL(pool, "Need to Create DeviceContextPool first!");
return *pool;
}
static DeviceContextPool& Create(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceContextPool(places);
}
return *pool;
}
std::vector<const platform::DeviceContext*> Borrow(
const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
PADDLE_ENFORCE_LE(places.size(), device_contexts_.size());
std::vector<const platform::DeviceContext*> borrowed_contexts;
for (auto& place : places) {
auto range = device_contexts_.equal_range(place);
if (range.first == range.second) {
PADDLE_THROW(
"'Place' is not supported, Please re-compile with WITH_GPU "
"option");
}
// TODO(dzhwinter) : assign the first found device. Will enhanced later.
// device load balancer maybe useful here.
borrowed_contexts.emplace_back(range.first->second);
}
return borrowed_contexts;
}
explicit DeviceContextPool(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), 0);
for (size_t i = 0; i < places.size(); i++) {
if (platform::is_cpu_place(places[i])) {
device_contexts_.emplace(
places[i], new platform::CPUDeviceContext(
boost::get<platform::CPUPlace>(places[i])));
} else if (platform::is_gpu_place(places[i])) {
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(
places[i], new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
}
}
~DeviceContextPool() {}
private:
static DeviceContextPool* pool;
struct Hash {
std::hash<int> hash_;
size_t operator()(const platform::Place& place) const {
return hash_(place.which());
}
};
std::unordered_multimap<const platform::Place, const platform::DeviceContext*,
Hash>
device_contexts_;
DISABLE_COPY_AND_ASSIGN(DeviceContextPool);
};
class Executor {
public:
// TODO(dzhwinter) : Do not rely on this function, it will be removed
explicit Executor(const platform::DeviceContext& device)
: Executor(std::vector<platform::Place>({device.GetPlace()})) {}
explicit Executor(const platform::Place& place)
: Executor(std::vector<platform::Place>({place})) {}
explicit Executor(const std::vector<platform::Place>& places);
explicit Executor(const platform::DeviceContext& devices);
~Executor();
/* @Brief
* Runtime evaluation of the given ProgramDesc under certain Scope
......@@ -39,7 +118,6 @@ class Executor {
private:
std::vector<const platform::DeviceContext*> device_contexts_;
bool own_;
};
} // namespace framework
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/framework/executor.h"
#include "paddle/framework/init.h"
#include "paddle/platform/place.h"
#include "paddle/string/piece.h"
namespace paddle {
namespace framework {
std::once_flag gflags_init_flag;
// TODO(qijun) move init gflags to init.cc
void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() {
int argc = argv.size();
char **arr = new char *[argv.size()];
std::string line;
for (size_t i = 0; i < argv.size(); i++) {
arr[i] = &argv[i][0];
line += argv[i];
line += ' ';
}
google::ParseCommandLineFlags(&argc, &arr, true);
VLOG(1) << "Init commandline: " << line;
});
}
bool InitDevices(const std::vector<std::string> &devices) {
// device format
// CPU
// GPU:1
// TODO(dzhwinter) : add device format annotation for users.
std::vector<platform::Place> places;
for (auto &device : devices) {
auto p = string::Piece(device);
if (string::Find(p, ':', 0) == string::Piece::npos) {
places.emplace_back(platform::CPUPlace());
} else if (string::HasPrefix(p, "GPU")) {
#ifdef PADDLE_WITH_CUDA
auto pos = string::RFind(p, ':', string::Piece::npos);
auto number = device.substr(pos + 1);
places.emplace_back(platform::GPUPlace(std::stoi(number)));
#else
LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
#endif
} else {
return false;
}
}
if (std::find_if(places.begin(), places.end(),
[&](const platform::Place &place) {
return platform::is_cpu_place(place);
}) == places.end()) {
places.emplace_back(platform::CPUPlace());
LOG(WARNING) << "Not specified any device, use CPU by Default.";
}
DeviceContextPool::Create(places);
return true;
return true;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <mutex>
#include "gflags/gflags.h"
#include "glog/logging.h"
namespace paddle {
namespace framework {
void InitGflags(std::vector<std::string> &argv);
bool InitDevices(const std::vector<std::string> &devices);
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
TEST(Init, InitDevices) {
using paddle::framework::InitDevices;
std::vector<std::string> ds1 = {"CPU"};
ASSERT_EQ(InitDevices(ds1), true);
#ifdef PADDLE_WITH_CUDA
std::vector<std::string> ds2 = {"CPU", "GPU:0", "GPU:1"};
ASSERT_EQ(InitDevices(ds2), true);
#endif
}
......@@ -126,6 +126,11 @@ public:
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
#ifdef PADDLE_MOBILE_INFERENCE
if (Device == DEVICE_TYPE_CPU) {
memory_.reset();
}
#endif
}
};
......
......@@ -84,12 +84,15 @@ void ROIPoolLayer::forward(PassType passType) {
size_t poolChannelOffset = pooledHeight_ * pooledWidth_;
real* outputData = outputValue->getData();
Matrix::resizeOrCreate(maxIdxs_,
numROIs,
channels_ * pooledHeight_ * pooledWidth_,
false,
false);
real* argmaxData = maxIdxs_->getData();
real* argmaxData = nullptr;
if (passType != PASS_TEST) {
Matrix::resizeOrCreate(maxIdxs_,
numROIs,
channels_ * pooledHeight_ * pooledWidth_,
false,
false);
argmaxData = maxIdxs_->getData();
}
for (size_t n = 0; n < numROIs; ++n) {
// the first five elememts of each RoI should be:
......@@ -128,14 +131,18 @@ void ROIPoolLayer::forward(PassType passType) {
bool isEmpty = (hend <= hstart) || (wend <= wstart);
size_t poolIndex = ph * pooledWidth_ + pw;
outputData[poolIndex] = isEmpty ? 0 : -FLT_MAX;
argmaxData[poolIndex] = -1;
if (argmaxData) {
argmaxData[poolIndex] = -1;
}
for (size_t h = hstart; h < hend; ++h) {
for (size_t w = wstart; w < wend; ++w) {
size_t index = h * width_ + w;
if (batchData[index] > outputData[poolIndex]) {
outputData[poolIndex] = batchData[index];
argmaxData[poolIndex] = index;
if (argmaxData) {
argmaxData[poolIndex] = index;
}
}
}
}
......@@ -143,7 +150,9 @@ void ROIPoolLayer::forward(PassType passType) {
}
batchData += channelOffset;
outputData += poolChannelOffset;
argmaxData += poolChannelOffset;
if (argmaxData) {
argmaxData += poolChannelOffset;
}
}
bottomROIs += roiOffset;
}
......
......@@ -171,12 +171,31 @@ void SequenceToBatch::sequence2BatchCopy(Matrix &batch,
hl_sequence2batch_copy(
batchData, seqData, idxData, seqWidth, batchCount, seq2batch);
} else {
for (int i = 0; i < batchCount; ++i) {
if (seq2batch) {
if (seq2batch) {
#ifdef PADDLE_USE_MKLML
const int blockMemSize = 8 * 1024;
const int blockSize = blockMemSize / sizeof(real);
#pragma omp parallel for collapse(2)
for (int i = 0; i < batchCount; ++i) {
for (int j = 0; j < seqWidth; j += blockSize) {
memcpy(batch.rowBuf(i) + j,
sequence.rowBuf(idxData[i]) + j,
(j + blockSize > seqWidth) ? (seqWidth - j) * sizeof(real)
: blockMemSize);
}
}
#else
for (int i = 0; i < batchCount; ++i) {
memcpy(batch.rowBuf(i),
sequence.rowBuf(idxData[i]),
seqWidth * sizeof(real));
} else {
}
#endif
} else {
#ifdef PADDLE_USE_MKLML
#pragma omp parallel for
#endif
for (int i = 0; i < batchCount; ++i) {
memcpy(sequence.rowBuf(idxData[i]),
batch.rowBuf(i),
seqWidth * sizeof(real));
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <stdlib.h> // for malloc and free
#include <sys/mman.h> // for mlock and munlock
#include <algorithm> // for std::max
#include "gflags/gflags.h"
......@@ -28,7 +29,7 @@ limitations under the License. */
// of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory.");
DECLARE_double(fraction_of_gpu_memory_to_use);
namespace paddle {
namespace memory {
namespace detail {
......@@ -77,45 +78,20 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr
// if size is 0. We just make sure it does.
if (size <= 0) return nullptr;
size_t available = 0;
size_t capacity = 0;
paddle::platform::GpuMemoryUsage(available, capacity);
// Reserve memory for page tables, etc.
size_t reserving = 0.05 * capacity + paddle::platform::GpuMinChunkSize();
size_t usable = available > reserving ? available - reserving : 0;
// If remaining size no less than expected size, using general
// cudaMalloc to allocate GPU memory.
void* p = 0;
if (size <= usable) {
cudaError_t result = cudaMalloc(&p, size);
if (result == cudaSuccess) {
index = 0;
gpu_alloc_size_ += size;
return p;
}
}
// If remaining size less than expected size or cudaMalloc failed,
// cudaMallocHost will be considered as a fallback allocator.
//
// NOTE: here, we use GpuMaxAllocSize() as the maximum memory size
// of host fallback allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging.
usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_;
if (size > usable) return nullptr;
cudaError_t result = cudaMallocHost(&p, size);
void* p;
cudaError_t result = cudaMalloc(&p, size);
if (result == cudaSuccess) {
index = 1;
fallback_alloc_size_ += size;
index = 0;
gpu_alloc_size_ += size;
return p;
} else {
LOG(WARNING)
<< "Cannot malloc " << size / 1024.0 / 1024.0
<< " MB GPU memory. Please shrink FLAGS_fraction_of_gpu_memory_to_use "
"environment variable to a lower value. Current value is "
<< FLAGS_fraction_of_gpu_memory_to_use;
return nullptr;
}
return nullptr;
}
void GPUAllocator::Free(void* p, size_t size, size_t index) {
......
......@@ -88,7 +88,8 @@ There are two ways to set shape:
The input should be a k-D tensor(k > 0 and k < 7). As an example:
Given:
Case 1:
Given
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
......@@ -107,6 +108,27 @@ we get:
Out = [[1, 2],
[3, 4]].
Case 2:
Given
X = [[0, 1, 2, 5, 0]
[0, 3, 4, 6, 0]
[0, 0, 0, 0, 0]],
and
offsets = [0, 1],
and
Y = [[0, 0, 0]
[0, 0, 0]],
we get:
Out = [[1, 2, 5],
[3, 4, 6]].
)DOC");
}
};
......
......@@ -274,7 +274,7 @@ void set_constant_with_place<platform::GPUPlace>(
}
template <>
void set_constant_with_place<platform::CudnnPlace>(
void set_constant_with_place<platform::CUDNNPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
......
......@@ -34,21 +34,33 @@ class ReshapeOp : public framework::OperatorWithKernel {
auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
auto x_dims = ctx->GetInputDim("X");
// TODO(qiao) change batch_size
for (size_t i = 1; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0,
"Each dimension of Attr(shape) "
"must be positive except the first one.");
}
if (shape[0] < 0) {
shape[0] = x_dims[0];
std::vector<size_t> neg_dims_idx;
// set some dimension to -1 if it is unknown
const int unknown_size = -1;
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or %d.",
unknown_size);
if (shape[i] == unknown_size) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
}
}
// capacity check
int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims);
PADDLE_ENFORCE_EQ(capacity, in_size,
"The size of Input(X) mismatches with Attr(shape).");
if (neg_dims_idx.size() == 1) {
// dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = shape[neg_dims_idx[0]] * (-capacity);
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
......@@ -88,6 +100,9 @@ the tensor X into a 2-D tensor:
[[1, 2, 3, 4]]
One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from
the original shape of Input(X) and other dimensions in the target shape.
)DOC");
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/spp_op.h"
namespace paddle {
namespace operators {
class SppOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SppOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor of spp operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"(Tensor) The output tensor of spp operator."
"N * M."
"M = C * H * W");
AddAttr<int>("pyramid_height", "(int), multi level pooling");
AddAttr<std::string>(
"pooling_type",
"(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.")
.InEnum({"max", "avg"});
AddComment(R"DOC(
"With spatial pyramid pooling, the input image can
be of any sizes. This not only allows arbitrary aspect
ratios, but also allows arbitrary scales. We can resize
the input image to any scale (e.g., min(w, h)=180, 224,
...) and apply the same deep network. When the
input image is at different scales, the network (with
the same filter sizes) will extract features at different
scales. The scales play important roles in traditional
methods.
Input shape: $(N, C_{in}, H_{in}, W_{in})$
Output shape: $(H_{out}, W_{out})$
Where
$$
H_{out} = N \\
W_{out} = (((4^pyramid_height) - 1) / (4 - 1))$ * C_{in}
$$
paper https://arxiv.org/pdf/1406.4729v4.pdf
)DOC");
}
};
class SppOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SppOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SppOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int pyramid_height = ctx->Attrs().Get<int>("pyramid_height");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Spping intput must be of 4-dimensional.");
int outlen = ((std::pow(4, pyramid_height) - 1) / (4 - 1)) * in_x_dims[1];
std::vector<int64_t> output_shape({in_x_dims[0], outlen});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class SppOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(spp, ops::SppOp, ops::SppOpMaker, spp_grad, ops::SppOpGrad);
REGISTER_OP_CPU_KERNEL(
spp, ops::SppKernel<paddle::platform::CPUDeviceContext, float>,
ops::SppKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
spp_grad, ops::SppGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SppGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/spp_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
spp, ops::SppKernel<paddle::platform::CUDADeviceContext, float>,
ops::SppKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
spp_grad, ops::SppGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SppGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SppKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
int pyramid_height = context.template Attr<int>("pyramid_height");
std::string pooling_type =
context.template Attr<std::string>("pooling_type");
out->mutable_data<T>(context.GetPlace());
auto out_stride = framework::stride(out->dims());
int input_h = in_x->dims()[2];
int input_w = in_x->dims()[3];
size_t output_offset = 0;
for (int p = 0; p < pyramid_height; ++p) {
int bins = std::pow(2, p);
int kernel_size_h = std::ceil(input_h / static_cast<double>(bins));
int kernel_size_w = std::ceil(input_w / static_cast<double>(bins));
int padding_h = (kernel_size_h * bins - input_h + 1) / 2;
int padding_w = (kernel_size_w * bins - input_w + 1) / 2;
std::vector<int> kernel_size({kernel_size_h, kernel_size_w});
std::vector<int> strides({kernel_size_h, kernel_size_w});
std::vector<int> paddings({padding_h, padding_w});
// pooling output shape
framework::Tensor out_level;
std::vector<int64_t> output_shape_vec(
{in_x->dims()[0], in_x->dims()[1], bins, bins});
framework::DDim output_shape(framework::make_ddim(output_shape_vec));
out_level.mutable_data<T>(output_shape, context.GetPlace());
// pooling
if (pooling_type == "max") {
math::Pool2dFunctor<DeviceContext, math::MaxPool<T>, T> pool_forward;
math::MaxPool<T> max_process;
pool_forward(context.template device_context<DeviceContext>(), *in_x,
kernel_size, strides, paddings, max_process, &out_level);
} else if (pooling_type == "avg") {
math::Pool2dFunctor<DeviceContext, math::AvgPool<T>, T> pool_forward;
math::AvgPool<T> avg_process;
pool_forward(context.template device_context<DeviceContext>(), *in_x,
kernel_size, strides, paddings, avg_process, &out_level);
}
// flatten pooling output shape
int output_flatten_w = in_x->dims()[1] * bins * bins;
std::vector<int64_t> output_flatten_shape_vec(
{in_x->dims()[0], output_flatten_w});
framework::DDim output_flatten_shape(
framework::make_ddim(output_flatten_shape_vec));
out_level.Resize(output_flatten_shape);
// concat
auto out_level_stride = framework::stride(out_level.dims());
StridedMemcpy<T>(context.template device_context<DeviceContext>(),
out_level.data<T>(), out_level_stride, out_level.dims(),
out_stride, out->data<T>() + output_offset);
output_offset += out_level.dims()[1] * out_level_stride[1];
}
}
};
template <typename DeviceContext, typename T>
class SppGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
int pyramid_height = context.template Attr<int>("pyramid_height");
std::string pooling_type =
context.template Attr<std::string>("pooling_type");
auto& device_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> zero;
in_x_grad->mutable_data<T>(context.GetPlace());
zero(device_ctx, in_x_grad, static_cast<T>(0));
auto out_stride = framework::stride(out->dims());
int input_h = in_x->dims()[2];
int input_w = in_x->dims()[3];
size_t out_offset = 0;
for (int p = 0; p < pyramid_height; ++p) {
int bins = std::pow(2, p);
int kernel_size_h = std::ceil(input_h / static_cast<double>(bins));
int kernel_size_w = std::ceil(input_w / static_cast<double>(bins));
int padding_h = (kernel_size_h * bins - input_h + 1) / 2;
int padding_w = (kernel_size_w * bins - input_w + 1) / 2;
std::vector<int> kernel_size({kernel_size_h, kernel_size_w});
std::vector<int> strides({kernel_size_h, kernel_size_w});
std::vector<int> paddings({padding_h, padding_w});
// split out and outgrad ... to flatten
framework::Tensor out_level;
framework::Tensor outgrad_level;
int out_flatten_w = in_x->dims()[1] * bins * bins;
std::vector<int64_t> out_flatten_shape_vec(
{in_x->dims()[0], out_flatten_w});
framework::DDim out_flatten_shape(
framework::make_ddim(out_flatten_shape_vec));
out_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
outgrad_level.mutable_data<T>(out_flatten_shape, context.GetPlace());
auto flatten_stride = framework::stride(out_level.dims());
// memcpy
StridedMemcpy<T>(context.template device_context<DeviceContext>(),
out->data<T>() + out_offset, out_stride,
out_level.dims(), flatten_stride, out_level.data<T>());
StridedMemcpy<T>(context.template device_context<DeviceContext>(),
out_grad->data<T>() + out_offset, out_stride,
outgrad_level.dims(), flatten_stride,
outgrad_level.data<T>());
out_offset += out_level.dims()[1] * out_stride[1];
// flatten backward to nchw
std::vector<int64_t> out_shape_vec({in_x->dims()[0], in_x->dims()[1]});
out_shape_vec.push_back(
(input_h - kernel_size_h + 2 * padding_h) / kernel_size_h + 1);
out_shape_vec.push_back(
(input_w - kernel_size_w + 2 * padding_w) / kernel_size_w + 1);
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out_level.ShareDataWith(out_level);
out_level.Resize(out_shape);
outgrad_level.ShareDataWith(outgrad_level);
outgrad_level.Resize(out_shape);
// pooling backward
if (pooling_type == "max") {
math::MaxPool2dGradFunctor<DeviceContext, T> pool2d_backward;
pool2d_backward(context.template device_context<DeviceContext>(), *in_x,
*&out_level, *&outgrad_level, kernel_size, strides,
paddings, in_x_grad);
} else if (pooling_type == "avg") {
math::Pool2dGradFunctor<DeviceContext, math::AvgPoolGrad<T>, T>
pool_backward;
math::AvgPoolGrad<T> avg_process;
pool_backward(context.template device_context<DeviceContext>(), *in_x,
*&out_level, *&outgrad_level, kernel_size, strides,
paddings, avg_process, in_x_grad);
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -125,21 +125,21 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place)
CUDNNDeviceContext::CUDNNDeviceContext(CUDNNPlace place)
: CUDADeviceContext(place), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
}
CudnnDeviceContext::~CudnnDeviceContext() {
CUDNNDeviceContext::~CUDNNDeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); }
Place CUDNNDeviceContext::GetPlace() const { return CUDNNPlace(); }
cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; }
#endif
......
......@@ -86,10 +86,10 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_;
};
class CudnnDeviceContext : public CUDADeviceContext {
class CUDNNDeviceContext : public CUDADeviceContext {
public:
explicit CudnnDeviceContext(CudnnPlace place);
virtual ~CudnnDeviceContext();
explicit CUDNNDeviceContext(CUDNNPlace place);
virtual ~CUDNNDeviceContext();
/*! \brief Return place in the device context. */
Place GetPlace() const final;
......@@ -99,7 +99,7 @@ class CudnnDeviceContext : public CUDADeviceContext {
private:
cudnnHandle_t cudnn_handle_;
CudnnPlace place_;
CUDNNPlace place_;
};
#endif
......
......@@ -47,14 +47,14 @@ TEST(Device, CUDADeviceContext) {
}
}
TEST(Device, CudnnDeviceContext) {
using paddle::platform::CudnnDeviceContext;
using paddle::platform::CudnnPlace;
TEST(Device, CUDNNDeviceContext) {
using paddle::platform::CUDNNDeviceContext;
using paddle::platform::CUDNNPlace;
if (paddle::platform::dynload::HasCUDNN()) {
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
CudnnDeviceContext* device_context =
new CudnnDeviceContext(CudnnPlace(i));
CUDNNDeviceContext* device_context =
new CUDNNDeviceContext(CUDNNPlace(i));
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
ASSERT_NE(nullptr, device_context->stream());
......
......@@ -73,19 +73,20 @@ size_t GpuMaxChunkSize() {
size_t available = 0;
GpuMemoryUsage(available, total);
// Reserving the rest memory for page tables, etc.
size_t reserving = 0.05 * total;
VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
<< total / 1024 / 1024 << "M";
size_t reserving = static_cast<size_t>(0.05 * total);
// If available less than minimum chunk size, no usable memory exists.
available =
std::max(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
reserving) -
reserving;
std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
total - reserving);
// Reserving the rest memory for page tables, etc.
size_t allocating = FLAGS_fraction_of_gpu_memory_to_use * total;
size_t allocating = static_cast<size_t>(FLAGS_fraction_of_gpu_memory_to_use *
(total - reserving));
PADDLE_ENFORCE_LT(allocating, available);
PADDLE_ENFORCE_LE(allocating, available);
return allocating;
}
......
......@@ -51,9 +51,9 @@ struct GPUPlace {
int device;
};
struct CudnnPlace : public GPUPlace {
CudnnPlace() : GPUPlace() {}
explicit CudnnPlace(int d) : GPUPlace(d) {}
struct CUDNNPlace : public GPUPlace {
CUDNNPlace() : GPUPlace() {}
explicit CUDNNPlace(int d) : GPUPlace(d) {}
};
struct IsGPUPlace : public boost::static_visitor<bool> {
......@@ -72,7 +72,7 @@ struct IsMKLDNNPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
......
......@@ -5,16 +5,22 @@
TEST(Place, Equality) {
paddle::platform::CPUPlace cpu;
paddle::platform::GPUPlace g0(0), g1(1), gg0(0);
paddle::platform::CUDNNPlace d0(0), d1(1), dd0(0);
EXPECT_EQ(cpu, cpu);
EXPECT_EQ(g0, g0);
EXPECT_EQ(g1, g1);
EXPECT_EQ(g0, gg0);
EXPECT_EQ(d0, dd0);
EXPECT_NE(g0, g1);
EXPECT_NE(d0, d1);
EXPECT_TRUE(paddle::platform::places_are_same_class(g0, gg0));
EXPECT_FALSE(paddle::platform::places_are_same_class(g0, cpu));
EXPECT_TRUE(paddle::platform::is_gpu_place(d0));
EXPECT_FALSE(paddle::platform::places_are_same_class(g0, d0));
}
TEST(Place, Default) {
......
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc
DEPS pybind python backward proto_desc paddle_memory executor prune
DEPS pybind python backward proto_desc paddle_memory executor prune init
${GLOB_OP_LIB})
endif(WITH_PYTHON)
......
......@@ -16,11 +16,11 @@ limitations under the License. */
#include <mutex> // for call_once
#include <unordered_map>
#include "gflags/gflags.h"
#include "paddle/framework/backward.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/init.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
......@@ -51,24 +51,6 @@ static size_t UniqueIntegerGenerator(const std::string &prefix) {
return generators[prefix].fetch_add(1);
}
std::once_flag gflags_init_flag;
// TODO(qijun) move init gflags to init.cc
void InitGflags(std::vector<std::string> &argv) {
std::call_once(gflags_init_flag, [&]() {
int argc = argv.size();
char **arr = new char *[argv.size()];
std::string line;
for (size_t i = 0; i < argv.size(); i++) {
arr[i] = &argv[i][0];
line += argv[i];
line += ' ';
}
google::ParseCommandLineFlags(&argc, &arr, true);
VLOG(1) << "Init commandline: " << line;
});
}
bool IsCompileGPU() {
#ifndef PADDLE_WITH_CUDA
return false;
......@@ -438,7 +420,8 @@ All parameter, weight, gradient are variables in Paddle.
.def("run", &Executor::Run);
m.def("unique_integer", UniqueIntegerGenerator);
m.def("init_gflags", InitGflags);
m.def("init_gflags", framework::InitGflags);
m.def("init_devices", &framework::InitDevices);
m.def("is_compile_gpu", IsCompileGPU);
m.def("set_feed_variable", framework::SetFeedVariable);
......
......@@ -1119,8 +1119,9 @@ def simple_gru2(input,
:param gru_bias_attr: bias parameter attribute of gru layer,
False means no bias, None means default bias.
:type gru_bias_attr: ParameterAttribute|False|None
:param gru_layer_attr: Extra attribute of the gru layer.
:type gru_layer_attr: ExtraLayerAttribute
:param gru_param_attr: param parameter attribute of gru layer,
None means default param.
:type gru_param_attr: ParameterAttribute|None
:return: the gru group.
:rtype: LayerOutput
"""
......
......@@ -46,6 +46,13 @@ class Executor(object):
p.set_place(each)
act_places.append(p)
# TODO(dzhwinter) : consider that our fluid tests all written in
# GPUPlace(gpu_id), this will be changed in next PR.
if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"])
else:
core.init_devices(["CPU"])
self.executor = core.Executor(act_places)
self.places = places
......
......@@ -17,5 +17,19 @@ class TestReshapeOp(OpTest):
self.check_grad(["X"], "Out")
class TestReshapeOpDimInfer(OpTest):
def setUp(self):
self.op_type = "reshape"
self.inputs = {'X': np.random.random((10, 20)).astype("float32")}
self.attrs = {'shape': [4, -1, 5]}
self.outputs = {'Out': self.inputs['X'].reshape(self.attrs['shape'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
from test_pool2d_op import max_pool2D_forward_naive
from test_pool2d_op import avg_pool2D_forward_naive
class TestSppOp(OpTest):
def setUp(self):
self.op_type = "spp"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
nsize, csize, hsize, wsize = input.shape
out_level_flatten = []
for i in xrange(self.pyramid_height):
bins = np.power(2, i)
kernel_size = [0, 0]
padding = [0, 0]
kernel_size[0] = np.ceil(hsize /
bins.astype("double")).astype("int32")
padding[0] = (
(kernel_size[0] * bins - hsize + 1) / 2).astype("int32")
kernel_size[1] = np.ceil(wsize /
bins.astype("double")).astype("int32")
padding[1] = (
(kernel_size[1] * bins - wsize + 1) / 2).astype("int32")
out_level = self.pool2D_forward_naive(input, kernel_size,
kernel_size, padding)
out_level_flatten.append(
out_level.reshape(nsize, bins * bins * csize))
if i == 0:
output = out_level_flatten[i]
else:
output = np.concatenate((output, out_level_flatten[i]), 1)
# output = np.concatenate(out_level_flatten.tolist(), 0);
self.inputs = {'X': input.astype('float32'), }
self.attrs = {
'pyramid_height': self.pyramid_height,
'pooling_type': self.pool_type
}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
if self.pool_type != "avg":
self.check_grad(['X'], 'Out', max_relative_error=0.05)
def init_test_case(self):
self.shape = [3, 2, 4, 4]
self.pyramid_height = 3
self.pool2D_forward_naive = max_pool2D_forward_naive
self.pool_type = "max"
class TestCase2(TestSppOp):
def init_test_case(self):
self.shape = [3, 2, 4, 4]
self.pyramid_height = 3
self.pool2D_forward_naive = avg_pool2D_forward_naive
self.pool_type = "avg"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册