Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8df5b4d6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8df5b4d6
编写于
9月 08, 2020
作者:
L
LielinJiang
提交者:
GitHub
9月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add correlation api to contrib (#27015)
* add correlation api to contrib
上级
cbcd5e40
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
908 addition
and
2 deletion
+908
-2
paddle/fluid/operators/correlation_op.cc
paddle/fluid/operators/correlation_op.cc
+181
-0
paddle/fluid/operators/correlation_op.cu
paddle/fluid/operators/correlation_op.cu
+483
-0
python/paddle/fluid/contrib/layers/nn.py
python/paddle/fluid/contrib/layers/nn.py
+81
-2
python/paddle/fluid/contrib/tests/test_correlation.py
python/paddle/fluid/contrib/tests/test_correlation.py
+163
-0
未找到文件。
paddle/fluid/operators/correlation_op.cc
0 → 100644
浏览文件 @
8df5b4d6
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
inline
std
::
vector
<
int64_t
>
CorrelationOutputSize
(
int
batch
,
int
input_height
,
int
input_width
,
int
stride1
,
int
stride2
,
int
kernel_size
,
int
pad_size
,
int
max_displacement
)
{
std
::
vector
<
int64_t
>
output_shape
({
batch
});
int
kernel_radius
=
(
kernel_size
-
1
)
/
2
;
int
border_radius
=
kernel_radius
+
max_displacement
;
int
padded_input_height
=
input_height
+
2
*
pad_size
;
int
padded_input_width
=
input_width
+
2
*
pad_size
;
int
output_channel
=
((
max_displacement
/
stride2
)
*
2
+
1
)
*
((
max_displacement
/
stride2
)
*
2
+
1
);
output_shape
.
push_back
(
output_channel
);
int
output_height
=
std
::
ceil
(
static_cast
<
float
>
(
padded_input_height
-
2
*
border_radius
)
/
static_cast
<
float
>
(
stride1
));
int
output_width
=
std
::
ceil
(
static_cast
<
float
>
(
padded_input_width
-
2
*
border_radius
)
/
static_cast
<
float
>
(
stride1
));
output_shape
.
push_back
(
output_height
);
output_shape
.
push_back
(
output_width
);
return
output_shape
;
}
class
CorrelationOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input1"
,
"Input is a 4-D Tensor with shape [N, C, H, W]"
);
AddInput
(
"Input2"
,
"Input is a 4-D Tensor with shape [N, C, H, W]"
);
AddOutput
(
"Output"
,
"(Tensor) The output tensor of correlation operator. "
"It has same data fromat and data type as the Input."
);
AddAttr
<
int
>
(
"pad_size"
,
"pad size for input1 and input2"
);
AddAttr
<
int
>
(
"kernel_size"
,
"kernel size of input1 and input2"
);
AddAttr
<
int
>
(
"max_displacement"
,
"max displacement of input1 and input2"
);
AddAttr
<
int
>
(
"stride1"
,
"Input1 stride"
);
AddAttr
<
int
>
(
"stride2"
,
"Input2 stride"
);
AddAttr
<
int
>
(
"corr_type_multiply"
,
"correlation coefficient"
).
SetDefault
(
1
);
AddComment
(
R"DOC(Correlation of two feature map. Only support NCHW data format.)DOC"
);
}
};
class
CorrelationOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input1"
),
"Input"
,
"X"
,
"CorrelationOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input2"
),
"Input"
,
"Y"
,
"CorrelationOp"
);
int
stride1
=
ctx
->
Attrs
().
Get
<
int
>
(
"stride1"
);
int
stride2
=
ctx
->
Attrs
().
Get
<
int
>
(
"stride2"
);
int
max_displacement
=
ctx
->
Attrs
().
Get
<
int
>
(
"max_displacement"
);
int
pad_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"pad_size"
);
int
kernel_size
=
ctx
->
Attrs
().
Get
<
int
>
(
"kernel_size"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input1"
);
auto
in2_dims
=
ctx
->
GetInputDim
(
"Input2"
);
PADDLE_ENFORCE_EQ
(
in_dims
.
size
()
==
4
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(X) of CorrelationOp must be 4 dims."
"But received dims is %d."
,
in_dims
.
size
()));
PADDLE_ENFORCE_EQ
(
in2_dims
.
size
()
==
4
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Y) of CorrelationOp must be 4 dims."
"But received dims is %d."
,
in2_dims
.
size
()));
std
::
vector
<
int64_t
>
output_shape
=
CorrelationOutputSize
(
in_dims
[
0
],
in_dims
[
2
],
in_dims
[
3
],
stride1
,
stride2
,
kernel_size
,
pad_size
,
max_displacement
);
ctx
->
SetOutputDim
(
"Output"
,
framework
::
make_ddim
(
output_shape
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input1"
);
PADDLE_ENFORCE_EQ
(
input_data_type
,
ctx
.
Input
<
Tensor
>
(
"Input2"
)
->
type
(),
platform
::
errors
::
InvalidArgument
(
"X and Y shoule have the same datatype"
));
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
override
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
};
template
<
typename
T
>
class
CorrelationOpGradMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"correlation_grad"
);
op
->
SetInput
(
"Input1"
,
this
->
Input
(
"Input1"
));
op
->
SetInput
(
"Input2"
,
this
->
Input
(
"Input2"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Output"
),
this
->
OutputGrad
(
"Output"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Input1"
),
this
->
InputGrad
(
"Input1"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Input2"
),
this
->
InputGrad
(
"Input2"
));
op
->
SetAttrMap
(
this
->
Attrs
());
}
};
class
CorrelationOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input1"
),
"Input"
,
"X"
,
"CorrelationOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Input2"
),
"Input"
,
"Y"
,
"CorrelationOp"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Output"
)),
"Input"
,
"Output@GRAD"
,
"CorrelationGradOp"
);
auto
in1_dims
=
ctx
->
GetInputDim
(
"Input1"
);
auto
in2_dims
=
ctx
->
GetInputDim
(
"Input2"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Input1"
),
in1_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Input2"
),
in2_dims
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"Input1"
),
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
CorrelationKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
Unimplemented
(
"Correlation only supports GPU now."
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
correlation
,
ops
::
CorrelationOp
,
ops
::
CorrelationOpMaker
,
ops
::
CorrelationOpGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
CorrelationOpGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
correlation_grad
,
ops
::
CorrelationOpGrad
);
REGISTER_OP_CPU_KERNEL
(
correlation
,
ops
::
CorrelationKernel
<
float
>
,
ops
::
CorrelationKernel
<
double
>
);
paddle/fluid/operators/correlation_op.cu
0 → 100644
浏览文件 @
8df5b4d6
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
#define THREADS_PER_BLOCK 32
#define FULL_MASK 0xffffffff
using
framework
::
Tensor
;
using
DataLayout
=
framework
::
DataLayout
;
template
<
typename
T
>
__forceinline__
__device__
T
warpReduceSum
(
T
val
)
{
for
(
int
offset
=
16
;
offset
>
0
;
offset
/=
2
)
{
val
+=
__shfl_down_sync
(
FULL_MASK
,
val
,
offset
);
}
return
val
;
}
template
<
typename
T
>
__forceinline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
int
lane
=
threadIdx
.
x
%
warpSize
;
int
wid
=
threadIdx
.
x
/
warpSize
;
val
=
warpReduceSum
(
val
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
val
=
(
threadIdx
.
x
<
blockDim
.
x
/
warpSize
)
?
shared
[
lane
]
:
0
;
if
(
wid
==
0
)
val
=
warpReduceSum
(
val
);
return
val
;
}
template
<
typename
T
>
__global__
void
set_zero
(
T
*
x
,
int
num
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
x
[
i
]
=
static_cast
<
T
>
(
0
);
}
template
<
typename
T
>
__global__
void
channel_first
(
const
T
*
input
,
T
*
rinput
,
const
int
channel
,
const
int
height
,
const
int
width
,
const
int
pad_size
)
{
int
n
=
blockIdx
.
x
;
int
h
=
blockIdx
.
y
;
int
w
=
blockIdx
.
z
;
int
ch_off
=
threadIdx
.
x
;
T
value
;
int
dimchw
=
channel
*
height
*
width
;
int
dimhw
=
height
*
width
;
int
p_dimw
=
(
width
+
2
*
pad_size
);
int
p_dimh
=
(
height
+
2
*
pad_size
);
int
p_dimchw
=
channel
*
p_dimw
*
p_dimh
;
int
p_dimcw
=
channel
*
p_dimw
;
for
(
int
c
=
ch_off
;
c
<
channel
;
c
+=
THREADS_PER_BLOCK
)
{
value
=
input
[
n
*
dimchw
+
c
*
dimhw
+
h
*
width
+
w
];
rinput
[
n
*
p_dimchw
+
(
h
+
pad_size
)
*
p_dimcw
+
(
w
+
pad_size
)
*
channel
+
c
]
=
value
;
}
}
template
<
typename
T
>
__global__
void
correlation_forward
(
T
*
output
,
const
int
output_channel
,
const
int
output_height
,
const
int
output_width
,
const
T
*
rinput1
,
const
int
input_channel
,
const
int
input_height
,
const
int
input_width
,
const
T
*
rinput2
,
const
int
pad_size
,
const
int
kernel_size
,
const
int
max_displacement
,
const
int
stride1
,
const
int
stride2
)
{
int
p_input_width
=
input_width
+
2
*
pad_size
;
int
p_input_height
=
input_height
+
2
*
pad_size
;
int
kernel_rad
=
(
kernel_size
-
1
)
/
2
;
int
displacement_rad
=
max_displacement
/
stride2
;
int
displacement_size
=
2
*
displacement_rad
+
1
;
int
n
=
blockIdx
.
x
;
int
h1
=
blockIdx
.
y
*
stride1
+
max_displacement
;
int
w1
=
blockIdx
.
z
*
stride1
+
max_displacement
;
int
c
=
threadIdx
.
x
;
int
p_dimchw
=
p_input_height
*
p_input_width
*
input_channel
;
int
p_dimcw
=
p_input_width
*
input_channel
;
int
p_dimc
=
input_channel
;
int
t_dimchw
=
output_channel
*
output_height
*
output_width
;
int
t_dimhw
=
output_height
*
output_width
;
int
t_dimw
=
output_width
;
int
nelems
=
kernel_size
*
kernel_size
*
p_dimc
;
for
(
int
tj
=
-
displacement_rad
;
tj
<=
displacement_rad
;
++
tj
)
{
for
(
int
ti
=
-
displacement_rad
;
ti
<=
displacement_rad
;
++
ti
)
{
int
w2
=
w1
+
ti
*
stride2
;
int
h2
=
h1
+
tj
*
stride2
;
T
acc0
=
0
;
for
(
int
j
=
-
kernel_rad
;
j
<=
kernel_rad
;
++
j
)
{
for
(
int
i
=
-
kernel_rad
;
i
<=
kernel_rad
;
++
i
)
{
for
(
int
ch
=
c
;
ch
<
p_dimc
;
ch
+=
blockDim
.
x
)
{
int
index1
=
n
*
p_dimchw
+
(
h1
+
j
)
*
p_dimcw
+
(
w1
+
i
)
*
p_dimc
+
ch
;
int
index2
=
n
*
p_dimchw
+
(
h2
+
j
)
*
p_dimcw
+
(
w2
+
i
)
*
p_dimc
+
ch
;
acc0
+=
static_cast
<
T
>
(
rinput1
[
index1
]
*
rinput2
[
index2
]);
}
}
}
if
(
blockDim
.
x
==
warpSize
)
{
__syncwarp
();
acc0
=
warpReduceSum
(
acc0
);
}
else
{
__syncthreads
();
acc0
=
blockReduceSum
(
acc0
);
}
if
(
threadIdx
.
x
==
0
)
{
int
tc
=
(
tj
+
displacement_rad
)
*
displacement_size
+
(
ti
+
displacement_rad
);
const
int
t_index
=
n
*
t_dimchw
+
tc
*
t_dimhw
+
blockIdx
.
y
*
t_dimw
+
blockIdx
.
z
;
output
[
t_index
]
=
static_cast
<
T
>
(
acc0
/
nelems
);
}
}
}
}
// class CorrelationKernel<platform::CUDADeviceContext, T>
template
<
typename
T
>
class
CorrelationCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
InvalidArgument
(
"Correlation only supports GPU now."
));
auto
*
input1
=
ctx
.
Input
<
Tensor
>
(
"Input1"
);
auto
*
input2
=
ctx
.
Input
<
Tensor
>
(
"Input2"
);
int
pad_size
=
ctx
.
Attr
<
int
>
(
"pad_size"
);
int
kernel_size
=
ctx
.
Attr
<
int
>
(
"kernel_size"
);
int
stride1
=
ctx
.
Attr
<
int
>
(
"stride1"
);
int
stride2
=
ctx
.
Attr
<
int
>
(
"stride2"
);
int
max_displacement
=
ctx
.
Attr
<
int
>
(
"max_displacement"
);
int
corr_type_multiply
=
ctx
.
Attr
<
int
>
(
"corr_type_multiply"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Output"
);
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
// base on input1, NCHW
auto
in_dims
=
input1
->
dims
();
int
N
=
in_dims
[
0
];
int
C
=
in_dims
[
1
];
int
H
=
in_dims
[
2
];
int
W
=
in_dims
[
3
];
int
padded_input_height
=
H
+
2
*
pad_size
;
int
padded_input_width
=
W
+
2
*
pad_size
;
Tensor
rinput1
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
(
{
N
,
padded_input_height
,
padded_input_width
,
C
},
dev_ctx
);
rinput1
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Tensor
rinput2
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
(
{
N
,
padded_input_height
,
padded_input_width
,
C
},
dev_ctx
);
rinput2
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
<<<
(
rinput1
.
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
rinput1
.
data
<
T
>
(),
rinput1
.
numel
());
set_zero
<<<
(
rinput2
.
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
rinput2
.
data
<
T
>
(),
rinput2
.
numel
());
set_zero
<<<
(
output
->
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
output
->
data
<
T
>
(),
output
->
numel
());
auto
out_dims
=
output
->
dims
();
int
OC
=
out_dims
[
1
];
int
OH
=
out_dims
[
2
];
int
OW
=
out_dims
[
3
];
dim3
blocks_grid
(
N
,
H
,
W
);
dim3
threads_block
(
THREADS_PER_BLOCK
);
channel_first
<
T
><<<
blocks_grid
,
threads_block
,
0
,
dev_ctx
.
stream
()
>>>
(
input1
->
data
<
T
>
(),
rinput1
.
data
<
T
>
(),
C
,
H
,
W
,
pad_size
);
channel_first
<
T
><<<
blocks_grid
,
threads_block
,
0
,
dev_ctx
.
stream
()
>>>
(
input2
->
data
<
T
>
(),
rinput2
.
data
<
T
>
(),
C
,
H
,
W
,
pad_size
);
dim3
threadsPerBlock
(
THREADS_PER_BLOCK
);
dim3
totalBlocksCorr
(
N
,
OH
,
OW
);
correlation_forward
<
T
><<<
totalBlocksCorr
,
threadsPerBlock
,
0
,
dev_ctx
.
stream
()
>>>
(
output
->
data
<
T
>
(),
OC
,
OH
,
OW
,
rinput1
.
data
<
T
>
(),
C
,
H
,
W
,
rinput2
.
data
<
T
>
(),
pad_size
,
kernel_size
,
max_displacement
,
stride1
,
stride2
);
}
};
template
<
typename
T
>
__global__
void
correlation_backward_input1
(
int
item
,
T
*
grad_input1
,
const
int
input_channel
,
const
int
input_height
,
const
int
input_width
,
const
T
*
grad_output
,
const
int
output_channel
,
const
int
output_height
,
const
int
output_width
,
const
T
*
rinput2
,
const
int
pad_size
,
const
int
kernel_size
,
const
int
max_displacement
,
const
int
stride1
,
const
int
stride2
)
{
int
n
=
item
;
int
h
=
blockIdx
.
x
*
stride1
+
pad_size
;
int
w
=
blockIdx
.
y
*
stride1
+
pad_size
;
int
c
=
blockIdx
.
z
;
int
tch_off
=
threadIdx
.
x
;
int
kernel_rad
=
(
kernel_size
-
1
)
/
2
;
int
displacement_rad
=
max_displacement
/
stride2
;
int
displacement_size
=
2
*
displacement_rad
+
1
;
int
xmin
=
(
w
-
kernel_rad
-
max_displacement
)
/
stride1
;
int
ymin
=
(
h
-
kernel_rad
-
max_displacement
)
/
stride1
;
int
xmax
=
(
w
+
kernel_rad
-
max_displacement
)
/
stride1
;
int
ymax
=
(
h
+
kernel_rad
-
max_displacement
)
/
stride1
;
if
(
xmax
<
0
||
ymax
<
0
||
xmin
>=
output_width
||
ymin
>=
output_height
)
{
return
;
}
if
(
xmin
>
xmax
||
ymin
>
ymax
)
{
return
;
}
xmin
=
max
(
0
,
xmin
);
xmax
=
min
(
output_width
-
1
,
xmax
);
ymin
=
max
(
0
,
ymin
);
ymax
=
min
(
output_height
-
1
,
ymax
);
int
p_input_width
=
input_width
+
2
*
pad_size
;
int
p_input_height
=
input_height
+
2
*
pad_size
;
int
p_dimchw
=
input_channel
*
p_input_height
*
p_input_width
;
int
p_dimcw
=
input_channel
*
p_input_width
;
int
p_dimc
=
input_channel
;
int
t_dimchw
=
output_channel
*
output_height
*
output_width
;
int
t_dimhw
=
output_height
*
output_width
;
int
t_dimw
=
output_width
;
int
o_dimchw
=
input_channel
*
input_height
*
input_width
;
int
o_dimhw
=
input_height
*
input_width
;
int
o_dimw
=
input_width
;
int
nelems
=
kernel_size
*
kernel_size
*
input_channel
;
__shared__
T
prod_sum
[
THREADS_PER_BLOCK
];
prod_sum
[
tch_off
]
=
0
;
for
(
int
tc
=
tch_off
;
tc
<
output_channel
;
tc
+=
THREADS_PER_BLOCK
)
{
int
i2
=
(
tc
%
displacement_size
-
displacement_rad
)
*
stride2
;
int
j2
=
(
tc
/
displacement_size
-
displacement_rad
)
*
stride2
;
int
index2
=
n
*
p_dimchw
+
(
h
+
j2
)
*
p_dimcw
+
(
w
+
i2
)
*
p_dimc
+
c
;
T
val2
=
rinput2
[
index2
];
for
(
int
j
=
ymin
;
j
<=
ymax
;
++
j
)
{
for
(
int
i
=
xmin
;
i
<=
xmax
;
++
i
)
{
int
t_index
=
n
*
t_dimchw
+
tc
*
t_dimhw
+
j
*
t_dimw
+
i
;
prod_sum
[
tch_off
]
+=
grad_output
[
t_index
]
*
val2
;
}
}
}
__syncthreads
();
if
(
tch_off
==
0
)
{
T
reduce_sum
=
0
;
for
(
int
index
=
0
;
index
<
THREADS_PER_BLOCK
;
index
++
)
{
reduce_sum
+=
prod_sum
[
index
];
}
const
int
index1
=
n
*
o_dimchw
+
c
*
o_dimhw
+
(
h
-
pad_size
)
*
o_dimw
+
(
w
-
pad_size
);
grad_input1
[
index1
]
=
static_cast
<
T
>
(
reduce_sum
/
nelems
);
}
}
template
<
typename
T
>
__global__
void
correlation_backward_input2
(
int
item
,
T
*
grad_input2
,
const
int
input_channel
,
const
int
input_height
,
const
int
input_width
,
const
T
*
grad_output
,
const
int
output_channel
,
const
int
output_height
,
const
int
output_width
,
const
T
*
rinput1
,
const
int
pad_size
,
const
int
kernel_size
,
const
int
max_displacement
,
const
int
stride1
,
const
int
stride2
)
{
int
n
=
item
;
int
h
=
blockIdx
.
x
*
stride1
+
pad_size
;
int
w
=
blockIdx
.
y
*
stride1
+
pad_size
;
int
c
=
blockIdx
.
z
;
int
tch_off
=
threadIdx
.
x
;
int
kernel_rad
=
(
kernel_size
-
1
)
/
2
;
int
displacement_rad
=
max_displacement
/
stride2
;
int
displacement_size
=
2
*
displacement_rad
+
1
;
int
p_input_width
=
input_width
+
2
*
pad_size
;
int
p_input_height
=
input_height
+
2
*
pad_size
;
int
p_dimchw
=
input_channel
*
p_input_height
*
p_input_width
;
int
p_dimcw
=
input_channel
*
p_input_width
;
int
p_dimc
=
input_channel
;
int
t_dimchw
=
output_channel
*
output_height
*
output_width
;
int
t_dimhw
=
output_height
*
output_width
;
int
t_dimw
=
output_width
;
int
o_dimchw
=
input_channel
*
input_height
*
input_width
;
int
o_dimhw
=
input_height
*
input_width
;
int
o_dimw
=
input_width
;
int
nelems
=
kernel_size
*
kernel_size
*
input_channel
;
__shared__
T
prod_sum
[
THREADS_PER_BLOCK
];
prod_sum
[
tch_off
]
=
0
;
for
(
int
tc
=
tch_off
;
tc
<
output_channel
;
tc
+=
THREADS_PER_BLOCK
)
{
int
i2
=
(
tc
%
displacement_size
-
displacement_rad
)
*
stride2
;
int
j2
=
(
tc
/
displacement_size
-
displacement_rad
)
*
stride2
;
int
xmin
=
(
w
-
kernel_rad
-
max_displacement
-
i2
)
/
stride1
;
int
ymin
=
(
h
-
kernel_rad
-
max_displacement
-
j2
)
/
stride1
;
int
xmax
=
(
w
+
kernel_rad
-
max_displacement
-
i2
)
/
stride1
;
int
ymax
=
(
h
+
kernel_rad
-
max_displacement
-
j2
)
/
stride1
;
if
(
xmax
<
0
||
ymax
<
0
||
xmin
>=
output_width
||
ymin
>=
output_height
)
{
continue
;
}
if
(
xmin
>
xmax
||
ymin
>
ymax
)
{
continue
;
}
xmin
=
max
(
0
,
xmin
);
xmax
=
min
(
output_width
-
1
,
xmax
);
ymin
=
max
(
0
,
ymin
);
ymax
=
min
(
output_height
-
1
,
ymax
);
int
index1
=
n
*
p_dimchw
+
(
h
-
j2
)
*
p_dimcw
+
(
w
-
i2
)
*
p_dimc
+
c
;
T
val1
=
rinput1
[
index1
];
for
(
int
j
=
ymin
;
j
<=
ymax
;
++
j
)
{
for
(
int
i
=
xmin
;
i
<=
xmax
;
++
i
)
{
int
t_index
=
n
*
t_dimchw
+
tc
*
t_dimhw
+
j
*
t_dimw
+
i
;
prod_sum
[
tch_off
]
+=
grad_output
[
t_index
]
*
val1
;
}
}
}
__syncthreads
();
if
(
tch_off
==
0
)
{
T
reduce_sum
=
0
;
for
(
int
index
=
0
;
index
<
THREADS_PER_BLOCK
;
index
++
)
{
reduce_sum
+=
prod_sum
[
index
];
}
const
int
index2
=
n
*
o_dimchw
+
c
*
o_dimhw
+
(
h
-
pad_size
)
*
o_dimw
+
(
w
-
pad_size
);
grad_input2
[
index2
]
=
static_cast
<
T
>
(
reduce_sum
/
nelems
);
}
}
template
<
typename
T
>
class
CorrelationCUDAGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
InvalidArgument
(
"Correlation only supports GPU now."
));
const
auto
*
input1
=
ctx
.
Input
<
Tensor
>
(
"Input1"
);
const
auto
*
input2
=
ctx
.
Input
<
Tensor
>
(
"Input2"
);
const
auto
*
grad_output
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Output"
));
const
int
pad_size
=
ctx
.
Attr
<
int
>
(
"pad_size"
);
const
int
kernel_size
=
ctx
.
Attr
<
int
>
(
"kernel_size"
);
const
int
stride1
=
ctx
.
Attr
<
int
>
(
"stride1"
);
const
int
stride2
=
ctx
.
Attr
<
int
>
(
"stride2"
);
const
int
max_displacement
=
ctx
.
Attr
<
int
>
(
"max_displacement"
);
const
int
corr_type_multiply
=
ctx
.
Attr
<
int
>
(
"corr_type_multiply"
);
auto
*
grad_input1
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input1"
));
grad_input1
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad_input2
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Input2"
));
grad_input2
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
in_dims
=
input1
->
dims
();
int
N
=
in_dims
[
0
];
int
C
=
in_dims
[
1
];
int
H
=
in_dims
[
2
];
int
W
=
in_dims
[
3
];
int
padded_input_height
=
H
+
2
*
pad_size
;
int
padded_input_width
=
W
+
2
*
pad_size
;
Tensor
rinput1
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
(
{
N
,
padded_input_height
,
padded_input_width
,
C
},
dev_ctx
);
rinput1
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
Tensor
rinput2
=
ctx
.
AllocateTmpTensor
<
T
,
platform
::
CUDADeviceContext
>
(
{
N
,
padded_input_height
,
padded_input_width
,
C
},
dev_ctx
);
rinput2
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
set_zero
<<<
(
rinput1
.
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
rinput1
.
data
<
T
>
(),
rinput1
.
numel
());
set_zero
<<<
(
rinput2
.
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
rinput2
.
data
<
T
>
(),
rinput2
.
numel
());
set_zero
<<<
(
grad_input1
->
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
grad_input1
->
data
<
T
>
(),
grad_input1
->
numel
());
set_zero
<<<
(
grad_input2
->
numel
()
+
512
-
1
)
/
512
,
512
,
0
,
dev_ctx
.
stream
()
>>>
(
grad_input2
->
data
<
T
>
(),
grad_input2
->
numel
());
auto
grad_out_dims
=
grad_output
->
dims
();
int
GOC
=
grad_out_dims
[
1
];
int
GOH
=
grad_out_dims
[
2
];
int
GOW
=
grad_out_dims
[
3
];
dim3
blocks_grid
(
N
,
H
,
W
);
dim3
threads_block
(
THREADS_PER_BLOCK
);
channel_first
<
T
><<<
blocks_grid
,
threads_block
,
0
,
dev_ctx
.
stream
()
>>>
(
input1
->
data
<
T
>
(),
rinput1
.
data
<
T
>
(),
C
,
H
,
W
,
pad_size
);
channel_first
<
T
><<<
blocks_grid
,
threads_block
,
0
,
dev_ctx
.
stream
()
>>>
(
input2
->
data
<
T
>
(),
rinput2
.
data
<
T
>
(),
C
,
H
,
W
,
pad_size
);
dim3
threadsPerBlock
(
THREADS_PER_BLOCK
);
dim3
totalBlocksCorr
(
H
,
W
,
C
);
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
correlation_backward_input1
<
T
><<<
totalBlocksCorr
,
threadsPerBlock
,
0
,
dev_ctx
.
stream
()
>>>
(
n
,
grad_input1
->
data
<
T
>
(),
C
,
H
,
W
,
grad_output
->
data
<
T
>
(),
GOC
,
GOH
,
GOW
,
rinput2
.
data
<
T
>
(),
pad_size
,
kernel_size
,
max_displacement
,
stride1
,
stride2
);
}
for
(
int
n
=
0
;
n
<
N
;
n
++
)
{
correlation_backward_input2
<
T
><<<
totalBlocksCorr
,
threadsPerBlock
,
0
,
dev_ctx
.
stream
()
>>>
(
n
,
grad_input2
->
data
<
T
>
(),
C
,
H
,
W
,
grad_output
->
data
<
T
>
(),
GOC
,
GOH
,
GOW
,
rinput1
.
data
<
T
>
(),
pad_size
,
kernel_size
,
max_displacement
,
stride1
,
stride2
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
correlation
,
ops
::
CorrelationCUDAKernel
<
float
>
,
ops
::
CorrelationCUDAKernel
<
double
>
);
REGISTER_OP_CUDA_KERNEL
(
correlation_grad
,
ops
::
CorrelationCUDAGradKernel
<
float
>
,
ops
::
CorrelationCUDAGradKernel
<
double
>
);
python/paddle/fluid/contrib/layers/nn.py
浏览文件 @
8df5b4d6
...
...
@@ -37,7 +37,7 @@ import warnings
import
inspect
import
numpy
as
np
import
paddle
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.layers
import
utils
from
...
import
unique_name
...
...
@@ -56,7 +56,8 @@ __all__ = [
'match_matrix_tensor'
,
'tree_conv'
,
'fused_embedding_seq_pool'
,
'multiclass_nms2'
,
'search_pyramid_hash'
,
'shuffle_batch'
,
'partial_concat'
,
'sparse_embedding'
,
'partial_sum'
,
'tdm_child'
,
'rank_attention'
,
'tdm_sampler'
,
'batch_fc'
,
'_pull_box_extended_sparse'
,
'bilateral_slice'
'tdm_sampler'
,
'batch_fc'
,
'_pull_box_extended_sparse'
,
'bilateral_slice'
,
'correlation'
]
...
...
@@ -1546,3 +1547,81 @@ def bilateral_slice(x, guide, grid, has_offset, name=None):
attrs
=
{
'has_offset'
:
has_offset
},
outputs
=
{
'Out'
:
out
})
return
out
def
correlation
(
x
,
y
,
pad_size
,
kernel_size
,
max_displacement
,
stride1
,
stride2
,
corr_type_multiply
=
1
):
"""
This operation compute correlation of two tensor.
For more information of correlation, please refer to PWC-Net:
CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume
<https://arxiv.org/pdf/1709.02371.pdf>_
Args:
x(Tensor): The input x is 4-D Tensor with shape [N, C, H, W]. The data type is float32 and float64.
y(Tensor): The input y is 4-D Tensor with shape [N, C, H, W]. The data type is float32 and float64.
pad_size(int): Pad size. The data type is int.
max_displacement(int): Max displacement. The data type is int.
stride1(int): stride size of x. The data type is int.
stride2(int): stride size of y. The data type is int.
corr_type_multiply(int, optional): The type of multiply. The data type is int. Default: 1.
Returns:
Tensor: The data type is same as input tensor.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x1 = fluid.layers.data(name='x1',
shape=x_shape,
dtype=x_type,
append_batch_size=False)
x2 = fluid.layers.data(name='x2',
shape=x_shape,
dtype=x_type,
append_batch_size=False)
out = fluid.contrib.correlation(
x1,
x2,
pad_size=4,
kernel_size=1,
max_displacement=4,
stride1=1,
stride2=1)
"""
helper
=
LayerHelper
(
"correlation"
,
**
locals
())
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
if
paddle
.
fluid
.
in_dygraph_mode
():
attrs
=
(
"pad_size"
,
pad_size
,
"kernel_size"
,
kernel_size
,
"max_displacement"
,
max_displacement
,
"stride1"
,
stride1
,
"stride2"
,
stride2
,
"corr_type_multiply"
,
corr_type_multiply
)
output
=
getattr
(
core
.
ops
,
"correlation"
)(
x
,
y
,
*
attrs
)
else
:
helper
.
append_op
(
type
=
"correlation"
,
inputs
=
{
"Input1"
:
x
,
"Input2"
:
y
},
attrs
=
{
"pad_size"
:
pad_size
,
"kernel_size"
:
kernel_size
,
"max_displacement"
:
max_displacement
,
"stride1"
:
stride1
,
"stride2"
:
stride2
,
"corr_type_multiply"
:
corr_type_multiply
},
outputs
=
{
"Output"
:
output
})
return
output
python/paddle/fluid/contrib/tests/test_correlation.py
0 → 100644
浏览文件 @
8df5b4d6
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
import
paddle.fluid
as
fluid
from
paddle.fluid.dygraph.base
import
to_variable
def
corr
(
x_1
,
x_2
,
pad_size
=
4
,
kernel_size
=
1
,
max_displacement
=
4
,
stride1
=
1
,
stride2
=
1
,
corr_multiply
=
1
):
K
=
kernel_size
rinput1
=
np
.
pad
(
x_1
,
((
0
,
0
),
(
0
,
0
),
(
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)),
mode
=
'constant'
)
rinput2
=
np
.
pad
(
x_2
,
((
0
,
0
),
(
0
,
0
),
(
pad_size
,
pad_size
),
(
pad_size
,
pad_size
)),
mode
=
'constant'
)
rinput1
=
np
.
transpose
(
rinput1
,
(
0
,
2
,
3
,
1
))
rinput2
=
np
.
transpose
(
rinput2
,
(
0
,
2
,
3
,
1
))
B
=
int
(
rinput1
.
shape
[
0
])
H
=
int
(
x_1
.
shape
[
2
])
W
=
int
(
x_2
.
shape
[
3
])
d
=
max_displacement
D
=
2
*
d
+
1
output
=
np
.
zeros
((
B
,
D
*
D
,
H
,
W
),
dtype
=
np
.
float32
)
for
b
in
range
(
B
):
for
i
in
range
(
H
):
for
j
in
range
(
W
):
for
k
in
range
(
-
d
,
d
+
1
):
for
l
in
range
(
-
d
,
d
+
1
):
x1_index
=
i
+
pad_size
y1_index
=
j
+
pad_size
x2_index
=
x1_index
+
k
y2_index
=
y1_index
+
l
output
[
b
,
l
+
d
+
D
*
(
k
+
d
),
i
,
j
]
=
np
.
mean
(
rinput1
[
b
,
x1_index
:
x1_index
+
K
,
y1_index
:
y1_index
+
K
]
*
rinput2
[
b
,
x2_index
:
x2_index
+
K
,
y2_index
:
y2_index
+
K
])
return
output
class
TestCorrelationOp
(
unittest
.
TestCase
):
def
test_check_output
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
():
return
np
.
random
.
seed
(
13
)
np
.
set_printoptions
(
threshold
=
np
.
inf
)
x_shape
=
(
2
,
10
,
3
,
3
)
x_type
=
'float32'
x1
=
fluid
.
layers
.
data
(
name
=
'x1'
,
shape
=
x_shape
,
dtype
=
x_type
,
append_batch_size
=
False
,
stop_gradient
=
False
)
x2
=
fluid
.
layers
.
data
(
name
=
'x2'
,
shape
=
x_shape
,
dtype
=
x_type
,
append_batch_size
=
False
,
stop_gradient
=
False
)
x1_np
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
x_type
)
x2_np
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
x_type
)
out_np
=
corr
(
x1_np
,
x2_np
,
pad_size
=
4
,
kernel_size
=
1
,
max_displacement
=
4
,
stride1
=
1
,
stride2
=
1
)
out
=
fluid
.
contrib
.
correlation
(
x1
,
x2
,
pad_size
=
4
,
kernel_size
=
1
,
max_displacement
=
4
,
stride1
=
1
,
stride2
=
1
)
loss
=
fluid
.
layers
.
reduce_mean
(
out
)
optimizer
=
fluid
.
optimizer
.
Momentum
(
0.0001
,
0.9
)
optimizer
.
minimize
(
loss
)
place
=
fluid
.
CUDAPlace
(
0
)
exe
=
fluid
.
Executor
(
place
)
res
=
exe
.
run
(
feed
=
{
'x1'
:
x1_np
,
'x2'
:
x2_np
},
fetch_list
=
[
out
.
name
,
loss
.
name
])
self
.
assertTrue
(
np
.
allclose
(
res
[
0
],
out_np
))
class
Net
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
):
super
(
Net
,
self
).
__init__
(
name_scope
)
def
forward
(
self
,
x1
,
x2
):
y
=
fluid
.
contrib
.
correlation
(
x1
,
x2
,
pad_size
=
4
,
kernel_size
=
1
,
max_displacement
=
4
,
stride1
=
1
,
stride2
=
1
)
return
y
class
TestCorrelationOpDyGraph
(
unittest
.
TestCase
):
def
test_check_output
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
():
return
np
.
random
.
seed
(
13
)
np
.
set_printoptions
(
threshold
=
np
.
inf
)
x_shape
=
(
2
,
10
,
3
,
3
)
x_type
=
'float32'
place
=
fluid
.
CUDAPlace
(
0
)
with
fluid
.
dygraph
.
guard
(
place
):
x1_np
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
x_type
)
x2_np
=
np
.
random
.
randn
(
2
,
3
,
4
,
5
).
astype
(
x_type
)
out_np
=
corr
(
x1_np
,
x2_np
,
pad_size
=
4
,
kernel_size
=
1
,
max_displacement
=
4
,
stride1
=
1
,
stride2
=
1
)
x1
=
to_variable
(
x1_np
)
x2
=
to_variable
(
x2_np
)
corr_pd
=
Net
(
'corr_pd'
)
y
=
corr_pd
(
x1
,
x2
)
out
=
y
.
numpy
()
self
.
assertTrue
(
np
.
allclose
(
out
,
out_np
))
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录