未验证 提交 b76343c3 编写于 作者: L lvmengsi 提交者: GitHub

cpu Conv double grad (#19672)

* cpu conv_grad_grad
上级 754fd57e
......@@ -510,3 +510,8 @@ REGISTER_OP_KERNEL(conv3d, CUDNN, plat::CUDAPlace,
REGISTER_OP_KERNEL(conv3d_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvGradOpKernel<float>,
paddle::operators::CUDNNConvGradOpKernel<double>);
REGISTER_OP_KERNEL(
conv3d_grad_grad, CUDNN, plat::CUDAPlace,
paddle::operators::CUDNNConvDoubleGradOpKernel<float>,
paddle::operators::CUDNNConvDoubleGradOpKernel<double>,
paddle::operators::CUDNNConvDoubleGradOpKernel<plat::float16>);
......@@ -565,6 +565,40 @@ class Conv2DDoubleGradMaker : public framework::SingleGradOpDescMaker {
}
};
/*
* Inputs: I, W, dO, ddI, ddW
* Outputs: ddO, dW, dI
*/
class Conv3DDoubleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType(this->ForwardOpType() + "_grad");
// I, W, dO, ddI, ddW
op->SetInput("Input", Input("Input"));
op->SetInput("Filter", Input("Filter"));
op->SetInput("DOutput", Input(framework::GradVarName("Output")));
op->SetInput("DDInput", OutputGrad(framework::GradVarName("Input")));
op->SetInput("DDFilter", OutputGrad(framework::GradVarName("Filter")));
auto ddx = OutputGrad(framework::GradVarName("Input"));
auto ddw = OutputGrad(framework::GradVarName("Filter"));
std::vector<std::string> empty_str = {};
op->SetOutput(
"DDOutput",
ddx.empty() ? empty_str : InputGrad(framework::GradVarName("Output")));
op->SetOutput("DFilter", ddx.empty() ? empty_str : InputGrad("Filter"));
op->SetOutput("DInput", ddw.empty() ? empty_str : InputGrad("Input"));
op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDesc>(op);
}
};
void ConvOpDoubleGrad::InferShape(framework::InferShapeContext* ctx) const {
auto x_dims = ctx->GetInputDim("Input");
auto w_dims = ctx->GetInputDim("Filter");
......@@ -592,8 +626,14 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
#ifdef PADDLE_WITH_CUDA
if (platform::CanCUDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kCUDNN;
} else {
PADDLE_THROW("Now ConvDoubleGrad only supports cuDNN.");
}
#endif
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
customized_type_value = kConvMKLDNNFP32;
}
#endif
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
......@@ -637,7 +677,8 @@ REGISTER_OPERATOR(depthwise_conv2d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d, ops::ConvOp, ops::Conv3DOpMaker,
ops::ConvOpInferVarType, ops::Conv3DGradMaker);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad);
REGISTER_OPERATOR(conv3d_grad, ops::ConvOpGrad, ops::Conv3DDoubleGradMaker);
REGISTER_OPERATOR(conv3d_grad_grad, ops::ConvOpDoubleGrad);
// depthwise conv kernel
// TODO(xingzhaolong): neon kernel for mobile
......@@ -658,6 +699,10 @@ REGISTER_OP_CPU_KERNEL(
conv2d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv2d_grad_grad,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d, ops::GemmConvKernel<paddle::platform::CPUDeviceContext, float>,
......@@ -666,3 +711,7 @@ REGISTER_OP_CPU_KERNEL(
conv3d_grad,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvGradKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
conv3d_grad_grad,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::GemmConvDoubleGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/depthwise_conv.h"
#include "paddle/fluid/operators/math/im2col.h"
......@@ -393,6 +394,218 @@ class GemmConvGradKernel : public framework::OpKernel<T> {
}
};
template <typename DeviceContext, typename T>
class GemmConvDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
"It must use CPUPlace.");
const Tensor* X = ctx.Input<Tensor>("Input");
const Tensor* dY = ctx.Input<Tensor>("DOutput");
const Tensor* ddX = ctx.Input<Tensor>("DDInput");
const Tensor* ddW_in = ctx.Input<Tensor>("DDFilter");
Tensor* ddY = ctx.Output<Tensor>("DDOutput");
Tensor* dW = ctx.Output<Tensor>("DFilter");
Tensor* dX = ctx.Output<Tensor>("DInput");
Tensor W = detail::Ref(ctx.Input<Tensor>("Filter"),
"Cannot find input Filter(%s) in scope)",
ctx.Inputs("Filter")[0]);
if (!ddY && !dW && !dX) return;
int groups = ctx.Attr<int>("groups");
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
const int batch_size = static_cast<int>(X->dims()[0]);
std::vector<int64_t> filter_shape_vec(framework::vectorize(W.dims()));
std::vector<int64_t> output_shape_vec(framework::vectorize(dY->dims()));
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
// col_shape [in_channel/group, kh, kw, oh, ow]
col_shape_vec[0] = X->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + data_dim + 1] = output_shape_vec[j + 2];
}
framework::DDim col_shape(framework::make_ddim(col_shape_vec));
// col_matrix_shape [in_channel/group * kh * kw, oh * ow]
framework::DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, data_dim + 1);
// input_shape [Cin, H, W]
framework::DDim input_shape =
framework::slice_ddim(X->dims(), 1, X->dims().size());
// filter_matrix_shape [Cout, Cin * kh * kw]
framework::DDim filter_matrix_shape = {W.dims()[0],
W.numel() / W.dims()[0]};
W.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
dY->dims()[1], dY->numel() / (dY->dims()[0] * dY->dims()[1])};
int in_step = static_cast<int>(X->dims()[1]) / groups;
int out_step = static_cast<int>(dY->dims()[1]) / groups;
bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations);
Tensor col;
Tensor col_matrix;
if (is_expand) {
col = ctx.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
math::SetConstant<DeviceContext, T> set_zero;
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
// dx convolution double grad: gemm + col2im(col2vol)
// dx = ddw * dy ==> dx(N, Cin, H, W), ddw(Cout, Cin, kh, kw), dy(N, Cout,
// oH, oW)
if (dX && ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
dX->mutable_data<T>(ctx.GetPlace());
// if is_expand is false, the operation of set_zero is unnecessary
// because math::matmul will reset dx
if (is_expand) {
set_zero(dev_ctx, dX, static_cast<T>(0));
}
math::Col2VolFunctor<DeviceContext, T> col2vol;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
for (int i = 0; i < batch_size; i++) {
Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor dx_batch = dX->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; g++) {
// gemm
Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
Tensor dx_slice = dx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col_matrix.ShareDataWith(dx_slice);
col_matrix.Resize(col_matrix_shape);
}
blas.MatMul(ddw_slice, true, dy_slice, false, T(1.0), &col_matrix,
T(0.0));
if (is_expand && data_dim == 2U) {
col2im(dev_ctx, col, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&dx_slice);
} else if (is_expand && data_dim == 3U) {
col2vol(dev_ctx, col, dilations, strides, paddings, &dx_slice);
}
}
}
}
// dw = ddx * dy ==> dw(Cout, Cin, kh, kw), ddx(N, Cin, H, W), dy(N, Cout,
// oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm
if (dW) {
dW->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, dW, static_cast<T>(0));
Tensor dW_arr = *dW;
dW_arr.Resize(filter_matrix_shape);
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
Tensor dy_batch = dY->Slice(i, i + 1).Resize(output_matrix_shape);
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
for (int g = 0; g < groups; ++g) {
// im2col
Tensor dy_slice = dy_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
im2col(dev_ctx, ddx_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
Tensor dw_slice = dW_arr.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(dy_slice, false, col_matrix, true, T(1.0), &dw_slice,
T(1.0));
}
}
}
// ddy = w * ddx + x * ddw ==> ddy(N, Cout, oH, oW), x/ddx(N, Cin, H, W),
// w/ddw(Cout, Cin, kh, kw)
// ddy convolution double grad: im2col(vol2col) + gemm
if (ddY) {
ddY->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, ddY, static_cast<T>(0));
math::Im2ColFunctor<math::ColFormat::kCFO, DeviceContext, T> im2col;
math::Vol2ColFunctor<DeviceContext, T> vol2col;
for (int i = 0; i < batch_size; ++i) {
Tensor ddx_batch = ddX->Slice(i, i + 1).Resize(input_shape);
Tensor x_batch = X->Slice(i, i + 1).Resize(input_shape);
Tensor ddy_batch = ddY->Slice(i, i + 1).Resize(output_matrix_shape);
for (int g = 0; g < groups; ++g) {
Tensor x_slice = x_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor ddx_slice = ddx_batch.Slice(g * in_step, (g + 1) * in_step);
if (!is_expand) {
col.ShareDataWith(ddx_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(dev_ctx, ddx_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(dev_ctx, ddx_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor ddy_slice = ddy_batch.Slice(g * out_step, (g + 1) * out_step);
Tensor w_slice = W.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(w_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(0.0));
if (ddW_in) {
Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
if (!is_expand) {
col.ShareDataWith(x_slice);
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
} else if (data_dim == 2U) {
// im2col
im2col(dev_ctx, x_slice, dilations, strides,
std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&col);
} else if (data_dim == 3U) {
// vol2col
vol2col(dev_ctx, x_slice, dilations, strides, paddings, &col);
}
// gemm
Tensor ddw_slice = ddW.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(ddw_slice, false, col_matrix, false, T(1.0), &ddy_slice,
T(1.0));
}
}
}
}
}
};
template <typename DeviceContext, typename T>
class DepthwiseConvKernel : public framework::OpKernel<T> {
public:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
from decorator_helper import prog_scope
class TestConvDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 7, 8]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 4, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConvDoubleGradCheckTest1(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 4, 5]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 4, 1, padding=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3DDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 3, 4, 2]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(x, 4, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestConv3DDoubleGradCheckTest1(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 5, 3, 2]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv3d(x, 4, 1, padding=1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
......@@ -43,30 +43,6 @@ class TestMulGradCheck(unittest.TestCase):
self.func(p)
class TestConvDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 14, 16]
eps = 0.005
dtype = np.float64
x = layers.data('x', shape, False, dtype)
y = layers.conv2d(x, 4, 1, bias_attr=False)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
w = fluid.default_main_program().global_block().all_parameters()
w_arr = []
for p in w:
w_arr.append(np.random.uniform(-1, 1, p.shape).astype(dtype))
gradient_checker.double_grad_check(
[x] + w, y, x_init=[x_arr] + w_arr, place=place, eps=eps)
def test_grad(self):
if core.is_compiled_with_cuda():
places = [fluid.CUDAPlace(0)]
for p in places:
self.func(p)
class TestReduceMeanWithDimDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册