提交 06db7038 编写于 作者: X xzl

../../../../../paddle/api

上级 3772d27d
......@@ -155,7 +155,8 @@ op_library(parallel_do_op DEPS executor)
# Regist multiple Kernel to pybind
if (WITH_GPU)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS vol2col)
op_library(conv_op SRCS conv_op.cc conv_op.cu.cc conv_cudnn_op.cu.cc DEPS
vol2col depthwise_conv)
op_library(pool_op SRCS pool_op.cc pool_op.cu.cc pool_cudnn_op.cu.cc DEPS pooling)
op_library(conv_transpose_op SRCS conv_transpose_op.cc conv_transpose_op.cu.cc
conv_transpose_cudnn_op.cu.cc DEPS vol2col)
......
......@@ -318,15 +318,20 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
namespace ops = paddle::operators;
REGISTER_OP(conv2d, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad,
ops::ConvOpGrad);
REGISTER_OP(depthwiseConv, ops::ConvOp, ops::Conv2DOpMaker, conv2d_grad,
REGISTER_OP(depthwiseConv, ops::ConvOp, ops::Conv2DOpMaker, depthwiseConv_grad,
ops::ConvOpGrad);
REGISTER_OP(conv3d, ops::ConvOp, ops::Conv3DOpMaker, conv3d_grad,
ops::ConvOpGrad);
REGISTER_OP_CPU_KERNEL(
depthwiseConv,
ops::DepthwiseConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::DepthwiseConvKernel<paddle::platform::CPUDeviceContext, double>);
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
depthwiseConv_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -364,18 +364,15 @@ class DepthwiseConvKernel : public framework::OpKernel<T> {
Tensor* output = context.Output<Tensor>("Output");
output->mutable_data<T>(context.GetPlace());
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
auto& dev_ctx = context.template device_context<DeviceContext>();
depthwiseConv(dev_ctx, input, filter, filter_shape_vec, strides, paddings,
depthwiseConv(dev_ctx, *input, filter, ksize, strides, paddings,
output);
}
};
......
......@@ -8,6 +8,7 @@ if(WITH_GPU)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS device_context)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(depthwise_conv SRCS depthwise_conv.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/pooling.h"
#include "paddle/operators/math/depthwise_conv.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
......@@ -195,7 +195,7 @@ __global__ void KernelDepthwiseConvFilterGrad(const int num_i,
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template <typename T>
template <class T>
class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
......@@ -226,7 +226,7 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
KernelDepthwiseConv<T><<<grid, threads, 0, STREAM_DEFAULT>>>(
KernelDepthwiseConv<T><<<grid, threads, 0, context.stream()>>>(
nthreads, input_data, filter_data, batch_size, output_channels,
output_height, output_width, input_channels, input_height, input_width,
output_channels / input_channels, ksize_height, ksize_width,
......@@ -236,7 +236,6 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T> {
};
/*
template <typename T>
class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, PoolProcess, T>
{
......@@ -254,8 +253,7 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, PoolProcess, T>
const int output_height = output.dims()[2];
const int output_width = output.dims()[3];
const int ksize_height = ksize[0];
const int ksize_width = ksize[1];
const int stride_height = strides[0];
const int ksize_width = ksize[1]; const int stride_height = strides[0];
const int stride_width = strides[1];
const int padding_height = paddings[0];
const int padding_width = paddings[1];
......@@ -321,24 +319,20 @@ class DepthwiseConvdFilterGradFunctor<platform::CUDADeviceContext, T> {
*/
template class DepthwiseConvFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPool<float>,
float>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext,
double>;
/*
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPoolGrad<float>,
float>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPoolGrad<float>,
float>;
template class DepthwiseConvFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPool<double>, double>;
template class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPoolGrad<double>,
double>;
template class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext,
paddle::operators::math::MaxPoolGrad<double>,
double>;
*/
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册