Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
bb33c2b3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
bb33c2b3
编写于
9月 30, 2017
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix kernel func
上级
2ed56df1
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
270 addition
and
8 deletion
+270
-8
paddle/operators/math/pooling.cc
paddle/operators/math/pooling.cc
+227
-0
paddle/operators/math/pooling.cu
paddle/operators/math/pooling.cu
+4
-6
paddle/operators/math/pooling.h
paddle/operators/math/pooling.h
+37
-0
paddle/operators/pool_with_index_op.h
paddle/operators/pool_with_index_op.h
+2
-2
未找到文件。
paddle/operators/math/pooling.cc
浏览文件 @
bb33c2b3
...
@@ -458,6 +458,233 @@ template class Pool3dGradFunctor<
...
@@ -458,6 +458,233 @@ template class Pool3dGradFunctor<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxPoolGrad
<
double
>,
double
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxPoolGrad
<
double
>,
double
>
;
template
class
Pool3dGradFunctor
<
template
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
double
>,
double
>
;
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
double
>,
double
>
;
template
<
typename
T
>
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
int
input_stride
=
input_height
*
input_width
;
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
std
::
max
(
hstart
,
0
);
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
T
ele
=
static_cast
<
T
>
(
-
FLT_MAX
);
int
index
=
-
1
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
ele
<
input_data
[
h
*
input_width
+
w
])
{
ele
=
input_data
[
h
*
input_width
+
w
];
index
=
h
*
input_width
+
w
;
}
}
}
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
mask_data
[
ph
*
output_width
+
pw
]
=
index
;
}
}
// offset
input_data
+=
input_stride
;
output_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
template
<
typename
T
>
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_height
=
input_grad
.
dims
()[
2
];
const
int
input_width
=
input_grad
.
dims
()[
3
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
input_stride
=
input_height
*
input_width
;
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
const
int
output_idx
=
ph
*
output_width
+
pw
;
const
int
input_idx
=
static_cast
<
int
>
(
mask_data
[
output_idx
]);
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
}
}
// offset
input_grad_data
+=
input_stride
;
output_grad_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
template
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
<
typename
T
>
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
const
int
stride_depth
=
strides
[
0
];
const
int
stride_height
=
strides
[
1
];
const
int
stride_width
=
strides
[
2
];
const
int
padding_depth
=
paddings
[
0
];
const
int
padding_height
=
paddings
[
1
];
const
int
padding_width
=
paddings
[
2
];
const
int
input_stride
=
input_depth
*
input_height
*
input_width
;
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
pd
=
0
;
pd
<
output_depth
;
++
pd
)
{
int
dstart
=
pd
*
stride_depth
-
padding_depth
;
int
dend
=
std
::
min
(
dstart
+
ksize_depth
,
input_depth
);
dstart
=
std
::
max
(
dstart
,
0
);
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
std
::
max
(
hstart
,
0
);
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
T
ele
=
static_cast
<
T
>
(
-
FLT_MAX
);
int
index
=
-
1
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
int
input_idx
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
if
(
ele
<
input_data
[
input_idx
])
{
index
=
input_idx
;
ele
=
input_data
[
input_idx
];
}
}
}
}
output_data
[
output_idx
]
=
ele
;
mask_data
[
output_idx
]
=
index
;
}
}
}
// offset
input_data
+=
input_stride
;
output_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
template
<
typename
T
>
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_depth
=
input_grad
.
dims
()[
2
];
const
int
input_height
=
input_grad
.
dims
()[
3
];
const
int
input_width
=
input_grad
.
dims
()[
4
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_depth
=
output_grad
.
dims
()[
2
];
const
int
output_height
=
output_grad
.
dims
()[
3
];
const
int
output_width
=
output_grad
.
dims
()[
4
];
const
int
input_stride
=
input_depth
*
input_height
*
input_width
;
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
pd
=
0
;
pd
<
output_depth
;
++
pd
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
const
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
const
int
input_idx
=
static_cast
<
int
>
(
mask_data
[
output_idx
]);
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
}
}
}
// offset
input_grad_data
+=
input_stride
;
output_grad_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
template
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/math/pooling.cu
浏览文件 @
bb33c2b3
...
@@ -637,7 +637,7 @@ __global__ void KernelMaxPool2dWithIdx(
...
@@ -637,7 +637,7 @@ __global__ void KernelMaxPool2dWithIdx(
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
padding_height
,
const
int
padding_width
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
...
@@ -676,7 +676,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(
...
@@ -676,7 +676,7 @@ __global__ void KernelMaxPool2DWithIdxGrad(
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
output_height
,
const
int
output_width
,
const
int
ksize_height
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
ksize_width
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
padding_height
,
const
int
padding_width
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
=
index
%
input_width
;
int
w_offset
=
index
%
input_width
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
...
@@ -766,7 +766,6 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
...
@@ -766,7 +766,6 @@ class MaxPool2dWithIndexGradFunctor<platform::GPUPlace, T> {
const
int
input_channels
=
input_grad
.
dims
()[
1
];
const
int
input_channels
=
input_grad
.
dims
()[
1
];
const
int
input_height
=
input_grad
.
dims
()[
2
];
const
int
input_height
=
input_grad
.
dims
()[
2
];
const
int
input_width
=
input_grad
.
dims
()[
3
];
const
int
input_width
=
input_grad
.
dims
()[
3
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
0
];
...
@@ -810,7 +809,7 @@ __global__ void KernelMaxPool3DWithIdx(
...
@@ -810,7 +809,7 @@ __global__ void KernelMaxPool3DWithIdx(
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
padding_width
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
pw
=
index
%
output_width
;
int
pw
=
index
%
output_width
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
int
ph
=
(
index
/
output_width
)
%
output_height
;
...
@@ -858,7 +857,7 @@ __global__ void KernelMaxPool3DWithIdxGrad(
...
@@ -858,7 +857,7 @@ __global__ void KernelMaxPool3DWithIdxGrad(
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
ksize_width
,
const
int
stride_depth
,
const
int
stride_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
stride_width
,
const
int
padding_depth
,
const
int
padding_height
,
const
int
padding_width
)
{
const
int
padding_width
)
{
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
(
nthreads
)
;
for
(
int
index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
index
<
nthreads
;
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
w_offset
=
index
%
input_width
;
int
w_offset
=
index
%
input_width
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
int
h_offset
=
(
index
/
input_width
)
%
input_height
;
...
@@ -969,7 +968,6 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
...
@@ -969,7 +968,6 @@ class MaxPool3dWithIndexGradFunctor<platform::GPUPlace, T> {
const
int
input_depth
=
input_grad
.
dims
()[
2
];
const
int
input_depth
=
input_grad
.
dims
()[
2
];
const
int
input_height
=
input_grad
.
dims
()[
3
];
const
int
input_height
=
input_grad
.
dims
()[
3
];
const
int
input_width
=
input_grad
.
dims
()[
4
];
const
int
input_width
=
input_grad
.
dims
()[
4
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_depth
=
output_grad
.
dims
()[
2
];
const
int
output_depth
=
output_grad
.
dims
()[
2
];
const
int
output_height
=
output_grad
.
dims
()[
3
];
const
int
output_height
=
output_grad
.
dims
()[
3
];
const
int
output_width
=
output_grad
.
dims
()[
4
];
const
int
output_width
=
output_grad
.
dims
()[
4
];
...
...
paddle/operators/math/pooling.h
浏览文件 @
bb33c2b3
...
@@ -117,6 +117,43 @@ class MaxPool3dGradFunctor {
...
@@ -117,6 +117,43 @@ class MaxPool3dGradFunctor {
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
};
template
<
typename
Place
,
typename
T
>
class
MaxPool2dWithIndexFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool2dWithIndexGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool3dWithIndexFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool3dWithIndexGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/operators/pool_with_index_op.h
浏览文件 @
bb33c2b3
...
@@ -25,7 +25,7 @@ namespace operators {
...
@@ -25,7 +25,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MaxPoolWithIndexKernel
:
public
framework
::
OpKernel
{
class
MaxPoolWithIndexKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
...
@@ -59,7 +59,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel {
...
@@ -59,7 +59,7 @@ class MaxPoolWithIndexKernel : public framework::OpKernel {
};
};
template
<
typename
Place
,
typename
T
>
template
<
typename
Place
,
typename
T
>
class
MaxPoolWithIndexGradKernel
:
public
framework
::
OpKernel
{
class
MaxPoolWithIndexGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
const
Tensor
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录