From 8df5b4d6084205a6e58d23ada78ad0ebc931226d Mon Sep 17 00:00:00 2001 From: LielinJiang <50691816+LielinJiang@users.noreply.github.com> Date: Tue, 8 Sep 2020 15:37:05 +0800 Subject: [PATCH] Add correlation api to contrib (#27015) * add correlation api to contrib --- paddle/fluid/operators/correlation_op.cc | 181 +++++++ paddle/fluid/operators/correlation_op.cu | 483 ++++++++++++++++++ python/paddle/fluid/contrib/layers/nn.py | 83 ++- .../fluid/contrib/tests/test_correlation.py | 163 ++++++ 4 files changed, 908 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/correlation_op.cc create mode 100644 paddle/fluid/operators/correlation_op.cu create mode 100644 python/paddle/fluid/contrib/tests/test_correlation.py diff --git a/paddle/fluid/operators/correlation_op.cc b/paddle/fluid/operators/correlation_op.cc new file mode 100644 index 0000000000..a2e6ff214b --- /dev/null +++ b/paddle/fluid/operators/correlation_op.cc @@ -0,0 +1,181 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +inline std::vector CorrelationOutputSize(int batch, int input_height, + int input_width, int stride1, + int stride2, int kernel_size, + int pad_size, + int max_displacement) { + std::vector output_shape({batch}); + int kernel_radius = (kernel_size - 1) / 2; + int border_radius = kernel_radius + max_displacement; + int padded_input_height = input_height + 2 * pad_size; + int padded_input_width = input_width + 2 * pad_size; + int output_channel = ((max_displacement / stride2) * 2 + 1) * + ((max_displacement / stride2) * 2 + 1); + output_shape.push_back(output_channel); + int output_height = + std::ceil(static_cast(padded_input_height - 2 * border_radius) / + static_cast(stride1)); + int output_width = + std::ceil(static_cast(padded_input_width - 2 * border_radius) / + static_cast(stride1)); + output_shape.push_back(output_height); + output_shape.push_back(output_width); + return output_shape; +} + +class CorrelationOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Input1", "Input is a 4-D Tensor with shape [N, C, H, W]"); + AddInput("Input2", "Input is a 4-D Tensor with shape [N, C, H, W]"); + AddOutput("Output", + "(Tensor) The output tensor of correlation operator. " + "It has same data fromat and data type as the Input."); + AddAttr("pad_size", "pad size for input1 and input2"); + AddAttr("kernel_size", "kernel size of input1 and input2"); + AddAttr("max_displacement", "max displacement of input1 and input2"); + AddAttr("stride1", "Input1 stride"); + AddAttr("stride2", "Input2 stride"); + AddAttr("corr_type_multiply", "correlation coefficient").SetDefault(1); + AddComment( + R"DOC(Correlation of two feature map. Only support NCHW data format.)DOC"); + } +}; + +class CorrelationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input1"), "Input", "X", "CorrelationOp"); + OP_INOUT_CHECK(ctx->HasInput("Input2"), "Input", "Y", "CorrelationOp"); + int stride1 = ctx->Attrs().Get("stride1"); + int stride2 = ctx->Attrs().Get("stride2"); + int max_displacement = ctx->Attrs().Get("max_displacement"); + int pad_size = ctx->Attrs().Get("pad_size"); + int kernel_size = ctx->Attrs().Get("kernel_size"); + + auto in_dims = ctx->GetInputDim("Input1"); + auto in2_dims = ctx->GetInputDim("Input2"); + + PADDLE_ENFORCE_EQ(in_dims.size() == 4, true, + platform::errors::InvalidArgument( + "Input(X) of CorrelationOp must be 4 dims." + "But received dims is %d.", + in_dims.size())); + + PADDLE_ENFORCE_EQ(in2_dims.size() == 4, true, + platform::errors::InvalidArgument( + "Input(Y) of CorrelationOp must be 4 dims." + "But received dims is %d.", + in2_dims.size())); + std::vector output_shape = + CorrelationOutputSize(in_dims[0], in_dims[2], in_dims[3], stride1, + stride2, kernel_size, pad_size, max_displacement); + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + OperatorWithKernel::IndicateVarDataType(ctx, "Input1"); + PADDLE_ENFORCE_EQ(input_data_type, ctx.Input("Input2")->type(), + platform::errors::InvalidArgument( + "X and Y shoule have the same datatype")); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + tensor.place(), tensor.layout()); + } +}; + +template +class CorrelationOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("correlation_grad"); + op->SetInput("Input1", this->Input("Input1")); + op->SetInput("Input2", this->Input("Input2")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetOutput(framework::GradVarName("Input1"), this->InputGrad("Input1")); + op->SetOutput(framework::GradVarName("Input2"), this->InputGrad("Input2")); + op->SetAttrMap(this->Attrs()); + } +}; + +class CorrelationOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input1"), "Input", "X", "CorrelationOp"); + OP_INOUT_CHECK(ctx->HasInput("Input2"), "Input", "Y", "CorrelationOp"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Output")), "Input", + "Output@GRAD", "CorrelationGradOp"); + + auto in1_dims = ctx->GetInputDim("Input1"); + auto in2_dims = ctx->GetInputDim("Input2"); + ctx->SetOutputDim(framework::GradVarName("Input1"), in1_dims); + ctx->SetOutputDim(framework::GradVarName("Input2"), in2_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace()); + } +}; + +template +class CorrelationKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE_EQ( + platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::Unimplemented("Correlation only supports GPU now.")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(correlation, ops::CorrelationOp, ops::CorrelationOpMaker, + ops::CorrelationOpGradMaker, + ops::CorrelationOpGradMaker); +REGISTER_OPERATOR(correlation_grad, ops::CorrelationOpGrad); +REGISTER_OP_CPU_KERNEL(correlation, ops::CorrelationKernel, + ops::CorrelationKernel); diff --git a/paddle/fluid/operators/correlation_op.cu b/paddle/fluid/operators/correlation_op.cu new file mode 100644 index 0000000000..0d177f653e --- /dev/null +++ b/paddle/fluid/operators/correlation_op.cu @@ -0,0 +1,483 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +#define THREADS_PER_BLOCK 32 +#define FULL_MASK 0xffffffff + +using framework::Tensor; +using DataLayout = framework::DataLayout; + +template +__forceinline__ __device__ T warpReduceSum(T val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(FULL_MASK, val, offset); + } + return val; +} + +template +__forceinline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) shared[wid] = val; + + __syncthreads(); + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + + if (wid == 0) val = warpReduceSum(val); + + return val; +} + +template +__global__ void set_zero(T *x, int num) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; + i += blockDim.x * gridDim.x) + x[i] = static_cast(0); +} + +template +__global__ void channel_first(const T *input, T *rinput, const int channel, + const int height, const int width, + const int pad_size) { + int n = blockIdx.x; + int h = blockIdx.y; + int w = blockIdx.z; + + int ch_off = threadIdx.x; + T value; + int dimchw = channel * height * width; + int dimhw = height * width; + + int p_dimw = (width + 2 * pad_size); + int p_dimh = (height + 2 * pad_size); + int p_dimchw = channel * p_dimw * p_dimh; + int p_dimcw = channel * p_dimw; + + for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) { + value = input[n * dimchw + c * dimhw + h * width + w]; + rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel + + c] = value; + } +} + +template +__global__ void correlation_forward( + T *output, const int output_channel, const int output_height, + const int output_width, const T *rinput1, const int input_channel, + const int input_height, const int input_width, const T *rinput2, + const int pad_size, const int kernel_size, const int max_displacement, + const int stride1, const int stride2) { + int p_input_width = input_width + 2 * pad_size; + int p_input_height = input_height + 2 * pad_size; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + + int displacement_size = 2 * displacement_rad + 1; + + int n = blockIdx.x; + int h1 = blockIdx.y * stride1 + max_displacement; + int w1 = blockIdx.z * stride1 + max_displacement; + int c = threadIdx.x; + + int p_dimchw = p_input_height * p_input_width * input_channel; + int p_dimcw = p_input_width * input_channel; + int p_dimc = input_channel; + + int t_dimchw = output_channel * output_height * output_width; + int t_dimhw = output_height * output_width; + int t_dimw = output_width; + + int nelems = kernel_size * kernel_size * p_dimc; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for (int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + int w2 = w1 + ti * stride2; + int h2 = h1 + tj * stride2; + + T acc0 = 0; + for (int j = -kernel_rad; j <= kernel_rad; ++j) { + for (int i = -kernel_rad; i <= kernel_rad; ++i) { + for (int ch = c; ch < p_dimc; ch += blockDim.x) { + int index1 = + n * p_dimchw + (h1 + j) * p_dimcw + (w1 + i) * p_dimc + ch; + int index2 = + n * p_dimchw + (h2 + j) * p_dimcw + (w2 + i) * p_dimc + ch; + acc0 += static_cast(rinput1[index1] * rinput2[index2]); + } + } + } + if (blockDim.x == warpSize) { + __syncwarp(); + acc0 = warpReduceSum(acc0); + } else { + __syncthreads(); + acc0 = blockReduceSum(acc0); + } + + if (threadIdx.x == 0) { + int tc = (tj + displacement_rad) * displacement_size + + (ti + displacement_rad); + const int t_index = + n * t_dimchw + tc * t_dimhw + blockIdx.y * t_dimw + blockIdx.z; + output[t_index] = static_cast(acc0 / nelems); + } + } + } +} + +// class CorrelationKernel +template +class CorrelationCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "Correlation only supports GPU now.")); + + auto *input1 = ctx.Input("Input1"); + auto *input2 = ctx.Input("Input2"); + int pad_size = ctx.Attr("pad_size"); + int kernel_size = ctx.Attr("kernel_size"); + int stride1 = ctx.Attr("stride1"); + int stride2 = ctx.Attr("stride2"); + int max_displacement = ctx.Attr("max_displacement"); + int corr_type_multiply = ctx.Attr("corr_type_multiply"); + + auto *output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + + // base on input1, NCHW + auto in_dims = input1->dims(); + int N = in_dims[0]; + int C = in_dims[1]; + int H = in_dims[2]; + int W = in_dims[3]; + + int padded_input_height = H + 2 * pad_size; + int padded_input_width = W + 2 * pad_size; + + Tensor rinput1 = ctx.AllocateTmpTensor( + {N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput1.mutable_data(ctx.GetPlace()); + + Tensor rinput2 = ctx.AllocateTmpTensor( + {N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput2.mutable_data(ctx.GetPlace()); + + set_zero<<<(rinput1.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>( + rinput1.data(), rinput1.numel()); + set_zero<<<(rinput2.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>( + rinput2.data(), rinput2.numel()); + set_zero<<<(output->numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>( + output->data(), output->numel()); + + auto out_dims = output->dims(); + int OC = out_dims[1]; + int OH = out_dims[2]; + int OW = out_dims[3]; + + dim3 blocks_grid(N, H, W); + dim3 threads_block(THREADS_PER_BLOCK); + + channel_first<<>>( + input1->data(), rinput1.data(), C, H, W, pad_size); + channel_first<<>>( + input2->data(), rinput2.data(), C, H, W, pad_size); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(N, OH, OW); + + correlation_forward< + T><<>>( + output->data(), OC, OH, OW, rinput1.data(), C, H, W, + rinput2.data(), pad_size, kernel_size, max_displacement, stride1, + stride2); + } +}; + +template +__global__ void correlation_backward_input1( + int item, T *grad_input1, const int input_channel, const int input_height, + const int input_width, const T *grad_output, const int output_channel, + const int output_height, const int output_width, const T *rinput2, + const int pad_size, const int kernel_size, const int max_displacement, + const int stride1, const int stride2) { + int n = item; + int h = blockIdx.x * stride1 + pad_size; + int w = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int xmin = (w - kernel_rad - max_displacement) / stride1; + int ymin = (h - kernel_rad - max_displacement) / stride1; + + int xmax = (w + kernel_rad - max_displacement) / stride1; + int ymax = (h + kernel_rad - max_displacement) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) { + return; + } + + if (xmin > xmax || ymin > ymax) { + return; + } + + xmin = max(0, xmin); + xmax = min(output_width - 1, xmax); + + ymin = max(0, ymin); + ymax = min(output_height - 1, ymax); + + int p_input_width = input_width + 2 * pad_size; + int p_input_height = input_height + 2 * pad_size; + int p_dimchw = input_channel * p_input_height * p_input_width; + int p_dimcw = input_channel * p_input_width; + int p_dimc = input_channel; + + int t_dimchw = output_channel * output_height * output_width; + int t_dimhw = output_height * output_width; + int t_dimw = output_width; + + int o_dimchw = input_channel * input_height * input_width; + int o_dimhw = input_height * input_width; + int o_dimw = input_width; + + int nelems = kernel_size * kernel_size * input_channel; + + __shared__ T prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int index2 = n * p_dimchw + (h + j2) * p_dimcw + (w + i2) * p_dimc + c; + + T val2 = rinput2[index2]; + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i; + prod_sum[tch_off] += grad_output[t_index] * val2; + } + } + } + + __syncthreads(); + + if (tch_off == 0) { + T reduce_sum = 0; + for (int index = 0; index < THREADS_PER_BLOCK; index++) { + reduce_sum += prod_sum[index]; + } + const int index1 = + n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size); + grad_input1[index1] = static_cast(reduce_sum / nelems); + } +} + +template +__global__ void correlation_backward_input2( + int item, T *grad_input2, const int input_channel, const int input_height, + const int input_width, const T *grad_output, const int output_channel, + const int output_height, const int output_width, const T *rinput1, + const int pad_size, const int kernel_size, const int max_displacement, + const int stride1, const int stride2) { + int n = item; + int h = blockIdx.x * stride1 + pad_size; + int w = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int p_input_width = input_width + 2 * pad_size; + int p_input_height = input_height + 2 * pad_size; + int p_dimchw = input_channel * p_input_height * p_input_width; + int p_dimcw = input_channel * p_input_width; + int p_dimc = input_channel; + + int t_dimchw = output_channel * output_height * output_width; + int t_dimhw = output_height * output_width; + int t_dimw = output_width; + + int o_dimchw = input_channel * input_height * input_width; + int o_dimhw = input_height * input_width; + int o_dimw = input_width; + + int nelems = kernel_size * kernel_size * input_channel; + + __shared__ T prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int xmin = (w - kernel_rad - max_displacement - i2) / stride1; + int ymin = (h - kernel_rad - max_displacement - j2) / stride1; + + int xmax = (w + kernel_rad - max_displacement - i2) / stride1; + int ymax = (h + kernel_rad - max_displacement - j2) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) { + continue; + } + + if (xmin > xmax || ymin > ymax) { + continue; + } + + xmin = max(0, xmin); + xmax = min(output_width - 1, xmax); + + ymin = max(0, ymin); + ymax = min(output_height - 1, ymax); + + int index1 = n * p_dimchw + (h - j2) * p_dimcw + (w - i2) * p_dimc + c; + T val1 = rinput1[index1]; + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i; + prod_sum[tch_off] += grad_output[t_index] * val1; + } + } + } + + __syncthreads(); + + if (tch_off == 0) { + T reduce_sum = 0; + for (int index = 0; index < THREADS_PER_BLOCK; index++) { + reduce_sum += prod_sum[index]; + } + const int index2 = + n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size); + grad_input2[index2] = static_cast(reduce_sum / nelems); + } +} + +template +class CorrelationCUDAGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, + platform::errors::InvalidArgument( + "Correlation only supports GPU now.")); + const auto *input1 = ctx.Input("Input1"); + const auto *input2 = ctx.Input("Input2"); + const auto *grad_output = + ctx.Input(framework::GradVarName("Output")); + const int pad_size = ctx.Attr("pad_size"); + const int kernel_size = ctx.Attr("kernel_size"); + const int stride1 = ctx.Attr("stride1"); + const int stride2 = ctx.Attr("stride2"); + const int max_displacement = ctx.Attr("max_displacement"); + const int corr_type_multiply = ctx.Attr("corr_type_multiply"); + + auto *grad_input1 = ctx.Output(framework::GradVarName("Input1")); + grad_input1->mutable_data(ctx.GetPlace()); + auto *grad_input2 = ctx.Output(framework::GradVarName("Input2")); + grad_input2->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + + auto in_dims = input1->dims(); + int N = in_dims[0]; + int C = in_dims[1]; + int H = in_dims[2]; + int W = in_dims[3]; + + int padded_input_height = H + 2 * pad_size; + int padded_input_width = W + 2 * pad_size; + + Tensor rinput1 = ctx.AllocateTmpTensor( + {N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput1.mutable_data(ctx.GetPlace()); + + Tensor rinput2 = ctx.AllocateTmpTensor( + {N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput2.mutable_data(ctx.GetPlace()); + + set_zero<<<(rinput1.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>( + rinput1.data(), rinput1.numel()); + set_zero<<<(rinput2.numel() + 512 - 1) / 512, 512, 0, dev_ctx.stream()>>>( + rinput2.data(), rinput2.numel()); + set_zero<<<(grad_input1->numel() + 512 - 1) / 512, 512, 0, + dev_ctx.stream()>>>(grad_input1->data(), + grad_input1->numel()); + set_zero<<<(grad_input2->numel() + 512 - 1) / 512, 512, 0, + dev_ctx.stream()>>>(grad_input2->data(), + grad_input2->numel()); + + auto grad_out_dims = grad_output->dims(); + int GOC = grad_out_dims[1]; + int GOH = grad_out_dims[2]; + int GOW = grad_out_dims[3]; + + dim3 blocks_grid(N, H, W); + dim3 threads_block(THREADS_PER_BLOCK); + + channel_first<<>>( + input1->data(), rinput1.data(), C, H, W, pad_size); + channel_first<<>>( + input2->data(), rinput2.data(), C, H, W, pad_size); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(H, W, C); + + for (int n = 0; n < N; n++) { + correlation_backward_input1< + T><<>>( + n, grad_input1->data(), C, H, W, grad_output->data(), GOC, GOH, + GOW, rinput2.data(), pad_size, kernel_size, max_displacement, + stride1, stride2); + } + + for (int n = 0; n < N; n++) { + correlation_backward_input2< + T><<>>( + n, grad_input2->data(), C, H, W, grad_output->data(), GOC, GOH, + GOW, rinput1.data(), pad_size, kernel_size, max_displacement, + stride1, stride2); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(correlation, ops::CorrelationCUDAKernel, + ops::CorrelationCUDAKernel); +REGISTER_OP_CUDA_KERNEL(correlation_grad, ops::CorrelationCUDAGradKernel, + ops::CorrelationCUDAGradKernel); diff --git a/python/paddle/fluid/contrib/layers/nn.py b/python/paddle/fluid/contrib/layers/nn.py index 0e187d4174..7b564b3f83 100644 --- a/python/paddle/fluid/contrib/layers/nn.py +++ b/python/paddle/fluid/contrib/layers/nn.py @@ -37,7 +37,7 @@ import warnings import inspect import numpy as np - +import paddle from paddle.fluid.layer_helper import LayerHelper from paddle.fluid.layers import utils from ... import unique_name @@ -56,7 +56,8 @@ __all__ = [ 'match_matrix_tensor', 'tree_conv', 'fused_embedding_seq_pool', 'multiclass_nms2', 'search_pyramid_hash', 'shuffle_batch', 'partial_concat', 'sparse_embedding', 'partial_sum', 'tdm_child', 'rank_attention', - 'tdm_sampler', 'batch_fc', '_pull_box_extended_sparse', 'bilateral_slice' + 'tdm_sampler', 'batch_fc', '_pull_box_extended_sparse', 'bilateral_slice', + 'correlation' ] @@ -1546,3 +1547,81 @@ def bilateral_slice(x, guide, grid, has_offset, name=None): attrs={'has_offset': has_offset}, outputs={'Out': out}) return out + + +def correlation(x, + y, + pad_size, + kernel_size, + max_displacement, + stride1, + stride2, + corr_type_multiply=1): + """ + + This operation compute correlation of two tensor. + For more information of correlation, please refer to PWC-Net: + CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume + _ + + Args: + x(Tensor): The input x is 4-D Tensor with shape [N, C, H, W]. The data type is float32 and float64. + y(Tensor): The input y is 4-D Tensor with shape [N, C, H, W]. The data type is float32 and float64. + pad_size(int): Pad size. The data type is int. + max_displacement(int): Max displacement. The data type is int. + stride1(int): stride size of x. The data type is int. + stride2(int): stride size of y. The data type is int. + corr_type_multiply(int, optional): The type of multiply. The data type is int. Default: 1. + + Returns: + Tensor: The data type is same as input tensor. + + Examples: + + .. code-block:: python + + import paddle.fluid as fluid + + x1 = fluid.layers.data(name='x1', + shape=x_shape, + dtype=x_type, + append_batch_size=False) + x2 = fluid.layers.data(name='x2', + shape=x_shape, + dtype=x_type, + append_batch_size=False) + + + out = fluid.contrib.correlation( + x1, + x2, + pad_size=4, + kernel_size=1, + max_displacement=4, + stride1=1, + stride2=1) + + """ + + helper = LayerHelper("correlation", **locals()) + output = helper.create_variable_for_type_inference(dtype=x.dtype) + if paddle.fluid.in_dygraph_mode(): + attrs = ("pad_size", pad_size, "kernel_size", kernel_size, + "max_displacement", max_displacement, "stride1", stride1, + "stride2", stride2, "corr_type_multiply", corr_type_multiply) + output = getattr(core.ops, "correlation")(x, y, *attrs) + else: + helper.append_op( + type="correlation", + inputs={"Input1": x, + "Input2": y}, + attrs={ + "pad_size": pad_size, + "kernel_size": kernel_size, + "max_displacement": max_displacement, + "stride1": stride1, + "stride2": stride2, + "corr_type_multiply": corr_type_multiply + }, + outputs={"Output": output}) + return output diff --git a/python/paddle/fluid/contrib/tests/test_correlation.py b/python/paddle/fluid/contrib/tests/test_correlation.py new file mode 100644 index 0000000000..7fcef4dbcd --- /dev/null +++ b/python/paddle/fluid/contrib/tests/test_correlation.py @@ -0,0 +1,163 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable + + +def corr(x_1, + x_2, + pad_size=4, + kernel_size=1, + max_displacement=4, + stride1=1, + stride2=1, + corr_multiply=1): + K = kernel_size + + rinput1 = np.pad(x_1, ((0, 0), (0, 0), (pad_size, pad_size), + (pad_size, pad_size)), + mode='constant') + rinput2 = np.pad(x_2, ((0, 0), (0, 0), (pad_size, pad_size), + (pad_size, pad_size)), + mode='constant') + rinput1 = np.transpose(rinput1, (0, 2, 3, 1)) + rinput2 = np.transpose(rinput2, (0, 2, 3, 1)) + B = int(rinput1.shape[0]) + H = int(x_1.shape[2]) + W = int(x_2.shape[3]) + d = max_displacement + D = 2 * d + 1 + output = np.zeros((B, D * D, H, W), dtype=np.float32) + + for b in range(B): + for i in range(H): + for j in range(W): + for k in range(-d, d + 1): + for l in range(-d, d + 1): + x1_index = i + pad_size + y1_index = j + pad_size + x2_index = x1_index + k + y2_index = y1_index + l + output[b, l + d + D * (k + d), i, j] = np.mean( + rinput1[b, x1_index:x1_index + K, y1_index:y1_index + + K] * rinput2[b, x2_index:x2_index + K, + y2_index:y2_index + K]) + + return output + + +class TestCorrelationOp(unittest.TestCase): + def test_check_output(self): + if not fluid.core.is_compiled_with_cuda(): + return + np.random.seed(13) + np.set_printoptions(threshold=np.inf) + x_shape = (2, 10, 3, 3) + x_type = 'float32' + x1 = fluid.layers.data( + name='x1', + shape=x_shape, + dtype=x_type, + append_batch_size=False, + stop_gradient=False) + x2 = fluid.layers.data( + name='x2', + shape=x_shape, + dtype=x_type, + append_batch_size=False, + stop_gradient=False) + + x1_np = np.random.randn(2, 3, 4, 5).astype(x_type) + x2_np = np.random.randn(2, 3, 4, 5).astype(x_type) + out_np = corr( + x1_np, + x2_np, + pad_size=4, + kernel_size=1, + max_displacement=4, + stride1=1, + stride2=1) + + out = fluid.contrib.correlation( + x1, + x2, + pad_size=4, + kernel_size=1, + max_displacement=4, + stride1=1, + stride2=1) + + loss = fluid.layers.reduce_mean(out) + optimizer = fluid.optimizer.Momentum(0.0001, 0.9) + optimizer.minimize(loss) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + res = exe.run(feed={'x1': x1_np, + 'x2': x2_np}, + fetch_list=[out.name, loss.name]) + + self.assertTrue(np.allclose(res[0], out_np)) + + +class Net(fluid.dygraph.Layer): + def __init__(self, name_scope): + super(Net, self).__init__(name_scope) + + def forward(self, x1, x2): + y = fluid.contrib.correlation( + x1, + x2, + pad_size=4, + kernel_size=1, + max_displacement=4, + stride1=1, + stride2=1) + return y + + +class TestCorrelationOpDyGraph(unittest.TestCase): + def test_check_output(self): + if not fluid.core.is_compiled_with_cuda(): + return + np.random.seed(13) + np.set_printoptions(threshold=np.inf) + x_shape = (2, 10, 3, 3) + x_type = 'float32' + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + x1_np = np.random.randn(2, 3, 4, 5).astype(x_type) + x2_np = np.random.randn(2, 3, 4, 5).astype(x_type) + out_np = corr( + x1_np, + x2_np, + pad_size=4, + kernel_size=1, + max_displacement=4, + stride1=1, + stride2=1) + + x1 = to_variable(x1_np) + x2 = to_variable(x2_np) + corr_pd = Net('corr_pd') + y = corr_pd(x1, x2) + out = y.numpy() + self.assertTrue(np.allclose(out, out_np)) + + +if __name__ == '__main__': + unittest.main() -- GitLab