diff --git a/PaddleCV/Research/PWCNet/AverageMeter.py b/PaddleCV/Research/PWCNet/AverageMeter.py new file mode 100644 index 0000000000000000000000000000000000000000..633e6c067d465559d2da61913342da2e521ac731 --- /dev/null +++ b/PaddleCV/Research/PWCNet/AverageMeter.py @@ -0,0 +1,18 @@ + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count diff --git a/PaddleCV/Research/PWCNet/README.md b/PaddleCV/Research/PWCNet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..b3335013b641836c47b61dd31f8a6f5459188254 --- /dev/null +++ b/PaddleCV/Research/PWCNet/README.md @@ -0,0 +1,86 @@ +# PWCNet reimplement using paddlepaddle DyGraph +PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume. +# Environment +``` +cenntos7 +paddle develop version (after 20191201) install from source +python3.7 +SciPy 1.1.0 +``` +code will update for paddle v1.7 later. +# Compile correlation op +``` +cd correlation_op +sh make.sh +``` +# Datasets +1.Please download the `FlyingChairs dataset` and `FlyingChairs_train_val.txt` from https://lmb.informatik.uni-freiburg.de/resources/datasets + +Or you can use `./data/download.sh` to download datasets. + +We split the data to train and val by using `FlyingChairs_train_val.txt` with `1 for train and 2 for val`. +# Inference +Note that the paddle models `pwc_net_paddle.pdparams` and `pwc_net_chairs_paddle.pdparams` are transferred from the pytorch pth files `pwc_net.pth.tar` and `pwc_net_chairs.pth.tar`. + +Run +``` +python infer.py +``` + +| Input img1 | Input img2 | +|-------|------------| +| | | + +|prediction with pwc_net_paddle.pdparams| prediction with pwc_net_chairs_paddle.pdparams| +|-------------|-------------| +| | | + +# First Train with L2 loss +A single gpu is supported. Multi gpus will be supported later. + +You should check parameters in `my_args.py` as you like. + +And change them in `train.sh`. +``` +--data_root +--train_val_txt +--batch_size +``` +Then run +``` +./train.sh +``` +Some results during training can be seen +``` +./img1.png +./img2.png +./hsv_pd.png # ground truth +./hsv_predict.png # output of model +``` + +# Finetune with L1 loss +finetune from your best pretrain model by adding --pretrained your_best_model_name eg. `--pretrained epoch_7_pwc_net_paddle` + +Run +``` +./finetune.sh +``` +# Note +This code reimplement PWCNet like the code of `https://github.com/NVlabs/PWC-Net` +If you want to want to train like the paper +``` +@InProceedings{Sun2018PWC-Net, + author = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz}, + title = {{PWC-Net}: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume}, + booktitle = CVPR, + year = {2018}, +} +``` +Please use all the datasets in `./data/download.sh` if you like. And use the code in `./data/datasets.py`. + +Reference works +``` +https://github.com/NVlabs/PWC-Net +https://github.com/ClementPinard/FlowNetPytorch +https://github.com/NVIDIA/flownet2-pytorch/blob/master/datasets.py +``` \ No newline at end of file diff --git a/PaddleCV/Research/PWCNet/__init__.py b/PaddleCV/Research/PWCNet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PaddleCV/Research/PWCNet/correlation_op/README.md b/PaddleCV/Research/PWCNet/correlation_op/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d83c6fe61d6fef1d01139289b69605628e689d72 --- /dev/null +++ b/PaddleCV/Research/PWCNet/correlation_op/README.md @@ -0,0 +1,14 @@ +自定义OP编译: +1. 使用paddle develop 12月1日之后的版本 +2. sh make.sh编译成correlation_lib.so动态库 +3. 添加动态库路径到LD_LIBRARY_PATH: +``` +export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python3.7 -c 'import paddle; print(paddle.sysconfig.get_lib())'` +``` +4. 添加correlation op的python路径: +``` +export PYTHONPATH=$PYTHONPATH:`pwd` +``` +5. python test_correlation.py运行单测,验证是否加载成功。 + +PS: 如果paddle whl包是从官网上下载的,需要使用gcc 4.8,即把make.sh中的g++ 改为 g++-4.8 diff --git a/PaddleCV/Research/PWCNet/correlation_op/correlation.py b/PaddleCV/Research/PWCNet/correlation_op/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..05e9267d1fcb51344e096592ad86d22223b99f75 --- /dev/null +++ b/PaddleCV/Research/PWCNet/correlation_op/correlation.py @@ -0,0 +1,25 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid +import os +file_dir = os.path.dirname(os.path.abspath(__file__)) +fluid.load_op_library(os.path.join(file_dir, 'correlation_lib.so')) + +from paddle.fluid.layer_helper import LayerHelper + +def correlation(input1, input2, pad_size, kernel_size, max_displacement, stride1, stride2, corr_type_multiply=1): + helper = LayerHelper("correlation", **locals()) + output = helper.create_variable_for_type_inference(dtype=input1.dtype) + helper.append_op(type="correlation", inputs={"Input1": input1, "Input2": input2}, attrs={"pad_size": pad_size, "kernel_size": kernel_size, "max_displacement": max_displacement, "stride1": stride1, "stride2": stride2, "corr_type_multiply": corr_type_multiply}, outputs = {"Output": output}) + return output diff --git a/PaddleCV/Research/PWCNet/correlation_op/correlation_op.cc b/PaddleCV/Research/PWCNet/correlation_op/correlation_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4902db3ed7115d0d315ae2f2cbab5ea1a5ee6528 --- /dev/null +++ b/PaddleCV/Research/PWCNet/correlation_op/correlation_op.cc @@ -0,0 +1,140 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +inline std::vector CorrelationOutputSize(int batch, int input_height, int input_width, int stride1, int stride2, int kernel_size, int pad_size, int max_displacement) { + + std::vector output_shape({batch}); + int kernel_radius = (kernel_size - 1) / 2; + int border_radius = kernel_radius + max_displacement; + int padded_input_height = input_height + 2 * pad_size; + int padded_input_width = input_width + 2 * pad_size; + int output_channel = ((max_displacement/stride2) * 2 + 1) * ((max_displacement/stride2) * 2 + 1); + output_shape.push_back(output_channel); + int output_height = std::ceil(static_cast(padded_input_height - 2 * border_radius) / static_cast(stride1)); + int output_width = std::ceil(static_cast(padded_input_width - 2 * border_radius) / static_cast(stride1)); + output_shape.push_back(output_height); + output_shape.push_back(output_width); + return output_shape; +} + +class CorrelationOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override{ + AddInput("Input1", "input1"); + AddInput("Input2", "input2"); + AddOutput("Output", "output"); + AddAttr("pad_size", "pad size for input1 and input2"); + AddAttr("kernel_size", "kernel size of input1 and input2"); + AddAttr("max_displacement", "max displacement of input1 and input2"); + AddAttr("stride1", "Input1 stride"); + AddAttr("stride2", "Input2 stride"); + AddAttr("corr_type_multiply", "correlation coefficient").SetDefault(1); + AddComment(R"DOC(Correlation of two feature map. Only support NCHW data format.)DOC"); + } +}; + +class CorrelationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{ + PADDLE_ENFORCE_EQ(ctx->HasInput("Input1"), true, "Input(input1) cannot be null"); + PADDLE_ENFORCE_EQ(ctx->HasInput("Input2"), true, "Input(input2) cannot be null"); + int stride1 = ctx->Attrs().Get("stride1"); + int stride2 = ctx->Attrs().Get("stride2"); + int max_displacement = ctx->Attrs().Get("max_displacement"); + int pad_size = ctx->Attrs().Get("pad_size"); + int kernel_size = ctx->Attrs().Get("kernel_size"); + + auto in_dims = ctx->GetInputDim("Input1"); + auto in2_dims = ctx->GetInputDim("Input2"); + PADDLE_ENFORCE_EQ(in_dims.size() == 4, true, "input1 must be 4-dims"); + PADDLE_ENFORCE_EQ(in2_dims.size() == 4, true, "input2 must be 4-dims"); + std::vector output_shape = CorrelationOutputSize(in_dims[0], in_dims[2], in_dims[3], stride1, stride2, kernel_size, pad_size, max_displacement); + ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override{ + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input1"); + PADDLE_ENFORCE_EQ(input_data_type, ctx.Input("Input2")->type(), "Input1 and Input2 shoule have same type"); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +template +class CorrelationOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + std::unique_ptr Apply() const override { + auto* op = new T(); + op->SetType("correlation_grad"); + op->SetInput("Input1", this->Input("Input1")); + op->SetInput("Input2", this->Input("Input2")); + op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output")); + op->SetOutput(framework::GradVarName("Input1"), this->InputGrad("Input1")); + op->SetOutput(framework::GradVarName("Input2"), this->InputGrad("Input2")); + op->SetAttrMap(this->Attrs()); + + return std::unique_ptr(op); + } +}; + +class CorrelationOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override{ + PADDLE_ENFORCE_EQ(ctx->HasInput("Input1"), true, "Input(Input1) should not be null"); + PADDLE_ENFORCE_EQ(ctx->HasInput("Input2"), true, "Input(Input2) should not be null"); + PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Output")), true, "Input(Output@GRAD) should not be null"); + + auto in1_dims = ctx->GetInputDim("Input1"); + auto in2_dims = ctx->GetInputDim("Input2"); + ctx->SetOutputDim(framework::GradVarName("Input1"), in1_dims); + ctx->SetOutputDim(framework::GradVarName("Input2"), in1_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override{ + const auto* var = ctx.InputVar(framework::GradVarName("Output")); + if (var == nullptr) { + PADDLE_THROW("cannot find Output@GRAD"); + } + return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(correlation, ops::CorrelationOp, ops::CorrelationOpMaker, + ops::CorrelationOpGradMaker, + ops::CorrelationOpGradMaker); +REGISTER_OPERATOR(correlation_grad, ops::CorrelationOpGrad); diff --git a/PaddleCV/Research/PWCNet/correlation_op/correlation_op.cu b/PaddleCV/Research/PWCNet/correlation_op/correlation_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..161844430fe4b9dfeaf80dbe127d802d67a6de76 --- /dev/null +++ b/PaddleCV/Research/PWCNet/correlation_op/correlation_op.cu @@ -0,0 +1,434 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" + +#define THREADS_PER_BLOCK 32 +#define FULL_MASK 0xffffffff + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +__forceinline__ __device__ T warpReduceSum(T val) { + for (int offset = 16; offset > 0; offset /= 2) { + val += __shfl_down_sync(FULL_MASK, val, offset); + } + return val; +} + +template +__forceinline__ __device__ T blockReduceSum(T val) { + static __shared__ T shared[32]; + int lane = threadIdx.x % warpSize; + int wid = threadIdx.x / warpSize; + + val = warpReduceSum(val); + if (lane == 0) + shared[wid] = val; + + __syncthreads(); + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0; + + if (wid == 0) + val = warpReduceSum(val); + + return val; +} + +template +__global__ void set_zero(T *x, int num) { + for(int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x) + x[i] = static_cast(0); +} + +template +__global__ void channel_first(const T *input, T *rinput, const int channel, const int height, const int width, const int pad_size) { + int n = blockIdx.x; + int h = blockIdx.y; + int w = blockIdx.z; + + int ch_off = threadIdx.x; + T value; + int dimchw = channel * height * width; + int dimhw = height * width; + + int p_dimw = (width + 2 * pad_size); + int p_dimh = (height + 2 * pad_size); + int p_dimchw = channel * p_dimw * p_dimh; + int p_dimcw = channel * p_dimw; + + for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) { + value = input[n * dimchw + c * dimhw + h * width + w]; + rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel + c] = value; + } +} + +template +__global__ void correlation_forward(T *output, const int output_channel, const int output_height, const int output_width, const T *rinput1, const int input_channel, const int input_height, const int input_width, const T *rinput2, const int pad_size, const int kernel_size, const int max_displacement, const int stride1, const int stride2) { + + int p_input_width = input_width + 2 * pad_size; + int p_input_height = input_height + 2 * pad_size; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + + int displacement_size = 2 * displacement_rad + 1; + + int n = blockIdx.x; + int h1 = blockIdx.y * stride1 + max_displacement; + int w1 = blockIdx.z * stride1 + max_displacement; + int c = threadIdx.x; + + int p_dimchw = p_input_height * p_input_width * input_channel; + int p_dimcw = p_input_width * input_channel; + int p_dimc = input_channel; + + int t_dimchw = output_channel * output_height * output_width; + int t_dimhw = output_height * output_width; + int t_dimw = output_width; + + int nelems = kernel_size * kernel_size * p_dimc; + + for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) { + for(int ti = -displacement_rad; ti <= displacement_rad; ++ti) { + int w2 = w1 + ti * stride2; + int h2 = h1 + tj * stride2; + + T acc0 = 0; + for(int j = -kernel_rad; j <= kernel_rad; ++j) { + for(int i = -kernel_rad; i <= kernel_rad; ++i) { + for(int ch = c; ch < p_dimc; ch += blockDim.x) { + int index1 = n * p_dimchw + (h1 + j) * p_dimcw + (w1 + i) * p_dimc + ch; + int index2 = n * p_dimchw + (h2 + j) * p_dimcw + (w2 + i) * p_dimc + ch; + acc0 += static_cast(rinput1[index1] * rinput2[index2]); + } + } + } + if (blockDim.x == warpSize) { + __syncwarp(); + acc0 = warpReduceSum(acc0); + } else { + __syncthreads(); + acc0 = blockReduceSum(acc0); + } + + if (threadIdx.x == 0) { + int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad); + const int t_index = n * t_dimchw + tc * t_dimhw + blockIdx.y * t_dimw + blockIdx.z; + output[t_index] = static_cast(acc0 / nelems); + } + } + } + +} + +//class CorrelationKernel +template +class CorrelationKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must be CUDAPlace"); + + auto *input1 = ctx.Input("Input1"); + auto *input2 = ctx.Input("Input2"); + int pad_size = ctx.Attr("pad_size"); + int kernel_size = ctx.Attr("kernel_size"); + int stride1 = ctx.Attr("stride1"); + int stride2 = ctx.Attr("stride2"); + int max_displacement = ctx.Attr("max_displacement"); + int corr_type_multiply = ctx.Attr("corr_type_multiply"); + + auto *output = ctx.Output("Output"); + output->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + + // base on input1, NCHW + auto in_dims = input1->dims(); + int N = in_dims[0]; + int C = in_dims[1]; + int H = in_dims[2]; + int W = in_dims[3]; + + int padded_input_height = H + 2 * pad_size; + int padded_input_width = W + 2 * pad_size; + + Tensor rinput1 = ctx.AllocateTmpTensor({N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput1.mutable_data(ctx.GetPlace()); + + Tensor rinput2 = ctx.AllocateTmpTensor({N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput2.mutable_data(ctx.GetPlace()); + + set_zero<<<(rinput1.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput1.data(), rinput1.numel()); + set_zero<<<(rinput2.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput2.data(), rinput2.numel()); + set_zero<<<(output->numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(output->data(), output->numel()); + + auto out_dims = output->dims(); + int OC = out_dims[1]; + int OH = out_dims[2]; + int OW = out_dims[3]; + + dim3 blocks_grid(N, H, W); + dim3 threads_block(THREADS_PER_BLOCK); + + channel_first<<>>(input1->data(), rinput1.data(), C, H, W, pad_size); + channel_first<<>>(input2->data(), rinput2.data(), C, H, W, pad_size); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(N, OH, OW); + + correlation_forward<<>>(output->data(), OC, OH, OW, rinput1.data(), +C, H, W, rinput2.data(), pad_size, kernel_size, max_displacement, stride1, stride2); + } +}; + +template +__global__ void correlation_backward_input1(int item, T *grad_input1, const int input_channel, const int input_height, const int input_width, const T *grad_output, const int output_channel, const int output_height, const int output_width, const T *rinput2, const int pad_size, const int kernel_size, const int max_displacement, const int stride1, const int stride2) { + + int n = item; + int h = blockIdx.x * stride1 + pad_size; + int w = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int xmin = (w - kernel_rad - max_displacement) / stride1; + int ymin = (h - kernel_rad - max_displacement) / stride1; + + int xmax = (w + kernel_rad - max_displacement) / stride1; + int ymax = (h + kernel_rad - max_displacement) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) { + return; + } + + if (xmin > xmax || ymin > ymax) { + return; + } + + xmin = max(0, xmin); + xmax = min(output_width - 1, xmax); + + ymin = max(0, ymin); + ymax = min(output_height - 1, ymax); + + int p_input_width = input_width + 2 * pad_size; + int p_input_height = input_height + 2 * pad_size; + int p_dimchw = input_channel * p_input_height * p_input_width; + int p_dimcw = input_channel * p_input_width; + int p_dimc = input_channel; + + int t_dimchw = output_channel * output_height * output_width; + int t_dimhw = output_height * output_width; + int t_dimw = output_width; + + int o_dimchw = input_channel * input_height * input_width; + int o_dimhw = input_height * input_width; + int o_dimw = input_width; + + int nelems = kernel_size * kernel_size * input_channel; + + __shared__ T prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int index2 = n * p_dimchw + (h + j2) * p_dimcw + (w + i2) * p_dimc + c; + + T val2 = rinput2[index2]; + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i; + prod_sum[tch_off] += grad_output[t_index] * val2; + } + } + } + + __syncthreads(); + + if (tch_off == 0) { + T reduce_sum = 0; + for (int index = 0; index < THREADS_PER_BLOCK; index++) { + reduce_sum += prod_sum[index]; + } + const int index1 = n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size); + grad_input1[index1] = static_cast(reduce_sum / nelems); + } + +} + +template +__global__ void correlation_backward_input2(int item, T *grad_input2, const int input_channel, const int input_height, const int input_width, const T *grad_output, const int output_channel, const int output_height, const int output_width, const T *rinput1, const int pad_size, const int kernel_size, const int max_displacement, const int stride1, const int stride2){ + + int n = item; + int h = blockIdx.x * stride1 + pad_size; + int w = blockIdx.y * stride1 + pad_size; + int c = blockIdx.z; + + int tch_off = threadIdx.x; + + int kernel_rad = (kernel_size - 1) / 2; + int displacement_rad = max_displacement / stride2; + int displacement_size = 2 * displacement_rad + 1; + + int p_input_width = input_width + 2 * pad_size; + int p_input_height = input_height + 2 * pad_size; + int p_dimchw = input_channel * p_input_height * p_input_width; + int p_dimcw = input_channel * p_input_width; + int p_dimc = input_channel; + + int t_dimchw = output_channel * output_height * output_width; + int t_dimhw = output_height * output_width; + int t_dimw = output_width; + + int o_dimchw = input_channel * input_height * input_width; + int o_dimhw = input_height * input_width; + int o_dimw = input_width; + + int nelems = kernel_size * kernel_size * input_channel; + + __shared__ T prod_sum[THREADS_PER_BLOCK]; + prod_sum[tch_off] = 0; + + for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) { + int i2 = (tc % displacement_size - displacement_rad) * stride2; + int j2 = (tc / displacement_size - displacement_rad) * stride2; + + int xmin = (w - kernel_rad - max_displacement - i2) / stride1; + int ymin = (h - kernel_rad - max_displacement - j2) / stride1; + + int xmax = (w + kernel_rad - max_displacement - i2) / stride1; + int ymax = (h + kernel_rad - max_displacement - j2) / stride1; + + if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) { + continue; + } + + if (xmin > xmax || ymin > ymax) { + continue; + } + + xmin = max(0, xmin); + xmax = min(output_width - 1, xmax); + + ymin = max(0, ymin); + ymax = min(output_height - 1, ymax); + + int index1 = n * p_dimchw + (h - j2) * p_dimcw + (w - i2) * p_dimc + c; + T val1 = rinput1[index1]; + for (int j = ymin; j <= ymax; ++j) { + for (int i = xmin; i <= xmax; ++i) { + int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i; + prod_sum[tch_off] += grad_output[t_index] * val1; + } + } + } + + __syncthreads(); + + if (tch_off == 0) { + T reduce_sum = 0; + for (int index = 0; index < THREADS_PER_BLOCK; index++) { + reduce_sum += prod_sum[index]; + } + const int index2 = n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size); + grad_input2[index2] = static_cast(reduce_sum / nelems); + } +} + +template +class CorrelationGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must use CUDAPlace."); + const auto *input1 = ctx.Input("Input1"); + const auto *input2 = ctx.Input("Input2"); + const auto *grad_output = ctx.Input(framework::GradVarName("Output")); + const int pad_size = ctx.Attr("pad_size"); + const int kernel_size = ctx.Attr("kernel_size"); + const int stride1 = ctx.Attr("stride1"); + const int stride2 = ctx.Attr("stride2"); + const int max_displacement = ctx.Attr("max_displacement"); + const int corr_type_multiply = ctx.Attr("corr_type_multiply"); + + auto *grad_input1 = ctx.Output(framework::GradVarName("Input1")); + grad_input1->mutable_data(ctx.GetPlace()); + auto *grad_input2 = ctx.Output(framework::GradVarName("Input2")); + grad_input2->mutable_data(ctx.GetPlace()); + auto &dev_ctx = ctx.template device_context(); + + auto in_dims = input1->dims(); + int N = in_dims[0]; + int C = in_dims[1]; + int H = in_dims[2]; + int W = in_dims[3]; + + int padded_input_height = H + 2 * pad_size; + int padded_input_width = W + 2 * pad_size; + + Tensor rinput1 = ctx.AllocateTmpTensor({N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput1.mutable_data(ctx.GetPlace()); + + Tensor rinput2 = ctx.AllocateTmpTensor({N, padded_input_height, padded_input_width, C}, dev_ctx); + rinput2.mutable_data(ctx.GetPlace()); + + set_zero<<<(rinput1.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput1.data(), rinput1.numel()); + set_zero<<<(rinput2.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput2.data(), rinput2.numel()); + set_zero<<<(grad_input1->numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(grad_input1->data(), grad_input1->numel()); + set_zero<<<(grad_input2->numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(grad_input2->data(), grad_input2->numel()); + + auto grad_out_dims = grad_output->dims(); + int GOC = grad_out_dims[1]; + int GOH = grad_out_dims[2]; + int GOW = grad_out_dims[3]; + + dim3 blocks_grid(N, H, W); + dim3 threads_block(THREADS_PER_BLOCK); + + channel_first<<>>(input1->data(), rinput1.data(), C, H, W, pad_size); + channel_first<<>>(input2->data(), rinput2.data(), C, H, W, pad_size); + + dim3 threadsPerBlock(THREADS_PER_BLOCK); + dim3 totalBlocksCorr(H, W, C); + + for (int n = 0; n < N; n++) { + correlation_backward_input1<<>>(n, grad_input1->data(), C, H, W, grad_output->data(), GOC, GOH, GOW, rinput2.data(), pad_size, kernel_size, max_displacement, stride1, stride2); + } + + for (int n = 0; n < N; n++) { + correlation_backward_input2<<>>(n, grad_input2->data(), C, H, W, grad_output->data(), GOC, GOH, GOW, rinput1.data(), pad_size, kernel_size, max_displacement, stride1, stride2); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + correlation, ops::CorrelationKernel, + ops::CorrelationKernel); +REGISTER_OP_CUDA_KERNEL( + correlation_grad, ops::CorrelationGradKernel, + ops::CorrelationGradKernel); + diff --git a/PaddleCV/Research/PWCNet/correlation_op/make.sh b/PaddleCV/Research/PWCNet/correlation_op/make.sh new file mode 100644 index 0000000000000000000000000000000000000000..0aa8deb6b3db2908838dbba10b976e37979bf231 --- /dev/null +++ b/PaddleCV/Research/PWCNet/correlation_op/make.sh @@ -0,0 +1,22 @@ +include_dir=$( python3.7 -c 'import paddle; print(paddle.sysconfig.get_include())' ) +lib_dir=$( python3.7 -c 'import paddle; print(paddle.sysconfig.get_lib())' ) + +echo $include_dir +echo $lib_dir + +OPS='correlation_op' +for op in ${OPS} +do +nvcc ${op}.cu -c -o ${op}.cu.o -ccbin cc -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -DPADDLE_WITH_MKLDNN -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O0 -g -DNVCC \ + -I ${include_dir}/third_party/ \ + -I ${include_dir} +done + +##g++-4.8 correlation_op.cu.o correlation_op.cc -o correlation_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \ +g++ correlation_op.cu.o correlation_op.cc -o correlation_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \ + -I ${include_dir}/third_party/ \ + -I ${include_dir} \ + -L ${lib_dir} \ + -L /usr/local/cuda/lib64 -lpaddle_framework -lcudart + +rm *.cu.o diff --git a/PaddleCV/Research/PWCNet/correlation_op/test_correlation.py b/PaddleCV/Research/PWCNet/correlation_op/test_correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..89e254adafe41465be93f98cef837cc6514bf9db --- /dev/null +++ b/PaddleCV/Research/PWCNet/correlation_op/test_correlation.py @@ -0,0 +1,88 @@ +import unittest +from correlation import correlation +import numpy as np +import paddle.fluid as fluid +from paddle.fluid.dygraph.base import to_variable + +def corr(x_1, x_2, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1): + K = kernel_size + # rinput1 = np.pad(x_1, tuple([pad_size for _ in range(4)]), mode='constant').transpose(1, 2).transpose(2, 3) + # rinput2 = np.pad(x_2, tuple([pad_size for _ in range(4)]), mode='constant').transpose(1, 2).transpose(2, 3) + + rinput1 = np.pad(x_1, ((0, 0), (0, 0), (pad_size, pad_size), (pad_size, pad_size)), mode='constant') + rinput2 = np.pad(x_2, ((0, 0), (0, 0), (pad_size, pad_size), (pad_size, pad_size)), mode='constant') + rinput1 = np.transpose(rinput1, (0, 2, 3, 1)) + rinput2 = np.transpose(rinput2, (0, 2, 3, 1)) + B = int(rinput1.shape[0]) + H = int(x_1.shape[2]) + W = int(x_2.shape[3]) + d = max_displacement + D = 2 * d + 1 + output = np.zeros((B, D * D, H, W), dtype=np.float32) + + for b in range(B): + for i in range(H): + for j in range(W): + for k in range(-d, d + 1): + for l in range(-d, d + 1): + x1_index = i + pad_size + y1_index = j + pad_size + x2_index = x1_index + k + y2_index = y1_index + l + output[b, l + d + D * (k + d), i, j] = np.mean( + rinput1[b, x1_index:x1_index + K, y1_index:y1_index + K] * rinput2[b, + x2_index:x2_index + K, + y2_index:y2_index + K]) + + return output + +class TestCorrelationOp(unittest.TestCase): + def test_check_output(self): + #x_shape = (1, 196, 3, 3) + np.random.seed(13) + np.set_printoptions(threshold=np.inf) + x_shape = (2, 10, 3, 3) + x_type = 'float32' + x1 = fluid.layers.data(name='x1', shape=x_shape, dtype=x_type, append_batch_size=False) + x2 = fluid.layers.data(name='x2', shape=x_shape, dtype=x_type, append_batch_size=False) + + x1_np = np.random.randn(2,3,4,5).astype(x_type) + x2_np = np.random.randn(2,3,4,5).astype(x_type) + out_np = corr(x1_np, x2_np, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1) + + out = correlation(x1, x2, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1) + + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + res = exe.run(feed={'x1':x1_np, 'x2':x2_np}, fetch_list=[out.name]) + + self.assertTrue(np.allclose(res[0], out_np)) + +class Net(fluid.dygraph.Layer): + def __init__(self, name_scope): + super(Net, self).__init__(name_scope) + def forward(self, x1, x2): + y = correlation(x1, x2, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1) + return y + +class TestCorrelationOpDyGraph(unittest.TestCase): + def test_check_output(self): + np.random.seed(13) + np.set_printoptions(threshold=np.inf) + x_shape = (2, 10, 3, 3) + x_type = 'float32' + place = fluid.CUDAPlace(0) + with fluid.dygraph.guard(place): + x1_np = np.random.randn(2,3,4,5).astype(x_type) + x2_np = np.random.randn(2,3,4,5).astype(x_type) + out_np = corr(x1_np, x2_np, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1) + + x1 = to_variable(x1_np) + x2 = to_variable(x2_np) + corr_pd = Net('corr_pd') + y = corr_pd(x1, x2) + out = y.numpy() + self.assertTrue(np.allclose(out, out_np)) + +if __name__ == '__main__': + unittest.main() diff --git a/PaddleCV/Research/PWCNet/data/__init__.py b/PaddleCV/Research/PWCNet/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PaddleCV/Research/PWCNet/data/datasets.py b/PaddleCV/Research/PWCNet/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..080e875df614c6ad8499822b492c85555321b338 --- /dev/null +++ b/PaddleCV/Research/PWCNet/data/datasets.py @@ -0,0 +1,475 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# @FileName: datasets.py reference https://github.com/NVIDIA/flownet2-pytorch/blob/master/datasets.py +import paddle +import paddle.fluid as fluid +import numpy as np +import argparse +import os, math, random +import sys +from os.path import * +import numpy as np +from glob import glob +sys.path.append('../') +import data.utils.frame_utils as frame_utils +from scipy.misc import imsave +from src import flow_vis +from src.read_files import read_txt_to_index + + +class StaticRandomCrop(object): + def __init__(self, image_size, crop_size): + self.th, self.tw = crop_size + h, w = image_size + self.h1 = random.randint(0, h - self.th) + self.w1 = random.randint(0, w - self.tw) + + def __call__(self, img): + return img[self.h1:(self.h1 + self.th), self.w1:(self.w1 + self.tw), :] + + +class StaticCenterCrop(object): + def __init__(self, image_size, crop_size): + self.th, self.tw = crop_size + self.h, self.w = image_size + + def __call__(self, img): + return img[(self.h - self.th) // 2:(self.h + self.th) // 2, (self.w - self.tw) // 2:(self.w + self.tw) // 2, :] + + +class MpiSintel(object): + def __init__(self, args, is_cropped=False, root='', dstype='clean', replicates=1): + self.args = args + self.is_cropped = is_cropped + self.crop_size = args.crop_size + self.render_size = args.inference_size + self.replicates = replicates + + flow_root = join(root, 'flow') + image_root = join(root, dstype) + + file_list = sorted(glob(join(flow_root, '*/*.flo'))) + + self.flow_list = [] + self.image_list = [] + + for file in file_list: + if 'test' in file: + # print file + continue + + fbase = file[len(flow_root) + 1:] + fprefix = fbase[:-8] + fnum = int(fbase[-8:-4]) + + img1 = join(image_root, fprefix + "%04d" % (fnum + 0) + '.png') + img2 = join(image_root, fprefix + "%04d" % (fnum + 1) + '.png') + + if not isfile(img1) or not isfile(img2) or not isfile(file): + continue + + self.image_list += [[img1, img2]] + self.flow_list += [file] + + self.size = len(self.image_list) + + self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape + + if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or ( + self.frame_size[1] % 64): + self.render_size[0] = ((self.frame_size[0]) // 64) * 64 + self.render_size[1] = ((self.frame_size[1]) // 64) * 64 + + args.inference_size = self.render_size + + assert (len(self.image_list) == len(self.flow_list)) + + def __getitem__(self, index): + + index = index % self.size + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = frame_utils.read_gen(self.flow_list[index]) + + images = [img1, img2] + image_size = img1.shape[:2] + + if self.is_cropped: + cropper = StaticRandomCrop(image_size, self.crop_size) + else: + cropper = StaticCenterCrop(image_size, self.render_size) + images = list(map(cropper, images)) + flow = cropper(flow) + + images = np.array(images).transpose(3, 0, 1, 2) + flow = flow.transpose(2, 0, 1) + return [images], [flow] + + def __len__(self): + return self.size * self.replicates + + +class MpiSintelClean(MpiSintel): + def __init__(self, args, is_cropped=False, root='', replicates=1): + super(MpiSintelClean, self).__init__(args, is_cropped=is_cropped, root=root, dstype='clean', + replicates=replicates) + + +class MpiSintelFinal(MpiSintel): + def __init__(self, args, is_cropped=False, root='', replicates=1): + super(MpiSintelFinal, self).__init__(args, is_cropped=is_cropped, root=root, dstype='final', + replicates=replicates) + + +class FlyingChairs(object): + def __init__(self, train_val, args, is_cropped, txt_file, root='/path/to/FlyingChairs_release/data', replicates=1): + self.args = args + self.is_cropped = is_cropped + self.crop_size = args.crop_size + self.render_size = args.inference_size + self.replicates = replicates + + images = sorted(glob(join(root, '*.ppm'))) + + flow_list = sorted(glob(join(root, '*.flo'))) + + assert (len(images) // 2 == len(flow_list)) + + image_list = [] + for i in range(len(flow_list)): + im1 = images[2 * i] + im2 = images[2 * i + 1] + image_list += [[im1, im2]] + + assert len(image_list) == len(flow_list) + if train_val == 'train': + intindex = np.array(read_txt_to_index(txt_file)) + image_list = np.array(image_list) + image_list = image_list[intindex == 1] + image_list = image_list.tolist() + flow_list = np.array(flow_list) + flow_list = flow_list[intindex == 1] + flow_list = flow_list.tolist() + assert len(image_list) == len(flow_list) + elif train_val == 'val': + intindex = np.array(read_txt_to_index(txt_file)) + image_list = np.array(image_list) + image_list = image_list[intindex == 2] + image_list = image_list.tolist() + flow_list = np.array(flow_list) + flow_list = flow_list[intindex == 2] + flow_list = flow_list.tolist() + assert len(image_list) == len(flow_list) + else: + raise ValueError('FlyingChairs_train_val.txt not found for txt_file ......') + self.flow_list = flow_list + self.image_list = image_list + + self.size = len(self.image_list) + + self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape + + if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or ( + self.frame_size[1] % 64): + self.render_size[0] = ((self.frame_size[0]) // 64) * 64 + self.render_size[1] = ((self.frame_size[1]) // 64) * 64 + + args.inference_size = self.render_size + + def __getitem__(self, index): + index = index % self.size + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = frame_utils.read_gen(self.flow_list[index]) + + images = [img1, img2] + image_size = img1.shape[:2] + if self.is_cropped: + cropper = StaticRandomCrop(image_size, self.crop_size) + else: + cropper = StaticCenterCrop(image_size, self.render_size) + images = list(map(cropper, images)) + flow = cropper(flow) + + images = np.array(images).transpose(3, 0, 1, 2) + flow = flow.transpose(2, 0, 1) + return [images], [flow] + + def __len__(self): + return self.size * self.replicates + + +def reader_flyingchairs(dataset): + n = len(dataset) + + def reader(): + for i in range(n): + a, b = dataset[i] + yield a[0][:,0,:,:].transpose(1,2,0), a[0][:,1,:,:].transpose(1,2,0), b[0].transpose(1, 2, 0)# a single entry of data is created each time + return reader + + +class FlyingThings(object): + def __init__(self, args, is_cropped, root='/path/to/flyingthings3d', dstype='frames_cleanpass', replicates=1): + self.args = args + self.is_cropped = is_cropped + self.crop_size = args.crop_size + self.render_size = args.inference_size + self.replicates = replicates + + image_dirs = sorted(glob(join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([join(f, 'left') for f in image_dirs] + [join(f, 'right') for f in image_dirs]) + + flow_dirs = sorted(glob(join(root, 'optical_flow_flo_format/TRAIN/*/*'))) + flow_dirs = sorted( + [join(f, 'into_future/left') for f in flow_dirs] + [join(f, 'into_future/right') for f in flow_dirs]) + + assert (len(image_dirs) == len(flow_dirs)) + + self.image_list = [] + self.flow_list = [] + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(join(idir, '*.png'))) + flows = sorted(glob(join(fdir, '*.flo'))) + for i in range(len(flows)): + self.image_list += [[images[i], images[i + 1]]] + self.flow_list += [flows[i]] + + assert len(self.image_list) == len(self.flow_list) + + self.size = len(self.image_list) + + self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape + + if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or ( + self.frame_size[1] % 64): + self.render_size[0] = ((self.frame_size[0]) // 64) * 64 + self.render_size[1] = ((self.frame_size[1]) // 64) * 64 + + args.inference_size = self.render_size + + def __getitem__(self, index): + index = index % self.size + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = frame_utils.read_gen(self.flow_list[index]) + + images = [img1, img2] + image_size = img1.shape[:2] + if self.is_cropped: + cropper = StaticRandomCrop(image_size, self.crop_size) + else: + cropper = StaticCenterCrop(image_size, self.render_size) + images = list(map(cropper, images)) + flow = cropper(flow) + + images = np.array(images).transpose(3, 0, 1, 2) + flow = flow.transpose(2, 0, 1) + return [images], [flow] + + def __len__(self): + return self.size * self.replicates + + +class FlyingThingsClean(FlyingThings): + def __init__(self, args, is_cropped=False, root='', replicates=1): + super(FlyingThingsClean, self).__init__(args, is_cropped=is_cropped, root=root, dstype='frames_cleanpass', + replicates=replicates) + + +class FlyingThingsFinal(FlyingThings): + def __init__(self, args, is_cropped=False, root='', replicates=1): + super(FlyingThingsFinal, self).__init__(args, is_cropped=is_cropped, root=root, dstype='frames_finalpass', + replicates=replicates) + + +class ChairsSDHom(object): + def __init__(self, args, is_cropped, root='/path/to/chairssdhom/data', dstype='train', replicates=1): + self.args = args + self.is_cropped = is_cropped + self.crop_size = args.crop_size + self.render_size = args.inference_size + self.replicates = replicates + + image1 = sorted(glob(join(root, dstype, 't0/*.png'))) + image2 = sorted(glob(join(root, dstype, 't1/*.png'))) + self.flow_list = sorted(glob(join(root, dstype, 'flow/*.flo'))) + + assert (len(image1) == len(self.flow_list)) + + self.image_list = [] + for i in range(len(self.flow_list)): + im1 = image1[i] + im2 = image2[i] + self.image_list += [[im1, im2]] + + assert len(self.image_list) == len(self.flow_list) + + self.size = len(self.image_list) + + self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape + + if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or ( + self.frame_size[1] % 64): + self.render_size[0] = ((self.frame_size[0]) // 64) * 64 + self.render_size[1] = ((self.frame_size[1]) // 64) * 64 + + args.inference_size = self.render_size + + def __getitem__(self, index): + index = index % self.size + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = frame_utils.read_gen(self.flow_list[index]) + flow = flow[::-1, :, :] + + images = [img1, img2] + image_size = img1.shape[:2] + if self.is_cropped: + cropper = StaticRandomCrop(image_size, self.crop_size) + else: + cropper = StaticCenterCrop(image_size, self.render_size) + images = list(map(cropper, images)) + flow = cropper(flow) + + images = np.array(images).transpose(3, 0, 1, 2) + flow = flow.transpose(2, 0, 1) + return [images], [flow] + + def __len__(self): + return self.size * self.replicates + + +class ChairsSDHomTrain(ChairsSDHom): + def __init__(self, args, is_cropped=False, root='', replicates=1): + super(ChairsSDHomTrain, self).__init__(args, is_cropped=is_cropped, root=root, dstype='train', + replicates=replicates) + + +class ChairsSDHomTest(ChairsSDHom): + def __init__(self, args, is_cropped=False, root='', replicates=1): + super(ChairsSDHomTest, self).__init__(args, is_cropped=is_cropped, root=root, dstype='test', + replicates=replicates) + + +class ImagesFromFolder(object): + def __init__(self, args, is_cropped, root='/path/to/frames/only/folder', iext='png', replicates=1): + self.args = args + self.is_cropped = is_cropped + self.crop_size = args.crop_size + self.render_size = args.inference_size + self.replicates = replicates + + images = sorted(glob(join(root, '*.' + iext))) + self.image_list = [] + for i in range(len(images) - 1): + im1 = images[i] + im2 = images[i + 1] + self.image_list += [[im1, im2]] + + self.size = len(self.image_list) + + self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape + + if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or ( + self.frame_size[1] % 64): + self.render_size[0] = ((self.frame_size[0]) // 64) * 64 + self.render_size[1] = ((self.frame_size[1]) // 64) * 64 + + args.inference_size = self.render_size + + def __getitem__(self, index): + index = index % self.size + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + images = [img1, img2] + image_size = img1.shape[:2] + if self.is_cropped: + cropper = StaticRandomCrop(image_size, self.crop_size) + else: + cropper = StaticCenterCrop(image_size, self.render_size) + images = list(map(cropper, images)) + + images = np.array(images).transpose(3, 0, 1, 2) + return [images], [np.zeros(images.size()[0:1] + (2,) + images.size()[-2:])] + + def __len__(self): + return self.size * self.replicates + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.inference_size = [1080, 1920] + args.crop_size = [384, 512] + + index = 50 + flyingchairs_dataset = FlyingChairs(args, True, root='/ssd2/zhenghe/DATA/FlyingChairs_release/data') + # a, b = flyingchairs_dataset[index] + # im1 = a[0][:,0,:,:].transpose(1,2,0) + # im2 = a[0][:,1,:,:].transpose(1,2,0) + # flo = b[0].transpose(1, 2, 0) / 20.0 + # flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False) + # imsave('./hsv_pd.png', flow_color) + sample_num = len(flyingchairs_dataset) + reader = reader_flyingchairs(flyingchairs_dataset) + BATCH_SIZE = 8 + train_batch_reader = paddle.batch(reader, BATCH_SIZE, drop_last=True) + epoch_num = 1 + + with fluid.dygraph.guard(): + for epoch in range(epoch_num): + for batch_id, data in enumerate(train_batch_reader()): + im1_data = np.array( + [x[0] for x in data]).astype('float32') + im2_data = np.array( + [x[1] for x in data]).astype('float32') + flo_data = np.array( + [x[2] for x in data]).astype('float32') + if batch_id % 500 == 0: + # if batch_id < 10: + print(batch_id) + print(im1_data.shape) + print(im2_data.shape) + print(flo_data.shape) + im1 = im1_data[0, :, :, :] + im2 = im2_data[0, :, :, :] + flo = flo_data[0, :, :, :] + print(im1.shape) + print(im2.shape) + print(flo.shape) + imsave('./img1.png', im1) + imsave('./img2.png', im2) + flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False) + imsave('./hsv_pd.png', flow_color) + print("batch_id:", batch_id) + print(batch_id * BATCH_SIZE) + print(sample_num) + # img = fluid.dygraph.to_variable(dy_x_data) + + + + + diff --git a/PaddleCV/Research/PWCNet/data/download.sh b/PaddleCV/Research/PWCNet/data/download.sh new file mode 100755 index 0000000000000000000000000000000000000000..8a0c5dad4d5fb233be56050983bf1f0b293944d0 --- /dev/null +++ b/PaddleCV/Research/PWCNet/data/download.sh @@ -0,0 +1,17 @@ +#!/bin/bash + +#mkdir FlyingThings3D_release +#cd FlyingThings3D_release +# +#wget http://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/raw_data/flyingthings3d__frames_cleanpass.tar +#wget http://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/derived_data/flyingthings3d__optical_flow.tar.bz2 +# +#tar xvf flyingthings3d__frames_cleanpass.tar +#tar xvf flyingthings3d__optical_flow.tar.bz2 +# +#cd .. +wget http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs/FlyingChairs.zip +unzip FlyingChairs.zip + +#wget https://lmb.informatik.uni-freiburg.de/data/FlowNet2/ChairsSDHom/ChairsSDHom.tar.gz +#tar xvzf ChairsSDHom.tar.gz diff --git a/PaddleCV/Research/PWCNet/data/frame_0010.png b/PaddleCV/Research/PWCNet/data/frame_0010.png new file mode 100644 index 0000000000000000000000000000000000000000..80df246723859bb1e0aaca2f41944537cdc18d70 Binary files /dev/null and b/PaddleCV/Research/PWCNet/data/frame_0010.png differ diff --git a/PaddleCV/Research/PWCNet/data/frame_0011.png b/PaddleCV/Research/PWCNet/data/frame_0011.png new file mode 100644 index 0000000000000000000000000000000000000000..0ee97e97a7eba203eb6f67f032f81a8fbdb2c3ed Binary files /dev/null and b/PaddleCV/Research/PWCNet/data/frame_0011.png differ diff --git a/PaddleCV/Research/PWCNet/data/utils/__init__.py b/PaddleCV/Research/PWCNet/data/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..139597f9cb07c5d48bed18984ec4747f4b4f3438 --- /dev/null +++ b/PaddleCV/Research/PWCNet/data/utils/__init__.py @@ -0,0 +1,2 @@ + + diff --git a/PaddleCV/Research/PWCNet/data/utils/flow_utils.py b/PaddleCV/Research/PWCNet/data/utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4ee0ecbb16a92bb9f738d278b61a18862ad518a5 --- /dev/null +++ b/PaddleCV/Research/PWCNet/data/utils/flow_utils.py @@ -0,0 +1,57 @@ +import numpy as np + +TAG_CHAR = np.array([202021.25], np.float32) + + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def writeFlow(filename, uv, v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert (uv.ndim == 3) + assert (uv.shape[2] == 2) + u = uv[:, :, 0] + v = uv[:, :, 1] + else: + u = uv + + assert (u.shape == v.shape) + height, width = u.shape + f = open(filename, 'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() \ No newline at end of file diff --git a/PaddleCV/Research/PWCNet/data/utils/frame_utils.py b/PaddleCV/Research/PWCNet/data/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40a8ea5a206aec428241ac7674de83a1a4099de0 --- /dev/null +++ b/PaddleCV/Research/PWCNet/data/utils/frame_utils.py @@ -0,0 +1,18 @@ +import numpy as np +from os.path import * +from scipy.misc import imread +from . import flow_utils + +def read_gen(file_name): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + im = imread(file_name) + if im.shape[2] > 3: + return im[:,:,:3] + else: + return im + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return flow_utils.readFlow(file_name).astype(np.float32) + return [] \ No newline at end of file diff --git a/PaddleCV/Research/PWCNet/finetune.sh b/PaddleCV/Research/PWCNet/finetune.sh new file mode 100755 index 0000000000000000000000000000000000000000..29d2e802da3cc3fa13413ab768071e19d59e3147 --- /dev/null +++ b/PaddleCV/Research/PWCNet/finetune.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +python3 train.py --loss l1 --pretrained ./out/pwc_net_paddle --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data + +# use multi gpus NEED TO DO LATER +#python3 -m paddle.distributed.launch --selected_gpus=0,1 train.py --use_multi_gpu --batch_size 40 --loss l1 --pretrained ./out/pwc_net_paddle --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data diff --git a/PaddleCV/Research/PWCNet/infer.py b/PaddleCV/Research/PWCNet/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..717c18f02c017e910b4a86e09616386668822e8a --- /dev/null +++ b/PaddleCV/Research/PWCNet/infer.py @@ -0,0 +1,148 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Infer for PWCNet.""" +import sys +import pickle +import time +import cv2 +import numpy as np +from math import ceil +from scipy.ndimage import imread +from scipy.misc import imsave +import paddle.fluid as fluid +from models.model import PWCDCNet +from src import flow_vis + + + +def writeFlowFile(filename, uv): + """ + According to the matlab code of Deqing Sun and c++ source code of Daniel Scharstein + Contact: dqsun@cs.brown.edu + Contact: schar@middlebury.edu + """ + TAG_STRING = np.array(202021.25, dtype=np.float32) + if uv.shape[2] != 2: + sys.exit("writeFlowFile: flow must have two bands!"); + H = np.array(uv.shape[0], dtype=np.int32) + W = np.array(uv.shape[1], dtype=np.int32) + with open(filename, 'wb') as f: + f.write(TAG_STRING.tobytes()) + f.write(W.tobytes()) + f.write(H.tobytes()) + f.write(uv.tobytes()) + + +def load_dict(filename_): + with open(filename_, 'rb') as f: + ret_di = pickle.load(f) + return ret_di + + +def pad_input(x0): + intWidth = x0.shape[2] + intHeight = x0.shape[3] + if intWidth != ((intWidth >> 6) << 6): + intWidth_pad = (((intWidth >> 6) + 1) << 6) # more than necessary + intPaddingLeft = int((intWidth_pad - intWidth) / 2) + intPaddingRight = intWidth_pad - intWidth - intPaddingLeft + else: + intWidth_pad = intWidth + intPaddingLeft = 0 + intPaddingRight = 0 + + if intHeight != ((intHeight >> 6) << 6): + intHeight_pad = (((intHeight >> 6) + 1) << 6) # more than necessary + intPaddingTop = int((intHeight_pad - intHeight) / 2) + intPaddingBottom = intHeight_pad - intHeight - intPaddingTop + else: + intHeight_pad = intHeight + intPaddingTop = 0 + intPaddingBottom = 0 + + out = fluid.layers.pad2d(input=x0, + paddings=[intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom], + mode='edge') + + return out, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] + + +def main(): + im1_fn = 'data/frame_0010.png' + im2_fn = 'data/frame_0011.png' + flow_fn = './tmp/frame_0010_pd.flo' + if len(sys.argv) > 1: + im1_fn = sys.argv[1] + if len(sys.argv) > 2: + im2_fn = sys.argv[2] + if len(sys.argv) > 3: + flow_fn = sys.argv[3] + + im_all = [imread(img) for img in [im1_fn, im2_fn]] + im_all = [im[:, :, :3] for im in im_all] + + # rescale the image size to be multiples of 64 + divisor = 64. + H = im_all[0].shape[0] + W = im_all[0].shape[1] + print('origin shape : ', H, W) + + H_ = int(ceil(H / divisor) * divisor) + W_ = int(ceil(W / divisor) * divisor) + print('resize shape: ', H_, W_) + for i in range(len(im_all)): + im_all[i] = cv2.resize(im_all[i], (W_, H_)) + + for _i, _inputs in enumerate(im_all): + im_all[_i] = im_all[_i][:, :, ::-1] + im_all[_i] = 1.0 * im_all[_i] / 255.0 + im_all[_i] = np.transpose(im_all[_i], (2, 0, 1)) + im_all = np.concatenate((im_all[0], im_all[1]), axis=0).astype(np.float32) + im_all = im_all[np.newaxis, :, :, :] + + with fluid.dygraph.guard(place=fluid.CUDAPlace(0)): + im_all = fluid.dygraph.to_variable(im_all) + im_all, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] = pad_input( + im_all) + + model = PWCDCNet("pwcnet") + model.eval() + pd_pretrain, _ = fluid.dygraph.load_dygraph("paddle_model/pwc_net_paddle") + model.set_dict(pd_pretrain) + start = time.time() + flo = model(im_all) + end = time.time() + print('Time of PWCNet model for one infer step: ', end - start) + flo = flo[0].numpy() * 20.0 + # scale the flow back to the input size + flo = np.swapaxes(np.swapaxes(flo, 0, 1), 1, 2) + flo = flo[intPaddingTop * 2:intPaddingTop * 2 + intHeight * 2, + intPaddingLeft * 2: intPaddingLeft * 2 + intWidth * 2, :] + u_ = cv2.resize(flo[:, :, 0], (W, H)) + v_ = cv2.resize(flo[:, :, 1], (W, H)) + u_ *= W / float(W_) + v_ *= H / float(H_) + flo = np.dstack((u_, v_)) + + # # Apply the coloring (for OpenCV, set convert_to_bgr=True) + flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False) + imsave('./tmp/hsv_pd.png', flow_color) + + writeFlowFile(flow_fn, flo) + + +if __name__ == '__main__': + main() + + diff --git a/PaddleCV/Research/PWCNet/models/__init__.py b/PaddleCV/Research/PWCNet/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..44a41a91f24512697caec6068c7ce1f4101c93b5 --- /dev/null +++ b/PaddleCV/Research/PWCNet/models/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import models.model diff --git a/PaddleCV/Research/PWCNet/models/model.py b/PaddleCV/Research/PWCNet/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..435e9f4dbc375251468906ca0f33ac3c79701804 --- /dev/null +++ b/PaddleCV/Research/PWCNet/models/model.py @@ -0,0 +1,277 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle.fluid as fluid +from paddle.fluid.dygraph import Conv2D, Conv2DTranspose +from correlation_op.correlation import correlation + + +class PWCDCNet(fluid.dygraph.Layer): + def __init__(self, name_scope, md=4): + super(PWCDCNet, self).__init__(name_scope) + self.param_attr = fluid.ParamAttr( + name='conv_weights', + regularizer=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0004), + initializer=fluid.initializer.MSRAInitializer(uniform=True, fan_in=None, seed=0)) + self.md = md + self.conv1a = Conv2D("conv1a", 16, filter_size=3, stride=2, padding=1, param_attr=self.param_attr) + self.conv1aa = Conv2D("conv1aa", 16, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv1b = Conv2D("conv1b", 16, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv2a = Conv2D("conv2a", 32, filter_size=3, stride=2, padding=1, param_attr=self.param_attr) + self.conv2aa = Conv2D("conv2aa", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv2b = Conv2D("conv2b", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv3a = Conv2D("conv3a", 64, filter_size=3, stride=2, padding=1, param_attr=self.param_attr) + self.conv3aa = Conv2D("conv3aa", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv3b = Conv2D("conv3b", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv4a = Conv2D("conv4a", 96, filter_size=3, stride=2, padding=1, param_attr=self.param_attr) + self.conv4aa = Conv2D("conv4aa", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv4b = Conv2D("conv4b", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv5a = Conv2D("conv5a", 128, filter_size=3, stride=2, padding=1, param_attr=self.param_attr) + self.conv5aa = Conv2D("conv5aa", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv5b = Conv2D("conv5b", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv6aa = Conv2D("conv6aa", 196, filter_size=3, stride=2, padding=1, param_attr=self.param_attr) + self.conv6a = Conv2D("conv6a", 196, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv6b = Conv2D("conv6b", 196, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + + self.conv6_0 = Conv2D("conv6_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv6_1 = Conv2D("conv6_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv6_2 = Conv2D("conv6_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv6_3 = Conv2D("conv6_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv6_4 = Conv2D("conv6_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.predict_flow6 = Conv2D("predict_flow6", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr) + self.deconv6 = Conv2DTranspose("deconv6", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + self.upfeat6 = Conv2DTranspose("upfeat6", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + + self.conv5_0 = Conv2D("conv5_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv5_1 = Conv2D("conv5_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv5_2 = Conv2D("conv5_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv5_3 = Conv2D("conv5_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv5_4 = Conv2D("conv5_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.predict_flow5 = Conv2D("predict_flow5", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr) + self.deconv5 = Conv2DTranspose("deconv5", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + self.upfeat5 = Conv2DTranspose("upfeat5", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + + self.conv4_0 = Conv2D("conv4_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv4_1 = Conv2D("conv4_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv4_2 = Conv2D("conv4_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv4_3 = Conv2D("conv4_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv4_4 = Conv2D("conv4_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.predict_flow4 = Conv2D("predict_flow4", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr) + self.deconv4 = Conv2DTranspose("deconv4", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + self.upfeat4 = Conv2DTranspose("upfeat4", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + + self.conv3_0 = Conv2D("conv3_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv3_1 = Conv2D("conv3_1", 128, filter_size=3, stride=1, padding=1 ,param_attr=self.param_attr) + self.conv3_2 = Conv2D("conv3_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv3_3 = Conv2D("conv3_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv3_4 = Conv2D("conv3_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.predict_flow3 = Conv2D("predict_flow3", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr) + self.deconv3 = Conv2DTranspose("deconv3", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + self.upfeat3 = Conv2DTranspose("upfeat3", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + + self.conv2_0 = Conv2D("conv2_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv2_1 = Conv2D("conv2_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv2_2 = Conv2D("conv2_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv2_3 = Conv2D("conv2_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.conv2_4 = Conv2D("conv2_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr) + self.predict_flow2 = Conv2D("predict_flow2", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr) + self.deconv2 = Conv2DTranspose("deconv2", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr) + + self.dc_conv1 = Conv2D("dc_conv1", 128, filter_size=3, stride=1, padding=1, dilation=1, param_attr=self.param_attr) + self.dc_conv2 = Conv2D("dc_conv2", 128, filter_size=3, stride=1, padding=2, dilation=2, param_attr=self.param_attr) + self.dc_conv3 = Conv2D("dc_conv3", 128, filter_size=3, stride=1, padding=4, dilation=4, param_attr=self.param_attr) + self.dc_conv4 = Conv2D("dc_conv4", 96, filter_size=3, stride=1, padding=8, dilation=8, param_attr=self.param_attr) + self.dc_conv5 = Conv2D("dc_conv5", 64, filter_size=3, stride=1, padding=16, dilation=16, param_attr=self.param_attr) + self.dc_conv6 = Conv2D("dc_conv6", 32, filter_size=3, stride=1, padding=1, dilation=1, param_attr=self.param_attr) + self.dc_conv7 = Conv2D("dc_conv7", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr) + + def warp(self, x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + + """ + + B, C, H, W = x.shape + # mesh grid + xx_pd = fluid.layers.range(0, W, 1, 'float32') + xx_pd = fluid.layers.reshape(xx_pd, shape=[1, -1]) + xx_pd = fluid.layers.expand(x=xx_pd, expand_times=[H, 1]) + xx_pd = fluid.layers.reshape(xx_pd, shape=[1, 1, H, W]) + xx_pd = fluid.layers.expand(x=xx_pd, expand_times=[B, 1, 1, 1]) + + yy_pd = fluid.layers.range(0, H, 1, 'float32') + yy_pd = fluid.layers.reshape(yy_pd, shape=[-1, 1]) + yy_pd = fluid.layers.expand(x=yy_pd, expand_times=[1, W]) + yy_pd = fluid.layers.reshape(x=yy_pd, shape=[1, 1, H, W]) + yy_pd = fluid.layers.expand(x=yy_pd, expand_times=[B, 1, 1, 1]) + grid_pd = fluid.layers.concat(input=[xx_pd, yy_pd], axis=1) + flo_pd = flo + vgrid_pd = fluid.layers.elementwise_add(grid_pd, flo_pd) + vgrid_pd_0 = 2.0 * fluid.layers.slice(vgrid_pd, axes=[1], starts=[0], ends=[1]) / max(W - 1, 1) - 1.0 + vgrid_pd_1 = 2.0 * fluid.layers.slice(vgrid_pd, axes=[1], starts=[1], ends=[2]) / max(H - 1, 1) - 1.0 + vgrid_pd = fluid.layers.concat(input=[vgrid_pd_0, vgrid_pd_1], axis=1) + vgrid_pd = fluid.layers.transpose(vgrid_pd, [0, 2, 3, 1]) + output = fluid.layers.grid_sampler(name='grid_sample', x=x, grid=vgrid_pd) + + mask = fluid.layers.zeros_like(x) + mask = mask + 1.0 + mask = fluid.layers.grid_sampler(name='grid_sample', x=mask, grid=vgrid_pd) + mask_temp1 = fluid.layers.cast(mask < 0.9990, 'float32') + mask = mask * (1 - mask_temp1) + mask = fluid.layers.cast(mask > 0, 'float32') + outwarp = fluid.layers.elementwise_mul(output, mask) + + return outwarp + + def corr(self, x_1, x_2): + out = correlation(x_1, x_2, pad_size=self.md, kernel_size=1, max_displacement=self.md, + stride1=1, stride2=1, corr_type_multiply=1) + return out + + def forward(self, x, output_more=False): + im1 = fluid.layers.slice(x, axes=[1], starts=[0], ends=[3]) + im2 = fluid.layers.slice(x, axes=[1], starts=[3], ends=[6]) + # print("\n\n***************************PWC Net details *************** \n\n") + c11 = fluid.layers.leaky_relu(self.conv1a(im1), 0.1) + c11 = fluid.layers.leaky_relu(self.conv1aa(c11), 0.1) + c11 = fluid.layers.leaky_relu(self.conv1b(c11), 0.1) + + c21 = fluid.layers.leaky_relu(self.conv1a(im2), 0.1) + c21 = fluid.layers.leaky_relu(self.conv1aa(c21), 0.1) + c21 = fluid.layers.leaky_relu(self.conv1b(c21), 0.1) + + c12 = fluid.layers.leaky_relu(self.conv2a(c11), 0.1) + c12 = fluid.layers.leaky_relu(self.conv2aa(c12), 0.1) + c12 = fluid.layers.leaky_relu(self.conv2b(c12), 0.1) + + c22 = fluid.layers.leaky_relu(self.conv2a(c21), 0.1) + c22 = fluid.layers.leaky_relu(self.conv2aa(c22), 0.1) + c22 = fluid.layers.leaky_relu(self.conv2b(c22), 0.1) + + c13 = fluid.layers.leaky_relu(self.conv3a(c12), 0.1) + c13 = fluid.layers.leaky_relu(self.conv3aa(c13), 0.1) + c13 = fluid.layers.leaky_relu(self.conv3b(c13), 0.1) + + c23 = fluid.layers.leaky_relu(self.conv3a(c22), 0.1) + c23 = fluid.layers.leaky_relu(self.conv3aa(c23), 0.1) + c23 = fluid.layers.leaky_relu(self.conv3b(c23), 0.1) + + c14 = fluid.layers.leaky_relu(self.conv4a(c13), 0.1) + c14 = fluid.layers.leaky_relu(self.conv4aa(c14), 0.1) + c14 = fluid.layers.leaky_relu(self.conv4b(c14), 0.1) + + c24 = fluid.layers.leaky_relu(self.conv4a(c23), 0.1) + c24 = fluid.layers.leaky_relu(self.conv4aa(c24), 0.1) + c24 = fluid.layers.leaky_relu(self.conv4b(c24), 0.1) + + c15 = fluid.layers.leaky_relu(self.conv5a(c14), 0.1) + c15 = fluid.layers.leaky_relu(self.conv5aa(c15), 0.1) + c15 = fluid.layers.leaky_relu(self.conv5b(c15), 0.1) + + c25 = fluid.layers.leaky_relu(self.conv5a(c24), 0.1) + c25 = fluid.layers.leaky_relu(self.conv5aa(c25), 0.1) + c25 = fluid.layers.leaky_relu(self.conv5b(c25), 0.1) + + c16 = fluid.layers.leaky_relu(self.conv6aa(c15), 0.1) + c16 = fluid.layers.leaky_relu(self.conv6a(c16), 0.1) + c16 = fluid.layers.leaky_relu(self.conv6b(c16), 0.1) + + c26 = fluid.layers.leaky_relu(self.conv6aa(c25), 0.1) + c26 = fluid.layers.leaky_relu(self.conv6a(c26), 0.1) + c26 = fluid.layers.leaky_relu(self.conv6b(c26), 0.1) + + corr6 = self.corr(c16, c26) + corr6 = fluid.layers.leaky_relu(corr6, alpha=0.1) + + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_0(corr6), 0.1), corr6], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_1(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_2(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_3(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_4(x), 0.1), x], axis=1) + + flow6 = self.predict_flow6(x) + up_flow6 = self.deconv6(flow6) + up_feat6 = self.upfeat6(x) + + warp5 = self.warp(c25, up_flow6 * 0.625) + corr5 = self.corr(c15, warp5) + corr5 = fluid.layers.leaky_relu(corr5, alpha=0.1) + + x = fluid.layers.concat(input=[corr5, c15, up_flow6, up_feat6], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_0(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_1(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_2(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_3(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_4(x), 0.1), x], axis=1) + + flow5 = self.predict_flow5(x) + up_flow5 = self.deconv5(flow5) + up_feat5 = self.upfeat5(x) + + warp4 = self.warp(c24, up_flow5 * 1.25) + corr4 = self.corr(c14, warp4) + corr4 = fluid.layers.leaky_relu(corr4, alpha=0.1) + + x = fluid.layers.concat(input=[corr4, c14, up_flow5, up_feat5], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_0(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_1(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_2(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_3(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_4(x), 0.1), x], axis=1) + + flow4 = self.predict_flow4(x) + up_flow4 = self.deconv4(flow4) + up_feat4 = self.upfeat4(x) + + warp3 = self.warp(c23, up_flow4 * 2.5) + corr3 = self.corr(c13, warp3) + corr3 = fluid.layers.leaky_relu(corr3, alpha=0.1) + + x = fluid.layers.concat(input=[corr3, c13, up_flow4, up_feat4], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_0(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_1(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_2(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_3(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_4(x), 0.1), x], axis=1) + + flow3 = self.predict_flow3(x) + up_flow3 = self.deconv3(flow3) + up_feat3 = self.upfeat3(x) + + warp2 = self.warp(c22, up_flow3 * 5.0) + corr2 = self.corr(c12, warp2) + corr2 = fluid.layers.leaky_relu(corr2, alpha=0.1) + + x = fluid.layers.concat(input=[corr2, c12, up_flow3, up_feat3], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_0(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_1(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_2(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_3(x), 0.1), x], axis=1) + x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_4(x), 0.1), x], axis=1) + + flow2 = self.predict_flow2(x) + + x = fluid.layers.leaky_relu(self.dc_conv4(fluid.layers.leaky_relu( + self.dc_conv3(fluid.layers.leaky_relu(self.dc_conv2(fluid.layers.leaky_relu(self.dc_conv1(x), 0.1)), 0.1)), + 0.1)), 0.1) + flow2 += self.dc_conv7( + fluid.layers.leaky_relu(self.dc_conv6(fluid.layers.leaky_relu(self.dc_conv5(x), 0.1)), 0.1)) + if not output_more: + return flow2 + else: + return [flow2, flow3, flow4, flow5, flow6] + diff --git a/PaddleCV/Research/PWCNet/my_args.py b/PaddleCV/Research/PWCNet/my_args.py new file mode 100644 index 0000000000000000000000000000000000000000..bb673efe10534ba319fa240c09f05d044be76d4b --- /dev/null +++ b/PaddleCV/Research/PWCNet/my_args.py @@ -0,0 +1,17 @@ +import argparse + +parser = argparse.ArgumentParser(description='PWCNet_paddle') +parser.add_argument('--dataset', default='FlyingChairs', help='dataset type : FlyingChairs') +parser.add_argument('--data_root', default='', help='the path of selected datasets') +parser.add_argument('--model_out_dir', default='./out', help='the path of selected datasets') +parser.add_argument('--loss', default='l2', help='loss type : first train with l2 and finetune with l1') +parser.add_argument('--train_val_txt', default='', help='the path of selected train_val_txt of dataset') +parser.add_argument('--numEpoch', '-e', type=int, default=100, help='Number of epochs to train') +parser.add_argument('--batch_size', '-b', type=int, default=40, help='batch size') +parser.add_argument('--pretrained', default=None, help='path to the pretrained model weights') +parser.add_argument('--optimize', default=None, help='path to the pretrained optimize weights') +parser.add_argument('--use_multi_gpu',action = 'store_true', help='Enable multi gpu mode') + +args = parser.parse_args() +args.inference_size = [384, 512] +args.crop_size = [384, 448] \ No newline at end of file diff --git a/PaddleCV/Research/PWCNet/paddle_model/pwc_net_chairs_paddle.pdparams b/PaddleCV/Research/PWCNet/paddle_model/pwc_net_chairs_paddle.pdparams new file mode 100755 index 0000000000000000000000000000000000000000..1b8a626b6bd1c5d30e65154bc6bb54f336716b25 Binary files /dev/null and b/PaddleCV/Research/PWCNet/paddle_model/pwc_net_chairs_paddle.pdparams differ diff --git a/PaddleCV/Research/PWCNet/paddle_model/pwc_net_paddle.pdparams b/PaddleCV/Research/PWCNet/paddle_model/pwc_net_paddle.pdparams new file mode 100755 index 0000000000000000000000000000000000000000..6e947b41ca33f8871bb72d3ad1e8f0b709c8f354 Binary files /dev/null and b/PaddleCV/Research/PWCNet/paddle_model/pwc_net_paddle.pdparams differ diff --git a/PaddleCV/Research/PWCNet/src/__init__.py b/PaddleCV/Research/PWCNet/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/PaddleCV/Research/PWCNet/src/flow_vis.py b/PaddleCV/Research/PWCNet/src/flow_vis.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fe36828f829151ec307f1b4e1dc687b4ecc8b3 --- /dev/null +++ b/PaddleCV/Research/PWCNet/src/flow_vis.py @@ -0,0 +1,163 @@ +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + + +def make_colorwheel(): + ''' + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + ''' + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_compute_color(u, v, convert_to_bgr=False): + ''' + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + :param u: np.ndarray, input horizontal flow + :param v: np.ndarray, input vertical flow + :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB + :return: + ''' + + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + + for i in range(colorwheel.shape[1]): + + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range? + + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + + return flow_image + + +def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): + ''' + Expects a two dimensional flow image of shape [H,W,2] + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + :param flow_uv: np.ndarray of shape [H,W,2] + :param clip_flow: float, maximum clipping value for flow + :return: + ''' + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_compute_color(u, v, convert_to_bgr) + + +def read_flow(filename): + """ + https://github.com/sampepose/flownet2-tf/blob/master/src/flowlib.py + read optical flow from Middlebury .flo file + :param filename: name of the flow file + :return: optical flow data in matrix + """ + f = open(filename, 'rb') + magic = np.fromfile(f, np.float32, count=1) + data2d = None + + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + print("Reading %d x %d flo file" % (h, w)) + data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0]) + # reshape data into 3D array (columns, rows, channels) + data2d = np.resize(data2d, (h[0], w[0], 2)) + f.close() + return data2d diff --git a/PaddleCV/Research/PWCNet/src/multiscaleloss.py b/PaddleCV/Research/PWCNet/src/multiscaleloss.py new file mode 100644 index 0000000000000000000000000000000000000000..a52a74acf278fde4a99335af21459050fd28a7ef --- /dev/null +++ b/PaddleCV/Research/PWCNet/src/multiscaleloss.py @@ -0,0 +1,85 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import paddle +import paddle.fluid as fluid + + +def EPE(input_flow, target_flow, loss_type, sparse=False, mean=True): + if loss_type == 'l1': + EPE_map = fluid.layers.abs(input_flow - target_flow) + else: + EPE_map = fluid.layers.square(input_flow - target_flow) + if sparse: #TODO mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) EPE_map = EPE_map[~mask] + mask_temp1 = fluid.layers.cast(target_flow[:, 0] == 0, 'float32') + mask_temp2 = fluid.layers.cast(target_flow[:, 1] == 0, 'float32') + mask = 1 - fluid.layers.elementwise_mul(mask_temp1, mask_temp2) + mask = fluid.layers.reshape(mask, [mask.shape[0], 1, mask.shape[1], mask.shape[2]]) + mask = fluid.layers.concat([mask, mask], 1) + EPE_map = EPE_map * mask + + if mean: + return fluid.layers.mean(EPE_map) + else: + batch_size = EPE_map.shape[0] + res_sum = fluid.layers.reduce_sum(EPE_map) + res = res_sum / batch_size + return res + + +def sparse_max_pool(input, size): + '''Downsample the input by considering 0 values as invalid. + + Unfortunately, no generic interpolation mode can resize a sparse map correctly, + the strategy here is to use max pooling for positive values and "min pooling" + for negative values, the two results are then summed. + This technique allows sparsity to be minized, contrary to nearest interpolation, + which could potentially lose information for isolated data points.''' + + positive = fluid.layers.cast(input > 0, 'float32') + negative = fluid.layers.cast(input < 0, 'float32') + output = fluid.layers.adaptive_pool2d(input * positive, size) - fluid.layers.adaptive_pool2d(-input * negative, + size) + return output + + +def multiscaleEPE(network_output, target_flow, loss_type, weights=None, sparse=False): + def one_scale(output, target, sparse, loss_type): + if sparse: + h = output.shape[2] + w = output.shape[3] + target_scaled = sparse_max_pool(target, [h, w]) + else: + target_scaled = fluid.layers.resize_bilinear(target, out_shape=[output.shape[2], + output.shape[3]], + align_corners=False, align_mode=False) + return EPE(output, target_scaled, loss_type=loss_type, sparse=sparse, mean=False) + + if type(network_output) not in [tuple, list]: + network_output = [network_output] + if weights is None: + weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article + assert(len(weights) == len(network_output)) + + loss = 0 + for output, weight in zip(network_output, weights): + loss += weight * one_scale(output, target_flow, sparse, loss_type) + return loss + + +def realEPE(output, target, sparse=False): + upsampled_output = fluid.layers.resize_bilinear(output, out_shape=[target.shape[2], + target.shape[3]], + align_corners=False, align_mode=False) + return EPE(upsampled_output, target, sparse, mean=True) + diff --git a/PaddleCV/Research/PWCNet/src/read_files.py b/PaddleCV/Research/PWCNet/src/read_files.py new file mode 100644 index 0000000000000000000000000000000000000000..743a57ddc2552c668c5a76b3511659c861ab160f --- /dev/null +++ b/PaddleCV/Research/PWCNet/src/read_files.py @@ -0,0 +1,22 @@ +def read_txt(videoTxt): + with open(videoTxt, 'r') as f: + videolist = f.readlines() + return videolist + + +def read_txt_to_index(file): + data = read_txt(file) + data = list(map(int, data)) + return data + + +def main(): + file = 'data_dir/FlyingChairs_release/FlyingChairs_train_val.txt' + data = read_txt_to_index(file) + data = list(map(int, data)) + print(data) + print(len(data)) + + +if __name__ == '__main__': + main() diff --git a/PaddleCV/Research/PWCNet/tmp/hsv_pd.png b/PaddleCV/Research/PWCNet/tmp/hsv_pd.png new file mode 100755 index 0000000000000000000000000000000000000000..0ebc10300e6d3e93260ddbec59bc9d002958c01a Binary files /dev/null and b/PaddleCV/Research/PWCNet/tmp/hsv_pd.png differ diff --git a/PaddleCV/Research/PWCNet/tmp/hsv_pd_chairs.png b/PaddleCV/Research/PWCNet/tmp/hsv_pd_chairs.png new file mode 100755 index 0000000000000000000000000000000000000000..cc3249bf0ca991502f715f26e29a723b9319b8ce Binary files /dev/null and b/PaddleCV/Research/PWCNet/tmp/hsv_pd_chairs.png differ diff --git a/PaddleCV/Research/PWCNet/train.py b/PaddleCV/Research/PWCNet/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7dc3b05edf1ccd2b59594e5c4a157e90b9390735 --- /dev/null +++ b/PaddleCV/Research/PWCNet/train.py @@ -0,0 +1,275 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Trainer for PWCNet.""" +import sys +import os +os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = "0.99999" +os.environ["FLAGS_eager_delete_tensor_gb"] = "0" +import pickle +import time +import cv2 +import numpy as np +import paddle +import paddle.fluid as fluid +from scipy.misc import imsave +from src import flow_vis +from models.model import PWCDCNet +from data.datasets import FlyingChairs, reader_flyingchairs +from src.multiscaleloss import multiscaleEPE, realEPE +from AverageMeter import * +from my_args import args + + +def writeFlowFile(filename, uv): + """ + According to the matlab code of Deqing Sun and c++ source code of Daniel Scharstein + Contact: dqsun@cs.brown.edu + Contact: schar@middlebury.edu + """ + TAG_STRING = np.array(202021.25, dtype=np.float32) + if uv.shape[2] != 2: + sys.exit("writeFlowFile: flow must have two bands!"); + H = np.array(uv.shape[0], dtype=np.int32) + W = np.array(uv.shape[1], dtype=np.int32) + with open(filename, 'wb') as f: + f.write(TAG_STRING.tobytes()) + f.write(W.tobytes()) + f.write(H.tobytes()) + f.write(uv.tobytes()) + + +def load_dict(filename_): + with open(filename_, 'rb') as f: + ret_di = pickle.load(f) + return ret_di + + +def pad_input(x0): + intWidth = x0.shape[2] + intHeight = x0.shape[3] + if intWidth != ((intWidth >> 6) << 6): + intWidth_pad = (((intWidth >> 6) + 1) << 6) # more than necessary + intPaddingLeft = int((intWidth_pad - intWidth) / 2) + intPaddingRight = intWidth_pad - intWidth - intPaddingLeft + else: + intWidth_pad = intWidth + intPaddingLeft = 0 + intPaddingRight = 0 + + if intHeight != ((intHeight >> 6) << 6): + intHeight_pad = (((intHeight >> 6) + 1) << 6) # more than necessary + intPaddingTop = int((intHeight_pad - intHeight) / 2) + intPaddingBottom = intHeight_pad - intHeight - intPaddingTop + else: + intHeight_pad = intHeight + intPaddingTop = 0 + intPaddingBottom = 0 + + out = fluid.layers.pad2d(input=x0, + paddings=[intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom], + mode='edge') + + return out, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] + + +def val(model, batch_reader, epoch, batch_num): + model.eval() + loss_cnt = AverageMeter() + for batch_id, data in enumerate(batch_reader()): + start = time.time() + im1_data = np.array( + [x[0] for x in data]).astype('float32') + im2_data = np.array( + [x[1] for x in data]).astype('float32') + flo_data = np.array( + [x[2] for x in data]).astype('float32') + step = im1_data.shape[0] + + im_all = np.concatenate((im1_data, im2_data), axis=3).astype(np.float32) + im_all = im_all / 255.0 + im_all = np.swapaxes(np.swapaxes(im_all, 1, 2), 1, 3) + label = flo_data / 20.0 + label = np.swapaxes(np.swapaxes(label, 1, 2), 1, 3) + + im_all = fluid.dygraph.to_variable(im_all) + label = fluid.dygraph.to_variable(label) + # im_all, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] = pad_input( + # im_all) + + end = time.time() + read_data_time = end - start + start = time.time() + network_output = model(im_all, output_more=False) + loss = realEPE(network_output, label) + end = time.time() + loss_cnt.update(loss.numpy()[0], step) + print('val epoch {} batch {}/{} run time: {}s read data time {}s loss {}'.format(epoch, batch_id, batch_num, + round(end - start, 2), + round(read_data_time, 2), + loss.numpy())) + return round(loss_cnt.avg, 4) + + +def train(model, train_batch_reader, adam, epoch, batch_num, args): + loss_type = args.loss + model.train() + for batch_id, data in enumerate(train_batch_reader()): + start = time.time() + im1_data = np.array( + [x[0] for x in data]).astype('float32') + im2_data = np.array( + [x[1] for x in data]).astype('float32') + flo_data = np.array( + [x[2] for x in data]).astype('float32') + im_all = np.concatenate((im1_data, im2_data), axis=3).astype(np.float32) + im_all = im_all / 255.0 + im_all = np.swapaxes(np.swapaxes(im_all, 1, 2), 1, 3) + label = flo_data / 20.0 + label = np.swapaxes(np.swapaxes(label, 1, 2), 1, 3) + if batch_id % 10 == 0: + im1 = im_all[0, :3, :, :] * 255 + im2 = im_all[0, 3:, :, :] * 255 + im1 = np.swapaxes(np.swapaxes(im1, 0, 1), 1, 2).astype(np.uint8) + im2 = np.swapaxes(np.swapaxes(im2, 0, 1), 1, 2).astype(np.uint8) + + flo = label[0, :, :, :] * 20 + flo = np.swapaxes(np.swapaxes(flo, 0, 1), 1, 2) + imsave('./img1.png', im1) + imsave('./img2.png', im2) + flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False) + imsave('./hsv_pd.png', flow_color) + H = im_all[0].shape[1] + W = im_all[0].shape[2] + + im_all = fluid.dygraph.to_variable(im_all) + label = fluid.dygraph.to_variable(label) + im_all, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] = pad_input( + im_all) + + label, _ = pad_input(label) + end = time.time() + read_data_time = end - start + start = time.time() + network_output = model(im_all, output_more=True) + if batch_id % 10 == 0: + flo = network_output[0][0].numpy() * 20.0 + # scale the flow back to the input size + flo = np.swapaxes(np.swapaxes(flo, 0, 1), 1, 2) + flo = flo[intPaddingTop * 2:intPaddingTop * 2 + intHeight * 2, + intPaddingLeft * 2: intPaddingLeft * 2 + intWidth * 2, :] + + u_ = cv2.resize(flo[:, :, 0], (W, H)) + v_ = cv2.resize(flo[:, :, 1], (W, H)) + flo = np.dstack((u_, v_)) + flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False) + imsave('./hsv_predict.png', flow_color) + loss = multiscaleEPE(network_output, label, loss_type, weights=None, sparse=False) + + end = time.time() + loss.backward() + if args.use_multi_gpu: + model.apply_collective_grads() + adam.minimize(loss) + model.clear_gradients() + print('epoch {} batch {}/{} run time: {}s read data time {}s loss {}'.format(epoch, batch_id, batch_num, + round(end - start, 2), + round(read_data_time, 2), + loss.numpy())) + + +def main(): + print(args) + if args.use_multi_gpu: + place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) + else: + place = fluid.CUDAPlace(0) + + with fluid.dygraph.guard(place=place): + if args.use_multi_gpu: + strategy = fluid.dygraph.parallel.prepare_context() + model = PWCDCNet("pwcnet") + if args.pretrained: + print('-----------load pretrained model:', args.pretrained) + pd_pretrain, _ = fluid.dygraph.load_dygraph(args.pretrained) + model.set_dict(pd_pretrain) + + adam = fluid.optimizer.AdamOptimizer(learning_rate=0.0001, regularization=fluid.regularizer.L2DecayRegularizer( + regularization_coeff=0.0004)) + if args.optimize: + print('--------------load pretrained model:', args.optimize) + adam_pretrain, _ = fluid.dygraph.load_dygraph(args.optimize) + adam.set_dict(adam_pretrain) + if args.use_multi_gpu: + model = fluid.dygraph.parallel.DataParallel(model, strategy) + + if args.dataset == 'FlyingChairs': + train_flyingchairs_dataset = FlyingChairs('train', args, is_cropped=True, txt_file=args.train_val_txt, + root=args.data_root) + val_flyingchairs_dataset = FlyingChairs('val', args, is_cropped=False, txt_file=args.train_val_txt, + root=args.data_root) + else: + raise ValueError('dataset name is wrong, please fix it by using args.dataset') + + train_sample_num = len(train_flyingchairs_dataset) + val_sample_num = len(val_flyingchairs_dataset) + print('train sample num: ', train_sample_num) + print('val sample num: ', val_sample_num) + train_reader = reader_flyingchairs(train_flyingchairs_dataset) + val_reader = reader_flyingchairs(val_flyingchairs_dataset) + if args.use_multi_gpu: + train_reader = fluid.contrib.reader.distributed_batch_reader( + train_reader) + val_reader = fluid.contrib.reader.distributed_batch_reader( + val_reader) + BATCH_SIZE = args.batch_size + train_batch_num = round(train_sample_num / BATCH_SIZE) + val_batch_num = round(val_sample_num / BATCH_SIZE) + train_batch_reader = paddle.batch(paddle.reader.shuffle(train_reader, buf_size=BATCH_SIZE * 100), BATCH_SIZE, + drop_last=True) + val_batch_reader = paddle.batch(val_reader, BATCH_SIZE, drop_last=False) + epoch_num = args.numEpoch + val_value = 100000000 + rm_best_model = "" + + for epoch in range(epoch_num): + train(model, train_batch_reader, adam, epoch, train_batch_num, args) + pd_save_dir = args.model_out_dir + if not os.path.exists(pd_save_dir): + os.makedirs(pd_save_dir) + pd_model_save = os.path.join(pd_save_dir, 'epoch_' + str(epoch) + "_pwc_net_paddle") + rm_dir = os.path.join(pd_save_dir, 'epoch_' + str(epoch - 1) + "_pwc_net_paddle.pdparams") + if os.path.exists(rm_dir): + os.remove(rm_dir) + if args.use_multi_gpu: + if fluid.dygraph.parallel.Env().local_rank == 0: + fluid.dygraph.save_dygraph(model.state_dict(), pd_model_save) + fluid.dygraph.save_dygraph(adam.state_dict(), os.path.join(pd_save_dir, 'adam')) + else: + fluid.dygraph.save_dygraph(model.state_dict(), pd_model_save) + fluid.dygraph.save_dygraph(adam.state_dict(), os.path.join(pd_save_dir, 'adam')) + val_loss_value = val(model, val_batch_reader, epoch, val_batch_num) + if val_loss_value < val_value: + best_model = os.path.join(pd_save_dir, "pwc_net_paddle_" + str(val_loss_value) + '.pdparams') + os.link(pd_model_save + '.pdparams', best_model) + if os.path.exists(rm_best_model): + os.remove(rm_best_model) + rm_best_model = best_model + val_value = val_loss_value + + +if __name__ == '__main__': + main() + + + diff --git a/PaddleCV/Research/PWCNet/train.sh b/PaddleCV/Research/PWCNet/train.sh new file mode 100755 index 0000000000000000000000000000000000000000..7c2b7226bef96ebdbbe6c768255a8419e0d32de0 --- /dev/null +++ b/PaddleCV/Research/PWCNet/train.sh @@ -0,0 +1,4 @@ +#!/usr/bin/env bash +python3 train.py --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data +# use multi gpus NEED TO DO LATER +#python3 -m paddle.distributed.launch --selected_gpus=0,1 --log_dir ./mylog train.py --use_multi_gpu --batch_size 20 --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data