Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9cc5603d
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9cc5603d
编写于
9月 28, 2020
作者:
W
whs
提交者:
GitHub
9月 28, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make grid support stopping graients. (#27630)
上级
074a71bd
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
68 addition
and
55 deletion
+68
-55
paddle/fluid/operators/grid_sampler_op.cc
paddle/fluid/operators/grid_sampler_op.cc
+0
-2
paddle/fluid/operators/grid_sampler_op.cu
paddle/fluid/operators/grid_sampler_op.cu
+21
-13
paddle/fluid/operators/grid_sampler_op.h
paddle/fluid/operators/grid_sampler_op.h
+47
-40
未找到文件。
paddle/fluid/operators/grid_sampler_op.cc
浏览文件 @
9cc5603d
...
...
@@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Output"
,
framework
::
GradVarName
(
"X"
),
"grid_sampler"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Grid"
)),
"Output"
,
framework
::
GradVarName
(
"Grid"
),
"grid_sampler"
);
auto
input_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
grid_dims
=
ctx
->
GetInputDim
(
"Grid"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
...
...
paddle/fluid/operators/grid_sampler_op.cu
浏览文件 @
9cc5603d
...
...
@@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
}
}
if
(
grad_grid
!=
nullptr
)
{
T
*
gGrid_ptr_NHW
=
grad_grid
+
index
*
grid_sW
;
gGrid_ptr_NHW
[
0
]
=
gix_mult
*
gix
;
gGrid_ptr_NHW
[
1
]
=
giy_mult
*
giy
;
}
}
else
if
(
mode
==
Mode
::
nearest
)
{
int
ix_nearest
=
static_cast
<
int
>
(
::
round
(
ix
));
int
iy_nearest
=
static_cast
<
int
>
(
::
round
(
iy
));
...
...
@@ -412,11 +414,13 @@ __global__ void grid_sampler_cuda_backward_kernel(
in_w
,
grad_output
[
gOut_offset
]);
}
if
(
grad_grid
!=
nullptr
)
{
T
*
gGrid_ptr_NHW
=
grad_grid
+
index
*
grid_sW
;
gGrid_ptr_NHW
[
0
]
=
static_cast
<
T
>
(
0
);
gGrid_ptr_NHW
[
1
]
=
static_cast
<
T
>
(
0
);
}
}
}
}
template
<
typename
T
>
...
...
@@ -460,11 +464,15 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
()(
ctx
.
template
device_context
<
paddle
::
platform
::
CUDADeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
T
*
grid_grad_data
=
nullptr
;
if
(
ctx
.
HasOutput
(
framework
::
GradVarName
(
"Grid"
)))
{
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
grid_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
grid_grad_data
=
grid_grad
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
math
::
SetConstant
<
paddle
::
platform
::
CUDADeviceContext
,
T
>
()(
ctx
.
template
device_context
<
paddle
::
platform
::
CUDADeviceContext
>(),
grid_grad
,
static_cast
<
T
>
(
0
));
}
int
count
=
static_cast
<
int
>
(
n
*
out_h
*
out_w
);
auto
cu_stream
=
dev_ctx
.
stream
();
...
...
@@ -472,8 +480,8 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
int
grid_size
=
(
count
+
block
-
1
)
/
block
;
grid_sampler_cuda_backward_kernel
<
T
><<<
block
,
grid_size
,
0
,
cu_stream
>>>
(
count
,
output_grad
->
data
<
T
>
(),
input
->
data
<
T
>
(),
grid
->
data
<
T
>
(),
n
,
c
,
out_h
,
out_w
,
in_h
,
in_w
,
input_grad
->
data
<
T
>
(),
grid_grad
->
data
<
T
>
()
,
mode
,
padding_mode
,
align_corners
);
out_h
,
out_w
,
in_h
,
in_w
,
input_grad
->
data
<
T
>
(),
grid_grad
_data
,
mode
,
padding_mode
,
align_corners
);
}
};
...
...
paddle/fluid/operators/grid_sampler_op.h
浏览文件 @
9cc5603d
...
...
@@ -450,6 +450,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
auto
output_grad_t
=
EigenTensor
<
T
,
4
>::
From
(
output_grad
);
if
(
grid_grad
!=
nullptr
)
{
Tensor
grid_grad_x
,
grid_grad_y
;
grid_grad_x
.
mutable_data
<
T
>
({
n
,
out_h
,
out_w
},
ctx
.
GetPlace
());
grid_grad_y
.
mutable_data
<
T
>
({
n
,
out_h
,
out_w
},
ctx
.
GetPlace
());
...
...
@@ -490,6 +491,7 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,
grid_grad_data
[
2
*
i
]
=
grid_grad_x_data
[
i
];
grid_grad_data
[
2
*
i
+
1
]
=
grid_grad_y_data
[
i
];
}
}
}
template
<
typename
DeviceContext
,
typename
T
>
...
...
@@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
input_grad
,
static_cast
<
T
>
(
0
));
auto
*
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
Tensor
*
grid_grad
=
nullptr
;
if
(
ctx
.
HasOutput
(
framework
::
GradVarName
(
"Grid"
)))
{
grid_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Grid"
));
grid_grad
->
mutable_data
<
T
>
({
n
,
out_h
,
out_w
,
2
},
ctx
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
()(
ctx
.
template
device_context
<
DeviceContext
>(),
grid_grad
,
static_cast
<
T
>
(
0
));
}
Tensor
grid_x
,
grid_y
;
Tensor
grid_x_scale
,
grid_y_scale
;
calcGridLocationsWithGrad
<
T
>
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录