提交 cadee843 编写于 作者: C chengduoZH

follow comments

上级 df48b43b
......@@ -195,6 +195,14 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
return result;
}
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int64_t> temp = vectorize(ddim);
std::vector<int> result(temp.begin(), temp.end());
return result;
}
struct ProductVisitor : public boost::static_visitor<int64_t> {
template <int D>
int64_t operator()(const Dim<D>& dim) {
......
......@@ -93,6 +93,7 @@ int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val);
std::vector<int64_t> vectorize(const DDim& ddim);
std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim);
......
......@@ -31,16 +31,6 @@ using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
// NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs.
std::vector<int> Dims2Vector(const framework::DDim& dims) {
std::vector<int> ret;
for (int i = 0; i < dims.size(); i++) {
ret.push_back(dims[i]);
}
return ret;
}
template <typename T>
class CudnnConvOpKernel : public framework::OpKernel<T> {
public:
......@@ -68,12 +58,12 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc =
input_desc.descriptor<T>(layout, Dims2Vector(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, Dims2Vector(output->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc =
filter_desc.descriptor<T>(layout, Dims2Vector(filter->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups);
cudnnConvolutionDescriptor_t cudnn_conv_desc =
conv_desc.descriptor<T>(paddings, strides, dilations);
......@@ -156,13 +146,13 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
ScopedConvolutionDescriptor conv_desc;
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc =
input_desc.descriptor<T>(layout, Dims2Vector(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()), groups);
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(layout, Dims2Vector(output_grad->dims()),
groups);
cudnnFilterDescriptor_t cudnn_filter_desc =
filter_desc.descriptor<T>(layout, Dims2Vector(filter->dims()), groups);
output_grad_desc.descriptor<T>(
layout, framework::vectorize2int(output_grad->dims()), groups);
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor<T>(
layout, framework::vectorize2int(filter->dims()), groups);
cudnnTensorDescriptor_t cudnn_input_grad_desc = nullptr;
cudnnFilterDescriptor_t cudnn_filter_grad_desc = nullptr;
......@@ -192,7 +182,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
auto handle = ctx.cuda_device_context().cudnn_handle();
if (input_grad) {
cudnn_input_grad_desc = input_grad_desc.descriptor<T>(
layout, Dims2Vector(input_grad->dims()), groups);
layout, framework::vectorize2int(input_grad->dims()), groups);
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc,
......@@ -213,7 +203,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
if (filter_grad) {
cudnn_filter_grad_desc = filter_grad_desc.descriptor<T>(
layout, Dims2Vector(filter_grad->dims()), groups);
layout, framework::vectorize2int(filter_grad->dims()), groups);
PADDLE_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, cudnn_conv_desc,
......
......@@ -24,15 +24,6 @@ using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode;
// NOTE: copy from conv_cudnn
std::vector<int> Dims2VectorPool(const framework::DDim &dims) {
std::vector<int> ret;
for (int i = 0; i < dims.size(); i++) {
ret.push_back(dims[i]);
}
return ret;
}
template <typename T>
class PoolCudnnOpKernel : public framework::OpKernel<T> {
public:
......@@ -62,10 +53,10 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc =
input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims()));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
......@@ -120,10 +111,10 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
cudnnTensorDescriptor_t cudnn_input_desc =
input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims()));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize2int(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor<T>(
layout, framework::vectorize2int(output->dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......
......@@ -81,8 +81,8 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"width of feature.");
AddAttr<std::string>("poolingType",
"(string), poolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
"(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
......@@ -90,10 +90,9 @@ Pool2dOpMaker::Pool2dOpMaker(framework::OpProto *proto,
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored and need not be specified.")
AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
......@@ -143,8 +142,8 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"width of feature.");
AddAttr<std::string>("poolingType",
"(string), poolingType of pooling operator."
"Str constant equal to 'max' or 'avg'.")
"(string), pooling type, can be \"max\" for max-pooling "
"and \"avg\" for average-pooling.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>(
"ksize",
......@@ -153,10 +152,9 @@ Pool3dOpMaker::Pool3dOpMaker(framework::OpProto *proto,
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored and need not be specified.")
AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, height, "
......
......@@ -109,10 +109,9 @@ class MaxPool2dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored and need not be specified.")
AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.")
.SetDefault(false);
AddAttr<std::vector<int>>(
"strides",
......@@ -178,10 +177,9 @@ class MaxPool3dWithIndexOpMaker : public framework::OpProtoAndCheckerMaker {
"If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr<bool>(
"globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored and need not be specified.")
AddAttr<bool>("globalPooling",
"(bool default: false), whether to use the global pooling."
"If globalPooling = true, ksize is ignored.")
.SetDefault(false);
AddAttr<std::vector<int>>("strides",
"(vector, default:{1,1,1}), strides(depth, "
......
......@@ -266,9 +266,9 @@ def pool2d(input,
inputs={"X": input},
outputs={"Out": pool_out},
attrs={
"pooling_type": pool_type,
"poolingType": pool_type,
"ksize": pool_size,
"global_pooling": global_pooling,
"globalPooling": global_pooling,
"strides": pool_stride,
"paddings": pool_padding
})
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册