提交 cadee843 编写于 作者: C chengduoZH

follow comments

上级 df48b43b
...@@ -195,6 +195,14 @@ std::vector<int64_t> vectorize(const DDim& ddim) { ...@@ -195,6 +195,14 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
return result; return result;
} }
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
}
struct ProductVisitor : public boost::static_visitor<int64_t> { struct ProductVisitor : public boost::static_visitor<int64_t> {
template <int D> template <int D>
int64_t operator()(const Dim<D>& dim) { int64_t operator()(const Dim<D>& dim) {
......
...@@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx); ...@@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val); void set(DDim& dim, int idx, int val);
std::vector<int64_t> vectorize(const DDim& ddim); std::vector<int64_t> vectorize(const DDim& ddim);
std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim); int64_t product(const DDim& ddim);
......
...@@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext; ...@@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> Dims2Vector(const framework::DDim& dims) {
std::vector<int> ret;
for (int i = 0; i < dims.size(); i++) {
ret.push_back(dims[i]);
}
return ret;
}
template <typename T> template <typename T>
class CudnnConvOpKernel : public framework::OpKernel<T> { class CudnnConvOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel<T> { ...@@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
input_desc.descriptor<T>(layout, Dims2Vector(input->dims()), groups); layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc = cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
output_desc.descriptor<T>(layout, Dims2Vector(output->dims()), groups); layout, framework::vectorize2int(output->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
filter_desc.descriptor<T>(layout, Dims2Vector(filter->dims()), groups); layout, framework::vectorize2int(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc = cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations); conv_desc.descriptor<T>(paddings, strides, dilations);
...@@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc; ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
input_desc.descriptor<T>(layout, Dims2Vector(input->dims()), groups); layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_grad_desc = cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(layout, Dims2Vector(output_grad->dims()), output_grad_desc.descriptor<T>(
groups); layout, framework::vectorize2int(output_grad->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
filter_desc.descriptor<T>(layout, Dims2Vector(filter->dims()), groups); layout, framework::vectorize2int(filter->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr; cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr;
cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr; cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr;
...@@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
if (input_grad) { if (input_grad) {
cudnn_input_grad_desc = input_grad_desc.descriptor<T>( cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
layout, Dims2Vector(input_grad->dims()), groups); layout, framework::vectorize2int(input_grad->dims()), groups);
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, handle, cudnn_filter_desc,
...@@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> { ...@@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) { if (filter_grad) {
cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>( cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>(
layout, Dims2Vector(filter_grad->dims()), groups); layout, framework::vectorize2int(filter_grad->dims()), groups);
PADDLE_ENFORCE( PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
......
...@@ -24,15 +24,6 @@ using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor; ...@@ -24,15 +24,6 @@ using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode; using PoolingMode = platform::PoolingMode;
// NOTE: copy from conv_cudnn
std::vector<int> Dims2VectorPool(const framework::DDim &dims) {
std::vector<int> ret;
for (int i = 0; i < dims.size(); i++) {
ret.push_back(dims[i]);
}
return ret;
}
template <typename T> template <typename T>
class PoolCudnnOpKernel : public framework::OpKernel<T> { class PoolCudnnOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -62,10 +53,10 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> { ...@@ -62,10 +53,10 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
ScopedPoolingDescriptor pool_desc; ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims())); layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims())); layout, framework::vectorize2int(output->dims()));
PoolingMode pooling_mode; PoolingMode pooling_mode;
if (pooling_type == "max") { if (pooling_type == "max") {
...@@ -120,10 +111,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -120,10 +111,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
ScopedPoolingDescriptor pool_desc; ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc = cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims())); layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims())); layout, framework::vectorize2int(output->dims()));
PoolingMode pooling_mode; PoolingMode pooling_mode;
if (pooling_type == "max") { if (pooling_type == "max") {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
...@@ -81,8 +81,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -81,8 +81,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"width of feature."); "width of feature.");
AddAttr<std::string>("poolingType", AddAttr<std::string>("poolingType",
"(string), poolingType of pooling operator." "(string), pooling type, can be \"max\" for max-pooling "
"Str constant equal to 'max' or 'avg'.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
...@@ -90,10 +90,9 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto, ...@@ -90,10 +90,9 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"globalPooling", "(bool default: false), whether to use the global pooling."
"(bool default: false), whether to use the global pooling." "If globalPooling = true, ksize is ignored.")
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -143,8 +142,8 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, ...@@ -143,8 +142,8 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"width of feature."); "width of feature.");
AddAttr<std::string>("poolingType", AddAttr<std::string>("poolingType",
"(string), poolingType of pooling operator." "(string), pooling type, can be \"max\" for max-pooling "
"Str constant equal to 'max' or 'avg'.") "and \"avg\" for average-pooling.")
.InEnum({"max", "avg"}); .InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
...@@ -153,10 +152,9 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto, ...@@ -153,10 +152,9 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"globalPooling", "(bool default: false), whether to use the global pooling."
"(bool default: false), whether to use the global pooling." "If globalPooling = true, ksize is ignored.")
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, height, " "(vector, default:{1,1,1}), strides(depth, height, "
......
...@@ -109,10 +109,9 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -109,10 +109,9 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"globalPooling", "(bool default: false), whether to use the global pooling."
"(bool default: false), whether to use the global pooling." "If globalPooling = true, ksize is ignored.")
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -178,10 +177,9 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -178,10 +177,9 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"If globalPooling = true, ksize is ignored and need not be " "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently, "specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.) // TypedAttrChecker don't support vector type.)
AddAttr<bool>( AddAttr<bool>("globalPooling",
"globalPooling", "(bool default: false), whether to use the global pooling."
"(bool default: false), whether to use the global pooling." "If globalPooling = true, ksize is ignored.")
"If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, " "(vector, default:{1,1,1}), strides(depth, "
......
...@@ -266,9 +266,9 @@ def pool2d(input, ...@@ -266,9 +266,9 @@ def pool2d(input,
inputs={"X": input}, inputs={"X": input},
outputs={"Out": pool_out}, outputs={"Out": pool_out},
attrs={ attrs={
"pooling_type": pool_type, "poolingType": pool_type,
"ksize": pool_size, "ksize": pool_size,
"global_pooling": global_pooling, "globalPooling": global_pooling,
"strides": pool_stride, "strides": pool_stride,
"paddings": pool_padding "paddings": pool_padding
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册