提交 ac4aa52d 编写于 作者: H HeZheng 提交者: ceci3

PWCNet reimplement using paddlepaddle DyGraph (#4105)

* add pwcnet
上级 eb60742f
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
# PWCNet reimplement using paddlepaddle DyGraph
PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume.
# Environment
```
cenntos7
paddle develop version (after 20191201) install from source
python3.7
SciPy 1.1.0
```
code will update for paddle v1.7 later.
# Compile correlation op
```
cd correlation_op
sh make.sh
```
# Datasets
1.Please download the `FlyingChairs dataset` and `FlyingChairs_train_val.txt` from https://lmb.informatik.uni-freiburg.de/resources/datasets
Or you can use `./data/download.sh` to download datasets.
We split the data to train and val by using `FlyingChairs_train_val.txt` with `1 for train and 2 for val`.
# Inference
Note that the paddle models `pwc_net_paddle.pdparams` and `pwc_net_chairs_paddle.pdparams` are transferred from the pytorch pth files `pwc_net.pth.tar` and `pwc_net_chairs.pth.tar`.
Run
```
python infer.py
```
| Input img1 | Input img2 |
|-------|------------|
| <img src='data/frame_0010.png' width=500> | <img src='data/frame_0011.png' width=500> |
|prediction with pwc_net_paddle.pdparams| prediction with pwc_net_chairs_paddle.pdparams|
|-------------|-------------|
|<img src='tmp/hsv_pd.png' width=500> | <img src='tmp/hsv_pd_chairs.png' width=500> |
# First Train with L2 loss
A single gpu is supported. Multi gpus will be supported later.
You should check parameters in `my_args.py` as you like.
And change them in `train.sh`.
```
--data_root
--train_val_txt
--batch_size
```
Then run
```
./train.sh
```
Some results during training can be seen
```
./img1.png
./img2.png
./hsv_pd.png # ground truth
./hsv_predict.png # output of model
```
# Finetune with L1 loss
finetune from your best pretrain model by adding --pretrained your_best_model_name eg. `--pretrained epoch_7_pwc_net_paddle`
Run
```
./finetune.sh
```
# Note
This code reimplement PWCNet like the code of `https://github.com/NVlabs/PWC-Net`
If you want to want to train like the paper
```
@InProceedings{Sun2018PWC-Net,
author = {Deqing Sun and Xiaodong Yang and Ming-Yu Liu and Jan Kautz},
title = {{PWC-Net}: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume},
booktitle = CVPR,
year = {2018},
}
```
Please use all the datasets in `./data/download.sh` if you like. And use the code in `./data/datasets.py`.
Reference works
```
https://github.com/NVlabs/PWC-Net
https://github.com/ClementPinard/FlowNetPytorch
https://github.com/NVIDIA/flownet2-pytorch/blob/master/datasets.py
```
\ No newline at end of file
自定义OP编译:
1. 使用paddle develop 12月1日之后的版本
2. sh make.sh编译成correlation_lib.so动态库
3. 添加动态库路径到LD_LIBRARY_PATH:
```
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:`python3.7 -c 'import paddle; print(paddle.sysconfig.get_lib())'`
```
4. 添加correlation op的python路径:
```
export PYTHONPATH=$PYTHONPATH:`pwd`
```
5. python test_correlation.py运行单测,验证是否加载成功。
PS: 如果paddle whl包是从官网上下载的,需要使用gcc 4.8,即把make.sh中的g++ 改为 g++-4.8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import os
file_dir = os.path.dirname(os.path.abspath(__file__))
fluid.load_op_library(os.path.join(file_dir, 'correlation_lib.so'))
from paddle.fluid.layer_helper import LayerHelper
def correlation(input1, input2, pad_size, kernel_size, max_displacement, stride1, stride2, corr_type_multiply=1):
helper = LayerHelper("correlation", **locals())
output = helper.create_variable_for_type_inference(dtype=input1.dtype)
helper.append_op(type="correlation", inputs={"Input1": input1, "Input2": input2}, attrs={"pad_size": pad_size, "kernel_size": kernel_size, "max_displacement": max_displacement, "stride1": stride1, "stride2": stride2, "corr_type_multiply": corr_type_multiply}, outputs = {"Output": output})
return output
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
inline std::vector<int64_t> CorrelationOutputSize(int batch, int input_height, int input_width, int stride1, int stride2, int kernel_size, int pad_size, int max_displacement) {
std::vector<int64_t> output_shape({batch});
int kernel_radius = (kernel_size - 1) / 2;
int border_radius = kernel_radius + max_displacement;
int padded_input_height = input_height + 2 * pad_size;
int padded_input_width = input_width + 2 * pad_size;
int output_channel = ((max_displacement/stride2) * 2 + 1) * ((max_displacement/stride2) * 2 + 1);
output_shape.push_back(output_channel);
int output_height = std::ceil(static_cast<float>(padded_input_height - 2 * border_radius) / static_cast<float>(stride1));
int output_width = std::ceil(static_cast<float>(padded_input_width - 2 * border_radius) / static_cast<float>(stride1));
output_shape.push_back(output_height);
output_shape.push_back(output_width);
return output_shape;
}
class CorrelationOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override{
AddInput("Input1", "input1");
AddInput("Input2", "input2");
AddOutput("Output", "output");
AddAttr<int>("pad_size", "pad size for input1 and input2");
AddAttr<int>("kernel_size", "kernel size of input1 and input2");
AddAttr<int>("max_displacement", "max displacement of input1 and input2");
AddAttr<int>("stride1", "Input1 stride");
AddAttr<int>("stride2", "Input2 stride");
AddAttr<int>("corr_type_multiply", "correlation coefficient").SetDefault(1);
AddComment(R"DOC(Correlation of two feature map. Only support NCHW data format.)DOC");
}
};
class CorrelationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override{
PADDLE_ENFORCE_EQ(ctx->HasInput("Input1"), true, "Input(input1) cannot be null");
PADDLE_ENFORCE_EQ(ctx->HasInput("Input2"), true, "Input(input2) cannot be null");
int stride1 = ctx->Attrs().Get<int>("stride1");
int stride2 = ctx->Attrs().Get<int>("stride2");
int max_displacement = ctx->Attrs().Get<int>("max_displacement");
int pad_size = ctx->Attrs().Get<int>("pad_size");
int kernel_size = ctx->Attrs().Get<int>("kernel_size");
auto in_dims = ctx->GetInputDim("Input1");
auto in2_dims = ctx->GetInputDim("Input2");
PADDLE_ENFORCE_EQ(in_dims.size() == 4, true, "input1 must be 4-dims");
PADDLE_ENFORCE_EQ(in2_dims.size() == 4, true, "input2 must be 4-dims");
std::vector<int64_t> output_shape = CorrelationOutputSize(in_dims[0], in_dims[2], in_dims[3], stride1, stride2, kernel_size, pad_size, max_displacement);
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override{
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input1");
PADDLE_ENFORCE_EQ(input_data_type, ctx.Input<Tensor>("Input2")->type(), "Input1 and Input2 shoule have same type");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
template <typename T>
class CorrelationOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
std::unique_ptr<T> Apply() const override {
auto* op = new T();
op->SetType("correlation_grad");
op->SetInput("Input1", this->Input("Input1"));
op->SetInput("Input2", this->Input("Input2"));
op->SetInput(framework::GradVarName("Output"), this->OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input1"), this->InputGrad("Input1"));
op->SetOutput(framework::GradVarName("Input2"), this->InputGrad("Input2"));
op->SetAttrMap(this->Attrs());
return std::unique_ptr<T>(op);
}
};
class CorrelationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override{
PADDLE_ENFORCE_EQ(ctx->HasInput("Input1"), true, "Input(Input1) should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput("Input2"), true, "Input(Input2) should not be null");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Output")), true, "Input(Output@GRAD) should not be null");
auto in1_dims = ctx->GetInputDim("Input1");
auto in2_dims = ctx->GetInputDim("Input2");
ctx->SetOutputDim(framework::GradVarName("Input1"), in1_dims);
ctx->SetOutputDim(framework::GradVarName("Input2"), in1_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override{
const auto* var = ctx.InputVar(framework::GradVarName("Output"));
if (var == nullptr) {
PADDLE_THROW("cannot find Output@GRAD");
}
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(ctx, "Input1"), ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(correlation, ops::CorrelationOp, ops::CorrelationOpMaker,
ops::CorrelationOpGradMaker<paddle::framework::OpDesc>,
ops::CorrelationOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(correlation_grad, ops::CorrelationOpGrad);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#define THREADS_PER_BLOCK 32
#define FULL_MASK 0xffffffff
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(FULL_MASK, val, offset);
}
return val;
}
template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize;
val = warpReduceSum(val);
if (lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : 0;
if (wid == 0)
val = warpReduceSum(val);
return val;
}
template <typename T>
__global__ void set_zero(T *x, int num) {
for(int i = blockIdx.x * blockDim.x + threadIdx.x; i < num; i += blockDim.x * gridDim.x)
x[i] = static_cast<T>(0);
}
template <typename T>
__global__ void channel_first(const T *input, T *rinput, const int channel, const int height, const int width, const int pad_size) {
int n = blockIdx.x;
int h = blockIdx.y;
int w = blockIdx.z;
int ch_off = threadIdx.x;
T value;
int dimchw = channel * height * width;
int dimhw = height * width;
int p_dimw = (width + 2 * pad_size);
int p_dimh = (height + 2 * pad_size);
int p_dimchw = channel * p_dimw * p_dimh;
int p_dimcw = channel * p_dimw;
for (int c = ch_off; c < channel; c += THREADS_PER_BLOCK) {
value = input[n * dimchw + c * dimhw + h * width + w];
rinput[n * p_dimchw + (h + pad_size) * p_dimcw + (w + pad_size) * channel + c] = value;
}
}
template <typename T>
__global__ void correlation_forward(T *output, const int output_channel, const int output_height, const int output_width, const T *rinput1, const int input_channel, const int input_height, const int input_width, const T *rinput2, const int pad_size, const int kernel_size, const int max_displacement, const int stride1, const int stride2) {
int p_input_width = input_width + 2 * pad_size;
int p_input_height = input_height + 2 * pad_size;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int n = blockIdx.x;
int h1 = blockIdx.y * stride1 + max_displacement;
int w1 = blockIdx.z * stride1 + max_displacement;
int c = threadIdx.x;
int p_dimchw = p_input_height * p_input_width * input_channel;
int p_dimcw = p_input_width * input_channel;
int p_dimc = input_channel;
int t_dimchw = output_channel * output_height * output_width;
int t_dimhw = output_height * output_width;
int t_dimw = output_width;
int nelems = kernel_size * kernel_size * p_dimc;
for (int tj = -displacement_rad; tj <= displacement_rad; ++tj) {
for(int ti = -displacement_rad; ti <= displacement_rad; ++ti) {
int w2 = w1 + ti * stride2;
int h2 = h1 + tj * stride2;
T acc0 = 0;
for(int j = -kernel_rad; j <= kernel_rad; ++j) {
for(int i = -kernel_rad; i <= kernel_rad; ++i) {
for(int ch = c; ch < p_dimc; ch += blockDim.x) {
int index1 = n * p_dimchw + (h1 + j) * p_dimcw + (w1 + i) * p_dimc + ch;
int index2 = n * p_dimchw + (h2 + j) * p_dimcw + (w2 + i) * p_dimc + ch;
acc0 += static_cast<T>(rinput1[index1] * rinput2[index2]);
}
}
}
if (blockDim.x == warpSize) {
__syncwarp();
acc0 = warpReduceSum(acc0);
} else {
__syncthreads();
acc0 = blockReduceSum(acc0);
}
if (threadIdx.x == 0) {
int tc = (tj + displacement_rad) * displacement_size + (ti + displacement_rad);
const int t_index = n * t_dimchw + tc * t_dimhw + blockIdx.y * t_dimw + blockIdx.z;
output[t_index] = static_cast<T>(acc0 / nelems);
}
}
}
}
//class CorrelationKernel<platform::CUDADeviceContext, T>
template <typename T>
class CorrelationKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must be CUDAPlace");
auto *input1 = ctx.Input<Tensor>("Input1");
auto *input2 = ctx.Input<Tensor>("Input2");
int pad_size = ctx.Attr<int>("pad_size");
int kernel_size = ctx.Attr<int>("kernel_size");
int stride1 = ctx.Attr<int>("stride1");
int stride2 = ctx.Attr<int>("stride2");
int max_displacement = ctx.Attr<int>("max_displacement");
int corr_type_multiply = ctx.Attr<int>("corr_type_multiply");
auto *output = ctx.Output<Tensor>("Output");
output->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
// base on input1, NCHW
auto in_dims = input1->dims();
int N = in_dims[0];
int C = in_dims[1];
int H = in_dims[2];
int W = in_dims[3];
int padded_input_height = H + 2 * pad_size;
int padded_input_width = W + 2 * pad_size;
Tensor rinput1 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput1.mutable_data<T>(ctx.GetPlace());
Tensor rinput2 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput2.mutable_data<T>(ctx.GetPlace());
set_zero<<<(rinput1.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput1.data<T>(), rinput1.numel());
set_zero<<<(rinput2.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput2.data<T>(), rinput2.numel());
set_zero<<<(output->numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(output->data<T>(), output->numel());
auto out_dims = output->dims();
int OC = out_dims[1];
int OH = out_dims[2];
int OW = out_dims[3];
dim3 blocks_grid(N, H, W);
dim3 threads_block(THREADS_PER_BLOCK);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(input1->data<T>(), rinput1.data<T>(), C, H, W, pad_size);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(input2->data<T>(), rinput2.data<T>(), C, H, W, pad_size);
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(N, OH, OW);
correlation_forward<T><<<totalBlocksCorr, threadsPerBlock, 0, dev_ctx.stream()>>>(output->data<T>(), OC, OH, OW, rinput1.data<T>(),
C, H, W, rinput2.data<T>(), pad_size, kernel_size, max_displacement, stride1, stride2);
}
};
template <typename T>
__global__ void correlation_backward_input1(int item, T *grad_input1, const int input_channel, const int input_height, const int input_width, const T *grad_output, const int output_channel, const int output_height, const int output_width, const T *rinput2, const int pad_size, const int kernel_size, const int max_displacement, const int stride1, const int stride2) {
int n = item;
int h = blockIdx.x * stride1 + pad_size;
int w = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int xmin = (w - kernel_rad - max_displacement) / stride1;
int ymin = (h - kernel_rad - max_displacement) / stride1;
int xmax = (w + kernel_rad - max_displacement) / stride1;
int ymax = (h + kernel_rad - max_displacement) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) {
return;
}
if (xmin > xmax || ymin > ymax) {
return;
}
xmin = max(0, xmin);
xmax = min(output_width - 1, xmax);
ymin = max(0, ymin);
ymax = min(output_height - 1, ymax);
int p_input_width = input_width + 2 * pad_size;
int p_input_height = input_height + 2 * pad_size;
int p_dimchw = input_channel * p_input_height * p_input_width;
int p_dimcw = input_channel * p_input_width;
int p_dimc = input_channel;
int t_dimchw = output_channel * output_height * output_width;
int t_dimhw = output_height * output_width;
int t_dimw = output_width;
int o_dimchw = input_channel * input_height * input_width;
int o_dimhw = input_height * input_width;
int o_dimw = input_width;
int nelems = kernel_size * kernel_size * input_channel;
__shared__ T prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int index2 = n * p_dimchw + (h + j2) * p_dimcw + (w + i2) * p_dimc + c;
T val2 = rinput2[index2];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i;
prod_sum[tch_off] += grad_output[t_index] * val2;
}
}
}
__syncthreads();
if (tch_off == 0) {
T reduce_sum = 0;
for (int index = 0; index < THREADS_PER_BLOCK; index++) {
reduce_sum += prod_sum[index];
}
const int index1 = n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size);
grad_input1[index1] = static_cast<T>(reduce_sum / nelems);
}
}
template <typename T>
__global__ void correlation_backward_input2(int item, T *grad_input2, const int input_channel, const int input_height, const int input_width, const T *grad_output, const int output_channel, const int output_height, const int output_width, const T *rinput1, const int pad_size, const int kernel_size, const int max_displacement, const int stride1, const int stride2){
int n = item;
int h = blockIdx.x * stride1 + pad_size;
int w = blockIdx.y * stride1 + pad_size;
int c = blockIdx.z;
int tch_off = threadIdx.x;
int kernel_rad = (kernel_size - 1) / 2;
int displacement_rad = max_displacement / stride2;
int displacement_size = 2 * displacement_rad + 1;
int p_input_width = input_width + 2 * pad_size;
int p_input_height = input_height + 2 * pad_size;
int p_dimchw = input_channel * p_input_height * p_input_width;
int p_dimcw = input_channel * p_input_width;
int p_dimc = input_channel;
int t_dimchw = output_channel * output_height * output_width;
int t_dimhw = output_height * output_width;
int t_dimw = output_width;
int o_dimchw = input_channel * input_height * input_width;
int o_dimhw = input_height * input_width;
int o_dimw = input_width;
int nelems = kernel_size * kernel_size * input_channel;
__shared__ T prod_sum[THREADS_PER_BLOCK];
prod_sum[tch_off] = 0;
for (int tc = tch_off; tc < output_channel; tc += THREADS_PER_BLOCK) {
int i2 = (tc % displacement_size - displacement_rad) * stride2;
int j2 = (tc / displacement_size - displacement_rad) * stride2;
int xmin = (w - kernel_rad - max_displacement - i2) / stride1;
int ymin = (h - kernel_rad - max_displacement - j2) / stride1;
int xmax = (w + kernel_rad - max_displacement - i2) / stride1;
int ymax = (h + kernel_rad - max_displacement - j2) / stride1;
if (xmax < 0 || ymax < 0 || xmin >= output_width || ymin >= output_height) {
continue;
}
if (xmin > xmax || ymin > ymax) {
continue;
}
xmin = max(0, xmin);
xmax = min(output_width - 1, xmax);
ymin = max(0, ymin);
ymax = min(output_height - 1, ymax);
int index1 = n * p_dimchw + (h - j2) * p_dimcw + (w - i2) * p_dimc + c;
T val1 = rinput1[index1];
for (int j = ymin; j <= ymax; ++j) {
for (int i = xmin; i <= xmax; ++i) {
int t_index = n * t_dimchw + tc * t_dimhw + j * t_dimw + i;
prod_sum[tch_off] += grad_output[t_index] * val1;
}
}
}
__syncthreads();
if (tch_off == 0) {
T reduce_sum = 0;
for (int index = 0; index < THREADS_PER_BLOCK; index++) {
reduce_sum += prod_sum[index];
}
const int index2 = n * o_dimchw + c * o_dimhw + (h - pad_size) * o_dimw + (w - pad_size);
grad_input2[index2] = static_cast<T>(reduce_sum / nelems);
}
}
template <typename T>
class CorrelationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true, "It must use CUDAPlace.");
const auto *input1 = ctx.Input<Tensor>("Input1");
const auto *input2 = ctx.Input<Tensor>("Input2");
const auto *grad_output = ctx.Input<Tensor>(framework::GradVarName("Output"));
const int pad_size = ctx.Attr<int>("pad_size");
const int kernel_size = ctx.Attr<int>("kernel_size");
const int stride1 = ctx.Attr<int>("stride1");
const int stride2 = ctx.Attr<int>("stride2");
const int max_displacement = ctx.Attr<int>("max_displacement");
const int corr_type_multiply = ctx.Attr<int>("corr_type_multiply");
auto *grad_input1 = ctx.Output<Tensor>(framework::GradVarName("Input1"));
grad_input1->mutable_data<T>(ctx.GetPlace());
auto *grad_input2 = ctx.Output<Tensor>(framework::GradVarName("Input2"));
grad_input2->mutable_data<T>(ctx.GetPlace());
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in_dims = input1->dims();
int N = in_dims[0];
int C = in_dims[1];
int H = in_dims[2];
int W = in_dims[3];
int padded_input_height = H + 2 * pad_size;
int padded_input_width = W + 2 * pad_size;
Tensor rinput1 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput1.mutable_data<T>(ctx.GetPlace());
Tensor rinput2 = ctx.AllocateTmpTensor<T, platform::CUDADeviceContext>({N, padded_input_height, padded_input_width, C}, dev_ctx);
rinput2.mutable_data<T>(ctx.GetPlace());
set_zero<<<(rinput1.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput1.data<T>(), rinput1.numel());
set_zero<<<(rinput2.numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(rinput2.data<T>(), rinput2.numel());
set_zero<<<(grad_input1->numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(grad_input1->data<T>(), grad_input1->numel());
set_zero<<<(grad_input2->numel() + 512 - 1)/512, 512, 0, dev_ctx.stream()>>>(grad_input2->data<T>(), grad_input2->numel());
auto grad_out_dims = grad_output->dims();
int GOC = grad_out_dims[1];
int GOH = grad_out_dims[2];
int GOW = grad_out_dims[3];
dim3 blocks_grid(N, H, W);
dim3 threads_block(THREADS_PER_BLOCK);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(input1->data<T>(), rinput1.data<T>(), C, H, W, pad_size);
channel_first<T><<<blocks_grid, threads_block, 0, dev_ctx.stream()>>>(input2->data<T>(), rinput2.data<T>(), C, H, W, pad_size);
dim3 threadsPerBlock(THREADS_PER_BLOCK);
dim3 totalBlocksCorr(H, W, C);
for (int n = 0; n < N; n++) {
correlation_backward_input1<T><<<totalBlocksCorr, threadsPerBlock, 0, dev_ctx.stream()>>>(n, grad_input1->data<T>(), C, H, W, grad_output->data<T>(), GOC, GOH, GOW, rinput2.data<T>(), pad_size, kernel_size, max_displacement, stride1, stride2);
}
for (int n = 0; n < N; n++) {
correlation_backward_input2<T><<<totalBlocksCorr, threadsPerBlock, 0, dev_ctx.stream()>>>(n, grad_input2->data<T>(), C, H, W, grad_output->data<T>(), GOC, GOH, GOW, rinput1.data<T>(), pad_size, kernel_size, max_displacement, stride1, stride2);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
correlation, ops::CorrelationKernel<float>,
ops::CorrelationKernel<double>);
REGISTER_OP_CUDA_KERNEL(
correlation_grad, ops::CorrelationGradKernel<float>,
ops::CorrelationGradKernel<double>);
include_dir=$( python3.7 -c 'import paddle; print(paddle.sysconfig.get_include())' )
lib_dir=$( python3.7 -c 'import paddle; print(paddle.sysconfig.get_lib())' )
echo $include_dir
echo $lib_dir
OPS='correlation_op'
for op in ${OPS}
do
nvcc ${op}.cu -c -o ${op}.cu.o -ccbin cc -DPADDLE_WITH_CUDA -DEIGEN_USE_GPU -DPADDLE_USE_DSO -DPADDLE_WITH_MKLDNN -Xcompiler -fPIC -std=c++11 -Xcompiler -fPIC -w --expt-relaxed-constexpr -O0 -g -DNVCC \
-I ${include_dir}/third_party/ \
-I ${include_dir}
done
##g++-4.8 correlation_op.cu.o correlation_op.cc -o correlation_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \
g++ correlation_op.cu.o correlation_op.cc -o correlation_lib.so -DPADDLE_WITH_MKLDNN -shared -fPIC -std=c++11 -O0 -g \
-I ${include_dir}/third_party/ \
-I ${include_dir} \
-L ${lib_dir} \
-L /usr/local/cuda/lib64 -lpaddle_framework -lcudart
rm *.cu.o
import unittest
from correlation import correlation
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.base import to_variable
def corr(x_1, x_2, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1, corr_multiply=1):
K = kernel_size
# rinput1 = np.pad(x_1, tuple([pad_size for _ in range(4)]), mode='constant').transpose(1, 2).transpose(2, 3)
# rinput2 = np.pad(x_2, tuple([pad_size for _ in range(4)]), mode='constant').transpose(1, 2).transpose(2, 3)
rinput1 = np.pad(x_1, ((0, 0), (0, 0), (pad_size, pad_size), (pad_size, pad_size)), mode='constant')
rinput2 = np.pad(x_2, ((0, 0), (0, 0), (pad_size, pad_size), (pad_size, pad_size)), mode='constant')
rinput1 = np.transpose(rinput1, (0, 2, 3, 1))
rinput2 = np.transpose(rinput2, (0, 2, 3, 1))
B = int(rinput1.shape[0])
H = int(x_1.shape[2])
W = int(x_2.shape[3])
d = max_displacement
D = 2 * d + 1
output = np.zeros((B, D * D, H, W), dtype=np.float32)
for b in range(B):
for i in range(H):
for j in range(W):
for k in range(-d, d + 1):
for l in range(-d, d + 1):
x1_index = i + pad_size
y1_index = j + pad_size
x2_index = x1_index + k
y2_index = y1_index + l
output[b, l + d + D * (k + d), i, j] = np.mean(
rinput1[b, x1_index:x1_index + K, y1_index:y1_index + K] * rinput2[b,
x2_index:x2_index + K,
y2_index:y2_index + K])
return output
class TestCorrelationOp(unittest.TestCase):
def test_check_output(self):
#x_shape = (1, 196, 3, 3)
np.random.seed(13)
np.set_printoptions(threshold=np.inf)
x_shape = (2, 10, 3, 3)
x_type = 'float32'
x1 = fluid.layers.data(name='x1', shape=x_shape, dtype=x_type, append_batch_size=False)
x2 = fluid.layers.data(name='x2', shape=x_shape, dtype=x_type, append_batch_size=False)
x1_np = np.random.randn(2,3,4,5).astype(x_type)
x2_np = np.random.randn(2,3,4,5).astype(x_type)
out_np = corr(x1_np, x2_np, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1)
out = correlation(x1, x2, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
res = exe.run(feed={'x1':x1_np, 'x2':x2_np}, fetch_list=[out.name])
self.assertTrue(np.allclose(res[0], out_np))
class Net(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(Net, self).__init__(name_scope)
def forward(self, x1, x2):
y = correlation(x1, x2, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1)
return y
class TestCorrelationOpDyGraph(unittest.TestCase):
def test_check_output(self):
np.random.seed(13)
np.set_printoptions(threshold=np.inf)
x_shape = (2, 10, 3, 3)
x_type = 'float32'
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
x1_np = np.random.randn(2,3,4,5).astype(x_type)
x2_np = np.random.randn(2,3,4,5).astype(x_type)
out_np = corr(x1_np, x2_np, pad_size=4, kernel_size=1, max_displacement=4, stride1=1, stride2=1)
x1 = to_variable(x1_np)
x2 = to_variable(x2_np)
corr_pd = Net('corr_pd')
y = corr_pd(x1, x2)
out = y.numpy()
self.assertTrue(np.allclose(out, out_np))
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# @FileName: datasets.py reference https://github.com/NVIDIA/flownet2-pytorch/blob/master/datasets.py
import paddle
import paddle.fluid as fluid
import numpy as np
import argparse
import os, math, random
import sys
from os.path import *
import numpy as np
from glob import glob
sys.path.append('../')
import data.utils.frame_utils as frame_utils
from scipy.misc import imsave
from src import flow_vis
from src.read_files import read_txt_to_index
class StaticRandomCrop(object):
def __init__(self, image_size, crop_size):
self.th, self.tw = crop_size
h, w = image_size
self.h1 = random.randint(0, h - self.th)
self.w1 = random.randint(0, w - self.tw)
def __call__(self, img):
return img[self.h1:(self.h1 + self.th), self.w1:(self.w1 + self.tw), :]
class StaticCenterCrop(object):
def __init__(self, image_size, crop_size):
self.th, self.tw = crop_size
self.h, self.w = image_size
def __call__(self, img):
return img[(self.h - self.th) // 2:(self.h + self.th) // 2, (self.w - self.tw) // 2:(self.w + self.tw) // 2, :]
class MpiSintel(object):
def __init__(self, args, is_cropped=False, root='', dstype='clean', replicates=1):
self.args = args
self.is_cropped = is_cropped
self.crop_size = args.crop_size
self.render_size = args.inference_size
self.replicates = replicates
flow_root = join(root, 'flow')
image_root = join(root, dstype)
file_list = sorted(glob(join(flow_root, '*/*.flo')))
self.flow_list = []
self.image_list = []
for file in file_list:
if 'test' in file:
# print file
continue
fbase = file[len(flow_root) + 1:]
fprefix = fbase[:-8]
fnum = int(fbase[-8:-4])
img1 = join(image_root, fprefix + "%04d" % (fnum + 0) + '.png')
img2 = join(image_root, fprefix + "%04d" % (fnum + 1) + '.png')
if not isfile(img1) or not isfile(img2) or not isfile(file):
continue
self.image_list += [[img1, img2]]
self.flow_list += [file]
self.size = len(self.image_list)
self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or (
self.frame_size[1] % 64):
self.render_size[0] = ((self.frame_size[0]) // 64) * 64
self.render_size[1] = ((self.frame_size[1]) // 64) * 64
args.inference_size = self.render_size
assert (len(self.image_list) == len(self.flow_list))
def __getitem__(self, index):
index = index % self.size
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = frame_utils.read_gen(self.flow_list[index])
images = [img1, img2]
image_size = img1.shape[:2]
if self.is_cropped:
cropper = StaticRandomCrop(image_size, self.crop_size)
else:
cropper = StaticCenterCrop(image_size, self.render_size)
images = list(map(cropper, images))
flow = cropper(flow)
images = np.array(images).transpose(3, 0, 1, 2)
flow = flow.transpose(2, 0, 1)
return [images], [flow]
def __len__(self):
return self.size * self.replicates
class MpiSintelClean(MpiSintel):
def __init__(self, args, is_cropped=False, root='', replicates=1):
super(MpiSintelClean, self).__init__(args, is_cropped=is_cropped, root=root, dstype='clean',
replicates=replicates)
class MpiSintelFinal(MpiSintel):
def __init__(self, args, is_cropped=False, root='', replicates=1):
super(MpiSintelFinal, self).__init__(args, is_cropped=is_cropped, root=root, dstype='final',
replicates=replicates)
class FlyingChairs(object):
def __init__(self, train_val, args, is_cropped, txt_file, root='/path/to/FlyingChairs_release/data', replicates=1):
self.args = args
self.is_cropped = is_cropped
self.crop_size = args.crop_size
self.render_size = args.inference_size
self.replicates = replicates
images = sorted(glob(join(root, '*.ppm')))
flow_list = sorted(glob(join(root, '*.flo')))
assert (len(images) // 2 == len(flow_list))
image_list = []
for i in range(len(flow_list)):
im1 = images[2 * i]
im2 = images[2 * i + 1]
image_list += [[im1, im2]]
assert len(image_list) == len(flow_list)
if train_val == 'train':
intindex = np.array(read_txt_to_index(txt_file))
image_list = np.array(image_list)
image_list = image_list[intindex == 1]
image_list = image_list.tolist()
flow_list = np.array(flow_list)
flow_list = flow_list[intindex == 1]
flow_list = flow_list.tolist()
assert len(image_list) == len(flow_list)
elif train_val == 'val':
intindex = np.array(read_txt_to_index(txt_file))
image_list = np.array(image_list)
image_list = image_list[intindex == 2]
image_list = image_list.tolist()
flow_list = np.array(flow_list)
flow_list = flow_list[intindex == 2]
flow_list = flow_list.tolist()
assert len(image_list) == len(flow_list)
else:
raise ValueError('FlyingChairs_train_val.txt not found for txt_file ......')
self.flow_list = flow_list
self.image_list = image_list
self.size = len(self.image_list)
self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or (
self.frame_size[1] % 64):
self.render_size[0] = ((self.frame_size[0]) // 64) * 64
self.render_size[1] = ((self.frame_size[1]) // 64) * 64
args.inference_size = self.render_size
def __getitem__(self, index):
index = index % self.size
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = frame_utils.read_gen(self.flow_list[index])
images = [img1, img2]
image_size = img1.shape[:2]
if self.is_cropped:
cropper = StaticRandomCrop(image_size, self.crop_size)
else:
cropper = StaticCenterCrop(image_size, self.render_size)
images = list(map(cropper, images))
flow = cropper(flow)
images = np.array(images).transpose(3, 0, 1, 2)
flow = flow.transpose(2, 0, 1)
return [images], [flow]
def __len__(self):
return self.size * self.replicates
def reader_flyingchairs(dataset):
n = len(dataset)
def reader():
for i in range(n):
a, b = dataset[i]
yield a[0][:,0,:,:].transpose(1,2,0), a[0][:,1,:,:].transpose(1,2,0), b[0].transpose(1, 2, 0)# a single entry of data is created each time
return reader
class FlyingThings(object):
def __init__(self, args, is_cropped, root='/path/to/flyingthings3d', dstype='frames_cleanpass', replicates=1):
self.args = args
self.is_cropped = is_cropped
self.crop_size = args.crop_size
self.render_size = args.inference_size
self.replicates = replicates
image_dirs = sorted(glob(join(root, dstype, 'TRAIN/*/*')))
image_dirs = sorted([join(f, 'left') for f in image_dirs] + [join(f, 'right') for f in image_dirs])
flow_dirs = sorted(glob(join(root, 'optical_flow_flo_format/TRAIN/*/*')))
flow_dirs = sorted(
[join(f, 'into_future/left') for f in flow_dirs] + [join(f, 'into_future/right') for f in flow_dirs])
assert (len(image_dirs) == len(flow_dirs))
self.image_list = []
self.flow_list = []
for idir, fdir in zip(image_dirs, flow_dirs):
images = sorted(glob(join(idir, '*.png')))
flows = sorted(glob(join(fdir, '*.flo')))
for i in range(len(flows)):
self.image_list += [[images[i], images[i + 1]]]
self.flow_list += [flows[i]]
assert len(self.image_list) == len(self.flow_list)
self.size = len(self.image_list)
self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or (
self.frame_size[1] % 64):
self.render_size[0] = ((self.frame_size[0]) // 64) * 64
self.render_size[1] = ((self.frame_size[1]) // 64) * 64
args.inference_size = self.render_size
def __getitem__(self, index):
index = index % self.size
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = frame_utils.read_gen(self.flow_list[index])
images = [img1, img2]
image_size = img1.shape[:2]
if self.is_cropped:
cropper = StaticRandomCrop(image_size, self.crop_size)
else:
cropper = StaticCenterCrop(image_size, self.render_size)
images = list(map(cropper, images))
flow = cropper(flow)
images = np.array(images).transpose(3, 0, 1, 2)
flow = flow.transpose(2, 0, 1)
return [images], [flow]
def __len__(self):
return self.size * self.replicates
class FlyingThingsClean(FlyingThings):
def __init__(self, args, is_cropped=False, root='', replicates=1):
super(FlyingThingsClean, self).__init__(args, is_cropped=is_cropped, root=root, dstype='frames_cleanpass',
replicates=replicates)
class FlyingThingsFinal(FlyingThings):
def __init__(self, args, is_cropped=False, root='', replicates=1):
super(FlyingThingsFinal, self).__init__(args, is_cropped=is_cropped, root=root, dstype='frames_finalpass',
replicates=replicates)
class ChairsSDHom(object):
def __init__(self, args, is_cropped, root='/path/to/chairssdhom/data', dstype='train', replicates=1):
self.args = args
self.is_cropped = is_cropped
self.crop_size = args.crop_size
self.render_size = args.inference_size
self.replicates = replicates
image1 = sorted(glob(join(root, dstype, 't0/*.png')))
image2 = sorted(glob(join(root, dstype, 't1/*.png')))
self.flow_list = sorted(glob(join(root, dstype, 'flow/*.flo')))
assert (len(image1) == len(self.flow_list))
self.image_list = []
for i in range(len(self.flow_list)):
im1 = image1[i]
im2 = image2[i]
self.image_list += [[im1, im2]]
assert len(self.image_list) == len(self.flow_list)
self.size = len(self.image_list)
self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or (
self.frame_size[1] % 64):
self.render_size[0] = ((self.frame_size[0]) // 64) * 64
self.render_size[1] = ((self.frame_size[1]) // 64) * 64
args.inference_size = self.render_size
def __getitem__(self, index):
index = index % self.size
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
flow = frame_utils.read_gen(self.flow_list[index])
flow = flow[::-1, :, :]
images = [img1, img2]
image_size = img1.shape[:2]
if self.is_cropped:
cropper = StaticRandomCrop(image_size, self.crop_size)
else:
cropper = StaticCenterCrop(image_size, self.render_size)
images = list(map(cropper, images))
flow = cropper(flow)
images = np.array(images).transpose(3, 0, 1, 2)
flow = flow.transpose(2, 0, 1)
return [images], [flow]
def __len__(self):
return self.size * self.replicates
class ChairsSDHomTrain(ChairsSDHom):
def __init__(self, args, is_cropped=False, root='', replicates=1):
super(ChairsSDHomTrain, self).__init__(args, is_cropped=is_cropped, root=root, dstype='train',
replicates=replicates)
class ChairsSDHomTest(ChairsSDHom):
def __init__(self, args, is_cropped=False, root='', replicates=1):
super(ChairsSDHomTest, self).__init__(args, is_cropped=is_cropped, root=root, dstype='test',
replicates=replicates)
class ImagesFromFolder(object):
def __init__(self, args, is_cropped, root='/path/to/frames/only/folder', iext='png', replicates=1):
self.args = args
self.is_cropped = is_cropped
self.crop_size = args.crop_size
self.render_size = args.inference_size
self.replicates = replicates
images = sorted(glob(join(root, '*.' + iext)))
self.image_list = []
for i in range(len(images) - 1):
im1 = images[i]
im2 = images[i + 1]
self.image_list += [[im1, im2]]
self.size = len(self.image_list)
self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0] % 64) or (
self.frame_size[1] % 64):
self.render_size[0] = ((self.frame_size[0]) // 64) * 64
self.render_size[1] = ((self.frame_size[1]) // 64) * 64
args.inference_size = self.render_size
def __getitem__(self, index):
index = index % self.size
img1 = frame_utils.read_gen(self.image_list[index][0])
img2 = frame_utils.read_gen(self.image_list[index][1])
images = [img1, img2]
image_size = img1.shape[:2]
if self.is_cropped:
cropper = StaticRandomCrop(image_size, self.crop_size)
else:
cropper = StaticCenterCrop(image_size, self.render_size)
images = list(map(cropper, images))
images = np.array(images).transpose(3, 0, 1, 2)
return [images], [np.zeros(images.size()[0:1] + (2,) + images.size()[-2:])]
def __len__(self):
return self.size * self.replicates
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.inference_size = [1080, 1920]
args.crop_size = [384, 512]
index = 50
flyingchairs_dataset = FlyingChairs(args, True, root='/ssd2/zhenghe/DATA/FlyingChairs_release/data')
# a, b = flyingchairs_dataset[index]
# im1 = a[0][:,0,:,:].transpose(1,2,0)
# im2 = a[0][:,1,:,:].transpose(1,2,0)
# flo = b[0].transpose(1, 2, 0) / 20.0
# flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False)
# imsave('./hsv_pd.png', flow_color)
sample_num = len(flyingchairs_dataset)
reader = reader_flyingchairs(flyingchairs_dataset)
BATCH_SIZE = 8
train_batch_reader = paddle.batch(reader, BATCH_SIZE, drop_last=True)
epoch_num = 1
with fluid.dygraph.guard():
for epoch in range(epoch_num):
for batch_id, data in enumerate(train_batch_reader()):
im1_data = np.array(
[x[0] for x in data]).astype('float32')
im2_data = np.array(
[x[1] for x in data]).astype('float32')
flo_data = np.array(
[x[2] for x in data]).astype('float32')
if batch_id % 500 == 0:
# if batch_id < 10:
print(batch_id)
print(im1_data.shape)
print(im2_data.shape)
print(flo_data.shape)
im1 = im1_data[0, :, :, :]
im2 = im2_data[0, :, :, :]
flo = flo_data[0, :, :, :]
print(im1.shape)
print(im2.shape)
print(flo.shape)
imsave('./img1.png', im1)
imsave('./img2.png', im2)
flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False)
imsave('./hsv_pd.png', flow_color)
print("batch_id:", batch_id)
print(batch_id * BATCH_SIZE)
print(sample_num)
# img = fluid.dygraph.to_variable(dy_x_data)
#!/bin/bash
#mkdir FlyingThings3D_release
#cd FlyingThings3D_release
#
#wget http://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/raw_data/flyingthings3d__frames_cleanpass.tar
#wget http://lmb.informatik.uni-freiburg.de/data/SceneFlowDatasets_CVPR16/Release_april16/data/FlyingThings3D/derived_data/flyingthings3d__optical_flow.tar.bz2
#
#tar xvf flyingthings3d__frames_cleanpass.tar
#tar xvf flyingthings3d__optical_flow.tar.bz2
#
#cd ..
wget http://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs/FlyingChairs.zip
unzip FlyingChairs.zip
#wget https://lmb.informatik.uni-freiburg.de/data/FlowNet2/ChairsSDHom/ChairsSDHom.tar.gz
#tar xvzf ChairsSDHom.tar.gz
import numpy as np
TAG_CHAR = np.array([202021.25], np.float32)
def readFlow(fn):
""" Read .flo file in Middlebury format"""
# Code adapted from:
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
# print 'fn = %s'%(fn)
with open(fn, 'rb') as f:
magic = np.fromfile(f, np.float32, count=1)
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
return None
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
# print 'Reading %d x %d flo file\n' % (w, h)
data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
# Reshape data into 3D array (columns, rows, bands)
# The reshape here is for visualization, the original code is (w,h,2)
return np.resize(data, (int(h), int(w), 2))
def writeFlow(filename, uv, v=None):
""" Write optical flow to file.
If v is None, uv is assumed to contain both u and v channels,
stacked in depth.
Original code by Deqing Sun, adapted from Daniel Scharstein.
"""
nBands = 2
if v is None:
assert (uv.ndim == 3)
assert (uv.shape[2] == 2)
u = uv[:, :, 0]
v = uv[:, :, 1]
else:
u = uv
assert (u.shape == v.shape)
height, width = u.shape
f = open(filename, 'wb')
# write the header
f.write(TAG_CHAR)
np.array(width).astype(np.int32).tofile(f)
np.array(height).astype(np.int32).tofile(f)
# arrange into matrix form
tmp = np.zeros((height, width * nBands))
tmp[:, np.arange(width) * 2] = u
tmp[:, np.arange(width) * 2 + 1] = v
tmp.astype(np.float32).tofile(f)
f.close()
\ No newline at end of file
import numpy as np
from os.path import *
from scipy.misc import imread
from . import flow_utils
def read_gen(file_name):
ext = splitext(file_name)[-1]
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
im = imread(file_name)
if im.shape[2] > 3:
return im[:,:,:3]
else:
return im
elif ext == '.bin' or ext == '.raw':
return np.load(file_name)
elif ext == '.flo':
return flow_utils.readFlow(file_name).astype(np.float32)
return []
\ No newline at end of file
#!/usr/bin/env bash
python3 train.py --loss l1 --pretrained ./out/pwc_net_paddle --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data
# use multi gpus NEED TO DO LATER
#python3 -m paddle.distributed.launch --selected_gpus=0,1 train.py --use_multi_gpu --batch_size 40 --loss l1 --pretrained ./out/pwc_net_paddle --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Infer for PWCNet."""
import sys
import pickle
import time
import cv2
import numpy as np
from math import ceil
from scipy.ndimage import imread
from scipy.misc import imsave
import paddle.fluid as fluid
from models.model import PWCDCNet
from src import flow_vis
def writeFlowFile(filename, uv):
"""
According to the matlab code of Deqing Sun and c++ source code of Daniel Scharstein
Contact: dqsun@cs.brown.edu
Contact: schar@middlebury.edu
"""
TAG_STRING = np.array(202021.25, dtype=np.float32)
if uv.shape[2] != 2:
sys.exit("writeFlowFile: flow must have two bands!");
H = np.array(uv.shape[0], dtype=np.int32)
W = np.array(uv.shape[1], dtype=np.int32)
with open(filename, 'wb') as f:
f.write(TAG_STRING.tobytes())
f.write(W.tobytes())
f.write(H.tobytes())
f.write(uv.tobytes())
def load_dict(filename_):
with open(filename_, 'rb') as f:
ret_di = pickle.load(f)
return ret_di
def pad_input(x0):
intWidth = x0.shape[2]
intHeight = x0.shape[3]
if intWidth != ((intWidth >> 6) << 6):
intWidth_pad = (((intWidth >> 6) + 1) << 6) # more than necessary
intPaddingLeft = int((intWidth_pad - intWidth) / 2)
intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
else:
intWidth_pad = intWidth
intPaddingLeft = 0
intPaddingRight = 0
if intHeight != ((intHeight >> 6) << 6):
intHeight_pad = (((intHeight >> 6) + 1) << 6) # more than necessary
intPaddingTop = int((intHeight_pad - intHeight) / 2)
intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
else:
intHeight_pad = intHeight
intPaddingTop = 0
intPaddingBottom = 0
out = fluid.layers.pad2d(input=x0,
paddings=[intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom],
mode='edge')
return out, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight]
def main():
im1_fn = 'data/frame_0010.png'
im2_fn = 'data/frame_0011.png'
flow_fn = './tmp/frame_0010_pd.flo'
if len(sys.argv) > 1:
im1_fn = sys.argv[1]
if len(sys.argv) > 2:
im2_fn = sys.argv[2]
if len(sys.argv) > 3:
flow_fn = sys.argv[3]
im_all = [imread(img) for img in [im1_fn, im2_fn]]
im_all = [im[:, :, :3] for im in im_all]
# rescale the image size to be multiples of 64
divisor = 64.
H = im_all[0].shape[0]
W = im_all[0].shape[1]
print('origin shape : ', H, W)
H_ = int(ceil(H / divisor) * divisor)
W_ = int(ceil(W / divisor) * divisor)
print('resize shape: ', H_, W_)
for i in range(len(im_all)):
im_all[i] = cv2.resize(im_all[i], (W_, H_))
for _i, _inputs in enumerate(im_all):
im_all[_i] = im_all[_i][:, :, ::-1]
im_all[_i] = 1.0 * im_all[_i] / 255.0
im_all[_i] = np.transpose(im_all[_i], (2, 0, 1))
im_all = np.concatenate((im_all[0], im_all[1]), axis=0).astype(np.float32)
im_all = im_all[np.newaxis, :, :, :]
with fluid.dygraph.guard(place=fluid.CUDAPlace(0)):
im_all = fluid.dygraph.to_variable(im_all)
im_all, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] = pad_input(
im_all)
model = PWCDCNet("pwcnet")
model.eval()
pd_pretrain, _ = fluid.dygraph.load_dygraph("paddle_model/pwc_net_paddle")
model.set_dict(pd_pretrain)
start = time.time()
flo = model(im_all)
end = time.time()
print('Time of PWCNet model for one infer step: ', end - start)
flo = flo[0].numpy() * 20.0
# scale the flow back to the input size
flo = np.swapaxes(np.swapaxes(flo, 0, 1), 1, 2)
flo = flo[intPaddingTop * 2:intPaddingTop * 2 + intHeight * 2,
intPaddingLeft * 2: intPaddingLeft * 2 + intWidth * 2, :]
u_ = cv2.resize(flo[:, :, 0], (W, H))
v_ = cv2.resize(flo[:, :, 1], (W, H))
u_ *= W / float(W_)
v_ *= H / float(H_)
flo = np.dstack((u_, v_))
# # Apply the coloring (for OpenCV, set convert_to_bgr=True)
flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False)
imsave('./tmp/hsv_pd.png', flow_color)
writeFlowFile(flow_fn, flo)
if __name__ == '__main__':
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import models.model
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
from paddle.fluid.dygraph import Conv2D, Conv2DTranspose
from correlation_op.correlation import correlation
class PWCDCNet(fluid.dygraph.Layer):
def __init__(self, name_scope, md=4):
super(PWCDCNet, self).__init__(name_scope)
self.param_attr = fluid.ParamAttr(
name='conv_weights',
regularizer=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0004),
initializer=fluid.initializer.MSRAInitializer(uniform=True, fan_in=None, seed=0))
self.md = md
self.conv1a = Conv2D("conv1a", 16, filter_size=3, stride=2, padding=1, param_attr=self.param_attr)
self.conv1aa = Conv2D("conv1aa", 16, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv1b = Conv2D("conv1b", 16, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv2a = Conv2D("conv2a", 32, filter_size=3, stride=2, padding=1, param_attr=self.param_attr)
self.conv2aa = Conv2D("conv2aa", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv2b = Conv2D("conv2b", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv3a = Conv2D("conv3a", 64, filter_size=3, stride=2, padding=1, param_attr=self.param_attr)
self.conv3aa = Conv2D("conv3aa", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv3b = Conv2D("conv3b", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv4a = Conv2D("conv4a", 96, filter_size=3, stride=2, padding=1, param_attr=self.param_attr)
self.conv4aa = Conv2D("conv4aa", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv4b = Conv2D("conv4b", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv5a = Conv2D("conv5a", 128, filter_size=3, stride=2, padding=1, param_attr=self.param_attr)
self.conv5aa = Conv2D("conv5aa", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv5b = Conv2D("conv5b", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6aa = Conv2D("conv6aa", 196, filter_size=3, stride=2, padding=1, param_attr=self.param_attr)
self.conv6a = Conv2D("conv6a", 196, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6b = Conv2D("conv6b", 196, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6_0 = Conv2D("conv6_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6_1 = Conv2D("conv6_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6_2 = Conv2D("conv6_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6_3 = Conv2D("conv6_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv6_4 = Conv2D("conv6_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.predict_flow6 = Conv2D("predict_flow6", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr)
self.deconv6 = Conv2DTranspose("deconv6", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.upfeat6 = Conv2DTranspose("upfeat6", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.conv5_0 = Conv2D("conv5_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv5_1 = Conv2D("conv5_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv5_2 = Conv2D("conv5_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv5_3 = Conv2D("conv5_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv5_4 = Conv2D("conv5_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.predict_flow5 = Conv2D("predict_flow5", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr)
self.deconv5 = Conv2DTranspose("deconv5", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.upfeat5 = Conv2DTranspose("upfeat5", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.conv4_0 = Conv2D("conv4_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv4_1 = Conv2D("conv4_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv4_2 = Conv2D("conv4_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv4_3 = Conv2D("conv4_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv4_4 = Conv2D("conv4_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.predict_flow4 = Conv2D("predict_flow4", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr)
self.deconv4 = Conv2DTranspose("deconv4", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.upfeat4 = Conv2DTranspose("upfeat4", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.conv3_0 = Conv2D("conv3_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv3_1 = Conv2D("conv3_1", 128, filter_size=3, stride=1, padding=1 ,param_attr=self.param_attr)
self.conv3_2 = Conv2D("conv3_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv3_3 = Conv2D("conv3_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv3_4 = Conv2D("conv3_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.predict_flow3 = Conv2D("predict_flow3", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr)
self.deconv3 = Conv2DTranspose("deconv3", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.upfeat3 = Conv2DTranspose("upfeat3", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.conv2_0 = Conv2D("conv2_0", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv2_1 = Conv2D("conv2_1", 128, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv2_2 = Conv2D("conv2_2", 96, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv2_3 = Conv2D("conv2_3", 64, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.conv2_4 = Conv2D("conv2_4", 32, filter_size=3, stride=1, padding=1, param_attr=self.param_attr)
self.predict_flow2 = Conv2D("predict_flow2", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr)
self.deconv2 = Conv2DTranspose("deconv2", 2, filter_size=4, stride=2, padding=1, param_attr=self.param_attr)
self.dc_conv1 = Conv2D("dc_conv1", 128, filter_size=3, stride=1, padding=1, dilation=1, param_attr=self.param_attr)
self.dc_conv2 = Conv2D("dc_conv2", 128, filter_size=3, stride=1, padding=2, dilation=2, param_attr=self.param_attr)
self.dc_conv3 = Conv2D("dc_conv3", 128, filter_size=3, stride=1, padding=4, dilation=4, param_attr=self.param_attr)
self.dc_conv4 = Conv2D("dc_conv4", 96, filter_size=3, stride=1, padding=8, dilation=8, param_attr=self.param_attr)
self.dc_conv5 = Conv2D("dc_conv5", 64, filter_size=3, stride=1, padding=16, dilation=16, param_attr=self.param_attr)
self.dc_conv6 = Conv2D("dc_conv6", 32, filter_size=3, stride=1, padding=1, dilation=1, param_attr=self.param_attr)
self.dc_conv7 = Conv2D("dc_conv7", 2, filter_size=3,stride=1,padding=1, param_attr=self.param_attr)
def warp(self, x, flo):
"""
warp an image/tensor (im2) back to im1, according to the optical flow
x: [B, C, H, W] (im2)
flo: [B, 2, H, W] flow
"""
B, C, H, W = x.shape
# mesh grid
xx_pd = fluid.layers.range(0, W, 1, 'float32')
xx_pd = fluid.layers.reshape(xx_pd, shape=[1, -1])
xx_pd = fluid.layers.expand(x=xx_pd, expand_times=[H, 1])
xx_pd = fluid.layers.reshape(xx_pd, shape=[1, 1, H, W])
xx_pd = fluid.layers.expand(x=xx_pd, expand_times=[B, 1, 1, 1])
yy_pd = fluid.layers.range(0, H, 1, 'float32')
yy_pd = fluid.layers.reshape(yy_pd, shape=[-1, 1])
yy_pd = fluid.layers.expand(x=yy_pd, expand_times=[1, W])
yy_pd = fluid.layers.reshape(x=yy_pd, shape=[1, 1, H, W])
yy_pd = fluid.layers.expand(x=yy_pd, expand_times=[B, 1, 1, 1])
grid_pd = fluid.layers.concat(input=[xx_pd, yy_pd], axis=1)
flo_pd = flo
vgrid_pd = fluid.layers.elementwise_add(grid_pd, flo_pd)
vgrid_pd_0 = 2.0 * fluid.layers.slice(vgrid_pd, axes=[1], starts=[0], ends=[1]) / max(W - 1, 1) - 1.0
vgrid_pd_1 = 2.0 * fluid.layers.slice(vgrid_pd, axes=[1], starts=[1], ends=[2]) / max(H - 1, 1) - 1.0
vgrid_pd = fluid.layers.concat(input=[vgrid_pd_0, vgrid_pd_1], axis=1)
vgrid_pd = fluid.layers.transpose(vgrid_pd, [0, 2, 3, 1])
output = fluid.layers.grid_sampler(name='grid_sample', x=x, grid=vgrid_pd)
mask = fluid.layers.zeros_like(x)
mask = mask + 1.0
mask = fluid.layers.grid_sampler(name='grid_sample', x=mask, grid=vgrid_pd)
mask_temp1 = fluid.layers.cast(mask < 0.9990, 'float32')
mask = mask * (1 - mask_temp1)
mask = fluid.layers.cast(mask > 0, 'float32')
outwarp = fluid.layers.elementwise_mul(output, mask)
return outwarp
def corr(self, x_1, x_2):
out = correlation(x_1, x_2, pad_size=self.md, kernel_size=1, max_displacement=self.md,
stride1=1, stride2=1, corr_type_multiply=1)
return out
def forward(self, x, output_more=False):
im1 = fluid.layers.slice(x, axes=[1], starts=[0], ends=[3])
im2 = fluid.layers.slice(x, axes=[1], starts=[3], ends=[6])
# print("\n\n***************************PWC Net details *************** \n\n")
c11 = fluid.layers.leaky_relu(self.conv1a(im1), 0.1)
c11 = fluid.layers.leaky_relu(self.conv1aa(c11), 0.1)
c11 = fluid.layers.leaky_relu(self.conv1b(c11), 0.1)
c21 = fluid.layers.leaky_relu(self.conv1a(im2), 0.1)
c21 = fluid.layers.leaky_relu(self.conv1aa(c21), 0.1)
c21 = fluid.layers.leaky_relu(self.conv1b(c21), 0.1)
c12 = fluid.layers.leaky_relu(self.conv2a(c11), 0.1)
c12 = fluid.layers.leaky_relu(self.conv2aa(c12), 0.1)
c12 = fluid.layers.leaky_relu(self.conv2b(c12), 0.1)
c22 = fluid.layers.leaky_relu(self.conv2a(c21), 0.1)
c22 = fluid.layers.leaky_relu(self.conv2aa(c22), 0.1)
c22 = fluid.layers.leaky_relu(self.conv2b(c22), 0.1)
c13 = fluid.layers.leaky_relu(self.conv3a(c12), 0.1)
c13 = fluid.layers.leaky_relu(self.conv3aa(c13), 0.1)
c13 = fluid.layers.leaky_relu(self.conv3b(c13), 0.1)
c23 = fluid.layers.leaky_relu(self.conv3a(c22), 0.1)
c23 = fluid.layers.leaky_relu(self.conv3aa(c23), 0.1)
c23 = fluid.layers.leaky_relu(self.conv3b(c23), 0.1)
c14 = fluid.layers.leaky_relu(self.conv4a(c13), 0.1)
c14 = fluid.layers.leaky_relu(self.conv4aa(c14), 0.1)
c14 = fluid.layers.leaky_relu(self.conv4b(c14), 0.1)
c24 = fluid.layers.leaky_relu(self.conv4a(c23), 0.1)
c24 = fluid.layers.leaky_relu(self.conv4aa(c24), 0.1)
c24 = fluid.layers.leaky_relu(self.conv4b(c24), 0.1)
c15 = fluid.layers.leaky_relu(self.conv5a(c14), 0.1)
c15 = fluid.layers.leaky_relu(self.conv5aa(c15), 0.1)
c15 = fluid.layers.leaky_relu(self.conv5b(c15), 0.1)
c25 = fluid.layers.leaky_relu(self.conv5a(c24), 0.1)
c25 = fluid.layers.leaky_relu(self.conv5aa(c25), 0.1)
c25 = fluid.layers.leaky_relu(self.conv5b(c25), 0.1)
c16 = fluid.layers.leaky_relu(self.conv6aa(c15), 0.1)
c16 = fluid.layers.leaky_relu(self.conv6a(c16), 0.1)
c16 = fluid.layers.leaky_relu(self.conv6b(c16), 0.1)
c26 = fluid.layers.leaky_relu(self.conv6aa(c25), 0.1)
c26 = fluid.layers.leaky_relu(self.conv6a(c26), 0.1)
c26 = fluid.layers.leaky_relu(self.conv6b(c26), 0.1)
corr6 = self.corr(c16, c26)
corr6 = fluid.layers.leaky_relu(corr6, alpha=0.1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_0(corr6), 0.1), corr6], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_1(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_2(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_3(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv6_4(x), 0.1), x], axis=1)
flow6 = self.predict_flow6(x)
up_flow6 = self.deconv6(flow6)
up_feat6 = self.upfeat6(x)
warp5 = self.warp(c25, up_flow6 * 0.625)
corr5 = self.corr(c15, warp5)
corr5 = fluid.layers.leaky_relu(corr5, alpha=0.1)
x = fluid.layers.concat(input=[corr5, c15, up_flow6, up_feat6], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_0(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_1(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_2(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_3(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv5_4(x), 0.1), x], axis=1)
flow5 = self.predict_flow5(x)
up_flow5 = self.deconv5(flow5)
up_feat5 = self.upfeat5(x)
warp4 = self.warp(c24, up_flow5 * 1.25)
corr4 = self.corr(c14, warp4)
corr4 = fluid.layers.leaky_relu(corr4, alpha=0.1)
x = fluid.layers.concat(input=[corr4, c14, up_flow5, up_feat5], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_0(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_1(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_2(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_3(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv4_4(x), 0.1), x], axis=1)
flow4 = self.predict_flow4(x)
up_flow4 = self.deconv4(flow4)
up_feat4 = self.upfeat4(x)
warp3 = self.warp(c23, up_flow4 * 2.5)
corr3 = self.corr(c13, warp3)
corr3 = fluid.layers.leaky_relu(corr3, alpha=0.1)
x = fluid.layers.concat(input=[corr3, c13, up_flow4, up_feat4], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_0(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_1(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_2(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_3(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv3_4(x), 0.1), x], axis=1)
flow3 = self.predict_flow3(x)
up_flow3 = self.deconv3(flow3)
up_feat3 = self.upfeat3(x)
warp2 = self.warp(c22, up_flow3 * 5.0)
corr2 = self.corr(c12, warp2)
corr2 = fluid.layers.leaky_relu(corr2, alpha=0.1)
x = fluid.layers.concat(input=[corr2, c12, up_flow3, up_feat3], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_0(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_1(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_2(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_3(x), 0.1), x], axis=1)
x = fluid.layers.concat(input=[fluid.layers.leaky_relu(self.conv2_4(x), 0.1), x], axis=1)
flow2 = self.predict_flow2(x)
x = fluid.layers.leaky_relu(self.dc_conv4(fluid.layers.leaky_relu(
self.dc_conv3(fluid.layers.leaky_relu(self.dc_conv2(fluid.layers.leaky_relu(self.dc_conv1(x), 0.1)), 0.1)),
0.1)), 0.1)
flow2 += self.dc_conv7(
fluid.layers.leaky_relu(self.dc_conv6(fluid.layers.leaky_relu(self.dc_conv5(x), 0.1)), 0.1))
if not output_more:
return flow2
else:
return [flow2, flow3, flow4, flow5, flow6]
import argparse
parser = argparse.ArgumentParser(description='PWCNet_paddle')
parser.add_argument('--dataset', default='FlyingChairs', help='dataset type : FlyingChairs')
parser.add_argument('--data_root', default='', help='the path of selected datasets')
parser.add_argument('--model_out_dir', default='./out', help='the path of selected datasets')
parser.add_argument('--loss', default='l2', help='loss type : first train with l2 and finetune with l1')
parser.add_argument('--train_val_txt', default='', help='the path of selected train_val_txt of dataset')
parser.add_argument('--numEpoch', '-e', type=int, default=100, help='Number of epochs to train')
parser.add_argument('--batch_size', '-b', type=int, default=40, help='batch size')
parser.add_argument('--pretrained', default=None, help='path to the pretrained model weights')
parser.add_argument('--optimize', default=None, help='path to the pretrained optimize weights')
parser.add_argument('--use_multi_gpu',action = 'store_true', help='Enable multi gpu mode')
args = parser.parse_args()
args.inference_size = [384, 512]
args.crop_size = [384, 448]
\ No newline at end of file
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-08-03
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def make_colorwheel():
'''
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
'''
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
col = col+RY
# YG
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
colorwheel[col:col+YG, 1] = 255
col = col+YG
# GC
colorwheel[col:col+GC, 1] = 255
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
col = col+GC
# CB
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
colorwheel[col:col+CB, 2] = 255
col = col+CB
# BM
colorwheel[col:col+BM, 2] = 255
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
col = col+BM
# MR
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
colorwheel[col:col+MR, 0] = 255
return colorwheel
def flow_compute_color(u, v, convert_to_bgr=False):
'''
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
:param u: np.ndarray, input horizontal flow
:param v: np.ndarray, input vertical flow
:param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
:return:
'''
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u)/np.pi
fk = (a+1) / 2*(ncols-1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:,i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1-f)*col0 + f*col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1-col[idx])
col[~idx] = col[~idx] * 0.75 # out of range?
# Note the 2-i => BGR instead of RGB
ch_idx = 2-i if convert_to_bgr else i
flow_image[:,:,ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False):
'''
Expects a two dimensional flow image of shape [H,W,2]
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
:param flow_uv: np.ndarray of shape [H,W,2]
:param clip_flow: float, maximum clipping value for flow
:return:
'''
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:,:,0]
v = flow_uv[:,:,1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_compute_color(u, v, convert_to_bgr)
def read_flow(filename):
"""
https://github.com/sampepose/flownet2-tf/blob/master/src/flowlib.py
read optical flow from Middlebury .flo file
:param filename: name of the flow file
:return: optical flow data in matrix
"""
f = open(filename, 'rb')
magic = np.fromfile(f, np.float32, count=1)
data2d = None
if 202021.25 != magic:
print('Magic number incorrect. Invalid .flo file')
else:
w = np.fromfile(f, np.int32, count=1)
h = np.fromfile(f, np.int32, count=1)
print("Reading %d x %d flo file" % (h, w))
data2d = np.fromfile(f, np.float32, count=2 * w[0] * h[0])
# reshape data into 3D array (columns, rows, channels)
data2d = np.resize(data2d, (h[0], w[0], 2))
f.close()
return data2d
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
def EPE(input_flow, target_flow, loss_type, sparse=False, mean=True):
if loss_type == 'l1':
EPE_map = fluid.layers.abs(input_flow - target_flow)
else:
EPE_map = fluid.layers.square(input_flow - target_flow)
if sparse: #TODO mask = (target_flow[:,0] == 0) & (target_flow[:,1] == 0) EPE_map = EPE_map[~mask]
mask_temp1 = fluid.layers.cast(target_flow[:, 0] == 0, 'float32')
mask_temp2 = fluid.layers.cast(target_flow[:, 1] == 0, 'float32')
mask = 1 - fluid.layers.elementwise_mul(mask_temp1, mask_temp2)
mask = fluid.layers.reshape(mask, [mask.shape[0], 1, mask.shape[1], mask.shape[2]])
mask = fluid.layers.concat([mask, mask], 1)
EPE_map = EPE_map * mask
if mean:
return fluid.layers.mean(EPE_map)
else:
batch_size = EPE_map.shape[0]
res_sum = fluid.layers.reduce_sum(EPE_map)
res = res_sum / batch_size
return res
def sparse_max_pool(input, size):
'''Downsample the input by considering 0 values as invalid.
Unfortunately, no generic interpolation mode can resize a sparse map correctly,
the strategy here is to use max pooling for positive values and "min pooling"
for negative values, the two results are then summed.
This technique allows sparsity to be minized, contrary to nearest interpolation,
which could potentially lose information for isolated data points.'''
positive = fluid.layers.cast(input > 0, 'float32')
negative = fluid.layers.cast(input < 0, 'float32')
output = fluid.layers.adaptive_pool2d(input * positive, size) - fluid.layers.adaptive_pool2d(-input * negative,
size)
return output
def multiscaleEPE(network_output, target_flow, loss_type, weights=None, sparse=False):
def one_scale(output, target, sparse, loss_type):
if sparse:
h = output.shape[2]
w = output.shape[3]
target_scaled = sparse_max_pool(target, [h, w])
else:
target_scaled = fluid.layers.resize_bilinear(target, out_shape=[output.shape[2],
output.shape[3]],
align_corners=False, align_mode=False)
return EPE(output, target_scaled, loss_type=loss_type, sparse=sparse, mean=False)
if type(network_output) not in [tuple, list]:
network_output = [network_output]
if weights is None:
weights = [0.005, 0.01, 0.02, 0.08, 0.32] # as in original article
assert(len(weights) == len(network_output))
loss = 0
for output, weight in zip(network_output, weights):
loss += weight * one_scale(output, target_flow, sparse, loss_type)
return loss
def realEPE(output, target, sparse=False):
upsampled_output = fluid.layers.resize_bilinear(output, out_shape=[target.shape[2],
target.shape[3]],
align_corners=False, align_mode=False)
return EPE(upsampled_output, target, sparse, mean=True)
def read_txt(videoTxt):
with open(videoTxt, 'r') as f:
videolist = f.readlines()
return videolist
def read_txt_to_index(file):
data = read_txt(file)
data = list(map(int, data))
return data
def main():
file = 'data_dir/FlyingChairs_release/FlyingChairs_train_val.txt'
data = read_txt_to_index(file)
data = list(map(int, data))
print(data)
print(len(data))
if __name__ == '__main__':
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Trainer for PWCNet."""
import sys
import os
os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = "0.99999"
os.environ["FLAGS_eager_delete_tensor_gb"] = "0"
import pickle
import time
import cv2
import numpy as np
import paddle
import paddle.fluid as fluid
from scipy.misc import imsave
from src import flow_vis
from models.model import PWCDCNet
from data.datasets import FlyingChairs, reader_flyingchairs
from src.multiscaleloss import multiscaleEPE, realEPE
from AverageMeter import *
from my_args import args
def writeFlowFile(filename, uv):
"""
According to the matlab code of Deqing Sun and c++ source code of Daniel Scharstein
Contact: dqsun@cs.brown.edu
Contact: schar@middlebury.edu
"""
TAG_STRING = np.array(202021.25, dtype=np.float32)
if uv.shape[2] != 2:
sys.exit("writeFlowFile: flow must have two bands!");
H = np.array(uv.shape[0], dtype=np.int32)
W = np.array(uv.shape[1], dtype=np.int32)
with open(filename, 'wb') as f:
f.write(TAG_STRING.tobytes())
f.write(W.tobytes())
f.write(H.tobytes())
f.write(uv.tobytes())
def load_dict(filename_):
with open(filename_, 'rb') as f:
ret_di = pickle.load(f)
return ret_di
def pad_input(x0):
intWidth = x0.shape[2]
intHeight = x0.shape[3]
if intWidth != ((intWidth >> 6) << 6):
intWidth_pad = (((intWidth >> 6) + 1) << 6) # more than necessary
intPaddingLeft = int((intWidth_pad - intWidth) / 2)
intPaddingRight = intWidth_pad - intWidth - intPaddingLeft
else:
intWidth_pad = intWidth
intPaddingLeft = 0
intPaddingRight = 0
if intHeight != ((intHeight >> 6) << 6):
intHeight_pad = (((intHeight >> 6) + 1) << 6) # more than necessary
intPaddingTop = int((intHeight_pad - intHeight) / 2)
intPaddingBottom = intHeight_pad - intHeight - intPaddingTop
else:
intHeight_pad = intHeight
intPaddingTop = 0
intPaddingBottom = 0
out = fluid.layers.pad2d(input=x0,
paddings=[intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom],
mode='edge')
return out, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight]
def val(model, batch_reader, epoch, batch_num):
model.eval()
loss_cnt = AverageMeter()
for batch_id, data in enumerate(batch_reader()):
start = time.time()
im1_data = np.array(
[x[0] for x in data]).astype('float32')
im2_data = np.array(
[x[1] for x in data]).astype('float32')
flo_data = np.array(
[x[2] for x in data]).astype('float32')
step = im1_data.shape[0]
im_all = np.concatenate((im1_data, im2_data), axis=3).astype(np.float32)
im_all = im_all / 255.0
im_all = np.swapaxes(np.swapaxes(im_all, 1, 2), 1, 3)
label = flo_data / 20.0
label = np.swapaxes(np.swapaxes(label, 1, 2), 1, 3)
im_all = fluid.dygraph.to_variable(im_all)
label = fluid.dygraph.to_variable(label)
# im_all, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] = pad_input(
# im_all)
end = time.time()
read_data_time = end - start
start = time.time()
network_output = model(im_all, output_more=False)
loss = realEPE(network_output, label)
end = time.time()
loss_cnt.update(loss.numpy()[0], step)
print('val epoch {} batch {}/{} run time: {}s read data time {}s loss {}'.format(epoch, batch_id, batch_num,
round(end - start, 2),
round(read_data_time, 2),
loss.numpy()))
return round(loss_cnt.avg, 4)
def train(model, train_batch_reader, adam, epoch, batch_num, args):
loss_type = args.loss
model.train()
for batch_id, data in enumerate(train_batch_reader()):
start = time.time()
im1_data = np.array(
[x[0] for x in data]).astype('float32')
im2_data = np.array(
[x[1] for x in data]).astype('float32')
flo_data = np.array(
[x[2] for x in data]).astype('float32')
im_all = np.concatenate((im1_data, im2_data), axis=3).astype(np.float32)
im_all = im_all / 255.0
im_all = np.swapaxes(np.swapaxes(im_all, 1, 2), 1, 3)
label = flo_data / 20.0
label = np.swapaxes(np.swapaxes(label, 1, 2), 1, 3)
if batch_id % 10 == 0:
im1 = im_all[0, :3, :, :] * 255
im2 = im_all[0, 3:, :, :] * 255
im1 = np.swapaxes(np.swapaxes(im1, 0, 1), 1, 2).astype(np.uint8)
im2 = np.swapaxes(np.swapaxes(im2, 0, 1), 1, 2).astype(np.uint8)
flo = label[0, :, :, :] * 20
flo = np.swapaxes(np.swapaxes(flo, 0, 1), 1, 2)
imsave('./img1.png', im1)
imsave('./img2.png', im2)
flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False)
imsave('./hsv_pd.png', flow_color)
H = im_all[0].shape[1]
W = im_all[0].shape[2]
im_all = fluid.dygraph.to_variable(im_all)
label = fluid.dygraph.to_variable(label)
im_all, [intPaddingLeft, intPaddingRight, intPaddingTop, intPaddingBottom, intWidth, intHeight] = pad_input(
im_all)
label, _ = pad_input(label)
end = time.time()
read_data_time = end - start
start = time.time()
network_output = model(im_all, output_more=True)
if batch_id % 10 == 0:
flo = network_output[0][0].numpy() * 20.0
# scale the flow back to the input size
flo = np.swapaxes(np.swapaxes(flo, 0, 1), 1, 2)
flo = flo[intPaddingTop * 2:intPaddingTop * 2 + intHeight * 2,
intPaddingLeft * 2: intPaddingLeft * 2 + intWidth * 2, :]
u_ = cv2.resize(flo[:, :, 0], (W, H))
v_ = cv2.resize(flo[:, :, 1], (W, H))
flo = np.dstack((u_, v_))
flow_color = flow_vis.flow_to_color(flo, convert_to_bgr=False)
imsave('./hsv_predict.png', flow_color)
loss = multiscaleEPE(network_output, label, loss_type, weights=None, sparse=False)
end = time.time()
loss.backward()
if args.use_multi_gpu:
model.apply_collective_grads()
adam.minimize(loss)
model.clear_gradients()
print('epoch {} batch {}/{} run time: {}s read data time {}s loss {}'.format(epoch, batch_id, batch_num,
round(end - start, 2),
round(read_data_time, 2),
loss.numpy()))
def main():
print(args)
if args.use_multi_gpu:
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id)
else:
place = fluid.CUDAPlace(0)
with fluid.dygraph.guard(place=place):
if args.use_multi_gpu:
strategy = fluid.dygraph.parallel.prepare_context()
model = PWCDCNet("pwcnet")
if args.pretrained:
print('-----------load pretrained model:', args.pretrained)
pd_pretrain, _ = fluid.dygraph.load_dygraph(args.pretrained)
model.set_dict(pd_pretrain)
adam = fluid.optimizer.AdamOptimizer(learning_rate=0.0001, regularization=fluid.regularizer.L2DecayRegularizer(
regularization_coeff=0.0004))
if args.optimize:
print('--------------load pretrained model:', args.optimize)
adam_pretrain, _ = fluid.dygraph.load_dygraph(args.optimize)
adam.set_dict(adam_pretrain)
if args.use_multi_gpu:
model = fluid.dygraph.parallel.DataParallel(model, strategy)
if args.dataset == 'FlyingChairs':
train_flyingchairs_dataset = FlyingChairs('train', args, is_cropped=True, txt_file=args.train_val_txt,
root=args.data_root)
val_flyingchairs_dataset = FlyingChairs('val', args, is_cropped=False, txt_file=args.train_val_txt,
root=args.data_root)
else:
raise ValueError('dataset name is wrong, please fix it by using args.dataset')
train_sample_num = len(train_flyingchairs_dataset)
val_sample_num = len(val_flyingchairs_dataset)
print('train sample num: ', train_sample_num)
print('val sample num: ', val_sample_num)
train_reader = reader_flyingchairs(train_flyingchairs_dataset)
val_reader = reader_flyingchairs(val_flyingchairs_dataset)
if args.use_multi_gpu:
train_reader = fluid.contrib.reader.distributed_batch_reader(
train_reader)
val_reader = fluid.contrib.reader.distributed_batch_reader(
val_reader)
BATCH_SIZE = args.batch_size
train_batch_num = round(train_sample_num / BATCH_SIZE)
val_batch_num = round(val_sample_num / BATCH_SIZE)
train_batch_reader = paddle.batch(paddle.reader.shuffle(train_reader, buf_size=BATCH_SIZE * 100), BATCH_SIZE,
drop_last=True)
val_batch_reader = paddle.batch(val_reader, BATCH_SIZE, drop_last=False)
epoch_num = args.numEpoch
val_value = 100000000
rm_best_model = ""
for epoch in range(epoch_num):
train(model, train_batch_reader, adam, epoch, train_batch_num, args)
pd_save_dir = args.model_out_dir
if not os.path.exists(pd_save_dir):
os.makedirs(pd_save_dir)
pd_model_save = os.path.join(pd_save_dir, 'epoch_' + str(epoch) + "_pwc_net_paddle")
rm_dir = os.path.join(pd_save_dir, 'epoch_' + str(epoch - 1) + "_pwc_net_paddle.pdparams")
if os.path.exists(rm_dir):
os.remove(rm_dir)
if args.use_multi_gpu:
if fluid.dygraph.parallel.Env().local_rank == 0:
fluid.dygraph.save_dygraph(model.state_dict(), pd_model_save)
fluid.dygraph.save_dygraph(adam.state_dict(), os.path.join(pd_save_dir, 'adam'))
else:
fluid.dygraph.save_dygraph(model.state_dict(), pd_model_save)
fluid.dygraph.save_dygraph(adam.state_dict(), os.path.join(pd_save_dir, 'adam'))
val_loss_value = val(model, val_batch_reader, epoch, val_batch_num)
if val_loss_value < val_value:
best_model = os.path.join(pd_save_dir, "pwc_net_paddle_" + str(val_loss_value) + '.pdparams')
os.link(pd_model_save + '.pdparams', best_model)
if os.path.exists(rm_best_model):
os.remove(rm_best_model)
rm_best_model = best_model
val_value = val_loss_value
if __name__ == '__main__':
main()
#!/usr/bin/env bash
python3 train.py --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data
# use multi gpus NEED TO DO LATER
#python3 -m paddle.distributed.launch --selected_gpus=0,1 --log_dir ./mylog train.py --use_multi_gpu --batch_size 20 --dataset FlyingChairs --train_val_txt data_dir/FlyingChairs_release/FlyingChairs_train_val.txt --data_root data_dir/FlyingChairs_release/data
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册