未验证 提交 b756063c 编写于 作者: Q qingqing01 提交者: GitHub

Speed depthwise transposed conv2d. (#11740)

* Speed depthwise transposed conv2d.
上级 8630ba2e
...@@ -302,6 +302,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -302,6 +302,7 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
namespace ops = paddle::operators; namespace ops = paddle::operators;
// conv2d_transpose
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker, ops::Conv2DTransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
...@@ -317,6 +318,7 @@ REGISTER_OP_CPU_KERNEL( ...@@ -317,6 +318,7 @@ REGISTER_OP_CPU_KERNEL(
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>); double>);
// conv3d_transpose
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp, REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
ops::Conv3DTransposeOpMaker, ops::Conv3DTransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
...@@ -331,3 +333,19 @@ REGISTER_OP_CPU_KERNEL( ...@@ -331,3 +333,19 @@ REGISTER_OP_CPU_KERNEL(
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>); double>);
// depthwise conv2d_transpose
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
ops::Conv2DTransposeOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
depthwise_conv2d_transpose_grad,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CPUDeviceContext,
double>);
...@@ -15,25 +15,28 @@ limitations under the License. */ ...@@ -15,25 +15,28 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_transpose_op.h" #include "paddle/fluid/operators/conv_transpose_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL( // conv2d
conv2d_transpose, REGISTER_OP_CUDA_KERNEL(conv2d_transpose,
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>, ops::GemmConvTransposeKernel<CUDA, float>,
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>); ops::GemmConvTransposeKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(conv2d_transpose_grad,
conv2d_transpose_grad, ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext, ops::GemmConvTransposeGradKernel<CUDA, double>);
float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext, // conv3d
double>); REGISTER_OP_CUDA_KERNEL(conv3d_transpose,
ops::GemmConvTransposeKernel<CUDA, float>,
REGISTER_OP_CUDA_KERNEL( ops::GemmConvTransposeKernel<CUDA, double>);
conv3d_transpose, REGISTER_OP_CUDA_KERNEL(conv3d_transpose_grad,
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, float>, ops::GemmConvTransposeGradKernel<CUDA, float>,
ops::GemmConvTransposeKernel<paddle::platform::CUDADeviceContext, double>); ops::GemmConvTransposeGradKernel<CUDA, double>);
REGISTER_OP_CUDA_KERNEL(
conv3d_transpose_grad, // depthwise conv2d
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext, REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose,
float>, ops::DepthwiseConvTransposeKernel<CUDA, float>,
ops::GemmConvTransposeGradKernel<paddle::platform::CUDADeviceContext, ops::DepthwiseConvTransposeKernel<CUDA, double>);
double>); REGISTER_OP_CUDA_KERNEL(depthwise_conv2d_transpose_grad,
ops::DepthwiseConvTransposeGradKernel<CUDA, float>,
ops::DepthwiseConvTransposeGradKernel<CUDA, double>);
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/operators/math/vol2col.h"
...@@ -316,5 +317,74 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -316,5 +317,74 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
} }
} }
}; };
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
int groups = context.Attr<int>("groups");
PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
for (auto v : dilations) {
PADDLE_ENFORCE_EQ(v, 1);
}
output->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, output, static_cast<T>(0));
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
depthwiseConvInputGrad;
depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings,
output);
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
Tensor filter = *context.Input<Tensor>("Filter");
if (!input_grad && !filter_grad) return;
auto& dev_ctx = context.template device_context<DeviceContext>();
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (input_grad) {
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings,
input_grad);
}
if (filter_grad) {
math::SetConstant<DeviceContext, T> set_zero;
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
depthwiseConvFilterGrad;
depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings,
filter_grad);
}
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -2334,10 +2334,17 @@ def conv2d_transpose(input, ...@@ -2334,10 +2334,17 @@ def conv2d_transpose(input,
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32') data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3) conv2d_transpose = fluid.layers.conv2d_transpose(input=data, num_filters=2, filter_size=3)
""" """
helper = LayerHelper("conv2d_transpose", **locals())
input_channel = input.shape[1]
op_type = 'conv2d_transpose'
if (input_channel == groups and num_filters == input_channel and
not use_cudnn):
op_type = 'depthwise_conv2d_transpose'
helper = LayerHelper(op_type, **locals())
if not isinstance(input, Variable): if not isinstance(input, Variable):
raise TypeError("Input of conv2d_transpose must be Variable") raise TypeError("Input of conv2d_transpose must be Variable")
input_channel = input.shape[1]
padding = utils.convert_to_list(padding, 2, 'padding') padding = utils.convert_to_list(padding, 2, 'padding')
stride = utils.convert_to_list(stride, 2, 'stride') stride = utils.convert_to_list(stride, 2, 'stride')
...@@ -2371,7 +2378,7 @@ def conv2d_transpose(input, ...@@ -2371,7 +2378,7 @@ def conv2d_transpose(input,
pre_bias = helper.create_tmp_variable(dtype=input.dtype) pre_bias = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op( helper.append_op(
type='conv2d_transpose', type=op_type,
inputs={'Input': [input], inputs={'Input': [input],
'Filter': [img_filter]}, 'Filter': [img_filter]},
outputs={'Output': pre_bias}, outputs={'Output': pre_bias},
......
...@@ -242,6 +242,19 @@ class TestCUDNNWithGroups(TestWithGroups): ...@@ -242,6 +242,19 @@ class TestCUDNNWithGroups(TestWithGroups):
self.op_type = "conv2d_transpose" self.op_type = "conv2d_transpose"
class TestDepthwiseConvTranspose(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [2, 2]
self.dilations = [1, 1]
self.input_size = [2, 8, 16, 16] # NCHW
self.groups = 8
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [self.input_size[1], f_c, 4, 4]
self.op_type = "depthwise_conv2d_transpose"
# Please Don't remove the following code. # Please Don't remove the following code.
# Currently, CI use cudnn V5.0 which not support dilation conv. # Currently, CI use cudnn V5.0 which not support dilation conv.
# class TestCUDNNWithDilation(TestWithDilation): # class TestCUDNNWithDilation(TestWithDilation):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册