提交 5ec55e79 编写于 作者: Z zchen0211

deconv impl

上级 80ebc8d5
......@@ -31,22 +31,23 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups");
int input_channels = in_dims[1];
int output_channels = filter_dims[0];
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Conv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Conv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
PADDLE_ENFORCE_EQ(
output_channels % groups, 0,
"The number of output channels should be divided by groups.");
for (int i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op.");
}
PADDLE_ENFORCE_EQ(in_dims.size(), 4, "Deconv2DOp input should be 4-D.");
PADDLE_ENFORCE_EQ(filter_dims.size(), 4, "Deconv2DOp filter should be 4-D.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal.");
PADDLE_ENFORCE_EQ(groups, 1,
"The number of groups should be 1 in case of deconv op.");
auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
ctx->SetOutputDim("Output",
{in_dims[0], filter_dims[0], output_height, output_width});
{in_dims[0], filter_dims[1], output_height, output_width});
}
Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
......@@ -55,12 +56,12 @@ Deconv2DOpMaker::Deconv2DOpMaker(framework::OpProto* proto,
AddInput(
"Input",
"The input tensor of deconvolution operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
"The format of input tensor is NMHW. Where N is batch size, M is the "
"number of input channels, H and W is the height and width of image.");
AddInput("Filter",
"The filter tensor of deconvolution operator."
"The format of the filter tensor is MCHW, where M is the number of "
"output image channels, C is the number of input image channels, "
"input image channels, C is the number of output image channels, "
"H and W is height and width of filter. "
"We enforce groups number == 1 and padding == 0 in our "
"deconvolution Scenario.");
......@@ -97,6 +98,6 @@ REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad,
ops::Deconv2DOpGrad);
REGISTER_OP_CPU_KERNEL(
deconv2d, ops::GemmConvGrad2DKernel<paddle::platform::CPUPlace, float>);
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
deconv2d_grad, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
......@@ -23,6 +23,7 @@ namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
// Define Op classes in .h file so that other deconv
// operator implementations can reuse the code.
......@@ -48,5 +49,167 @@ class Deconv2DOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override;
};
template <typename Place, typename T>
class GemmDeconv2DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// filter will be reshaped, so we do not use constant pointer here
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// no paddings and groups allowed in deconv
int N = input->dims()[0];
int M = input->dims()[1];
int H = input->dims()[2];
int W = input->dims()[3];
int K_H = filter.dims()[2];
int K_W = filter.dims()[3];
int C = output->dims()[1]; // output channels
int O_H = output->dims()[2];
int O_W = output->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {M * K_H * K_W, H * W};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W};
DDim filter_matrix_shape = {M, C * K_H * K_W};
filter.Resize(filter_matrix_shape);
// deconvolution: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; i++) {
// batch with size (M, H * W)
Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// output size: (C, O_H, O_W)
Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0));
col2im(context.device_context(), output_batch, col_matrix, strides[0],
strides[1], 0, 0);
}
}
};
/*
template <typename Place, typename T>
class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer
// but we should avoid
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// no paddings and groups allowed in deconv
int N = input->dims()[0];
int M = input->dims()[1];
int H = input->dims()[2];
int W = input->dims()[3];
int K_H = filter.dims()[2];
int K_W = filter.dims()[3];
int C = output->dims()[1]; // output channels
int O_H = output->dims()[2];
int O_W = output->dims()[3];
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {M * K_H * K_W, H * W};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W};
DDim filter_matrix_shape = {M, C* K_H * K_W};
filter.Resize(filter_matrix_shape);
// deconvolution: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; i++) {
// batch with size (M, H * W)
Tensor input_batch =
input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// output size: (C, O_H, O_W)
Tensor output_batch =
output->Slice<T>(i, i + 1).Resize(output_shape);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix,
T(0.0));
col2im(context.device_context(), output_batch, col_matrix, strides[0],
strides[1], 0, 0);
}
}
};
*/
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册