Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
69fd376b
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
69fd376b
编写于
10月 11, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' of
https://github.com/PaddlePaddle/Paddle
into feature/polish_infer_shape
上级
cb2ef7d9
cec1f598
变更
32
展开全部
隐藏空白更改
内联
并排
Showing
32 changed file
with
1778 addition
and
45 deletion
+1778
-45
paddle/framework/data_type.h
paddle/framework/data_type.h
+0
-1
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+1
-0
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+1
-0
paddle/math/tests/test_GpuProfiler.cpp
paddle/math/tests/test_GpuProfiler.cpp
+1
-1
paddle/memory/detail/buddy_allocator.cc
paddle/memory/detail/buddy_allocator.cc
+1
-1
paddle/memory/detail/system_allocator.cc
paddle/memory/detail/system_allocator.cc
+1
-1
paddle/memory/detail/system_allocator.h
paddle/memory/detail/system_allocator.h
+1
-1
paddle/memory/detail/system_allocator_test.cc
paddle/memory/detail/system_allocator_test.cc
+1
-1
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+1
-1
paddle/memory/memcpy.h
paddle/memory/memcpy.h
+1
-1
paddle/memory/memory.cc
paddle/memory/memory.cc
+1
-1
paddle/memory/memory_test.cc
paddle/memory/memory_test.cc
+1
-1
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+8
-0
paddle/operators/activation_op.cc
paddle/operators/activation_op.cc
+24
-0
paddle/operators/activation_op.h
paddle/operators/activation_op.h
+48
-18
paddle/operators/fill_constant_op.cc
paddle/operators/fill_constant_op.cc
+68
-0
paddle/operators/fill_constant_op.cu
paddle/operators/fill_constant_op.cu
+22
-0
paddle/operators/fill_constant_op.h
paddle/operators/fill_constant_op.h
+37
-0
paddle/operators/interp_op.cc
paddle/operators/interp_op.cc
+113
-0
paddle/operators/math/pooling.cc
paddle/operators/math/pooling.cc
+279
-2
paddle/operators/math/pooling.cu
paddle/operators/math/pooling.cu
+432
-8
paddle/operators/math/pooling.h
paddle/operators/math/pooling.h
+76
-4
paddle/operators/pool_with_index_op.cc
paddle/operators/pool_with_index_op.cc
+228
-0
paddle/operators/pool_with_index_op.cu
paddle/operators/pool_with_index_op.cu
+31
-0
paddle/operators/pool_with_index_op.h
paddle/operators/pool_with_index_op.h
+103
-0
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+1
-1
paddle/platform/enforce.h
paddle/platform/enforce.h
+1
-1
paddle/platform/gpu_info.h
paddle/platform/gpu_info.h
+1
-1
python/paddle/v2/framework/tests/test_activation_op.py
python/paddle/v2/framework/tests/test_activation_op.py
+20
-0
python/paddle/v2/framework/tests/test_fill_constant_op.py
python/paddle/v2/framework/tests/test_fill_constant_op.py
+35
-0
python/paddle/v2/framework/tests/test_interp_op.py
python/paddle/v2/framework/tests/test_interp_op.py
+28
-0
python/paddle/v2/framework/tests/test_pool_max_op.py
python/paddle/v2/framework/tests/test_pool_max_op.py
+212
-0
未找到文件。
paddle/framework/data_type.h
浏览文件 @
69fd376b
...
...
@@ -28,7 +28,6 @@ inline DataType ToDataType(std::type_index type) {
return
DataType
::
INT32
;
}
else
{
PADDLE_THROW
(
"Not supported"
);
return
static_cast
<
DataType
>
(
-
1
);
}
}
...
...
paddle/framework/op_desc.cc
浏览文件 @
69fd376b
...
...
@@ -28,6 +28,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
inputs_
=
inputs
;
outputs_
=
outputs
;
attrs_
=
attrs
;
need_update_
=
true
;
}
OpDesc
*
OpDescBind
::
Proto
()
{
...
...
paddle/framework/type_defs.h
浏览文件 @
69fd376b
...
...
@@ -15,6 +15,7 @@
#pragma once
#include <functional>
#include <map>
#include <memory>
#include "paddle/platform/variant.h"
namespace
paddle
{
...
...
paddle/math/tests/test_GpuProfiler.cpp
浏览文件 @
69fd376b
...
...
@@ -162,4 +162,4 @@ int main(int argc, char** argv) {
return
RUN_ALL_TESTS
();
}
#endif
/* PADDLE_ONLY_CPU */
#endif
paddle/memory/detail/buddy_allocator.cc
浏览文件 @
69fd376b
...
...
@@ -182,7 +182,7 @@ BuddyAllocator::PoolSet::iterator BuddyAllocator::RefillPool() {
max_chunk_size_
=
platform
::
GpuMaxChunkSize
();
}
}
#endif
// PADDLE_ONLY_CPU
#endif
// Allocate a new maximum sized block
size_t
index
=
0
;
...
...
paddle/memory/detail/system_allocator.cc
浏览文件 @
69fd376b
...
...
@@ -134,7 +134,7 @@ void GPUAllocator::Free(void* p, size_t size, size_t index) {
bool
GPUAllocator
::
UseGpu
()
const
{
return
true
;
}
#endif
// PADDLE_ONLY_CPU
#endif
}
// namespace detail
}
// namespace memory
...
...
paddle/memory/detail/system_allocator.h
浏览文件 @
69fd376b
...
...
@@ -51,7 +51,7 @@ class GPUAllocator : public SystemAllocator {
size_t
gpu_alloc_size_
=
0
;
size_t
fallback_alloc_size_
=
0
;
};
#endif
// PADDLE_ONLY_CPU
#endif
}
// namespace detail
}
// namespace memory
...
...
paddle/memory/detail/system_allocator_test.cc
浏览文件 @
69fd376b
...
...
@@ -62,4 +62,4 @@ TEST(GPUAllocator, Alloc) {
TestAllocator
(
a
,
2048
);
TestAllocator
(
a
,
0
);
}
#endif
// PADDLE_ONLY_CPU
#endif
paddle/memory/memcpy.cc
浏览文件 @
69fd376b
...
...
@@ -89,7 +89,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
platform
::
GpuMemcpySync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
);
}
#endif
// PADDLE_ONLY_CPU
#endif
}
// namespace memory
}
// namespace paddle
paddle/memory/memcpy.h
浏览文件 @
69fd376b
...
...
@@ -53,7 +53,7 @@ template <typename DstPlace, typename SrcPlace>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
);
#endif
// PADDLE_ONLY_CPU
#endif
}
// namespace memory
}
// namespace paddle
paddle/memory/memory.cc
浏览文件 @
69fd376b
...
...
@@ -111,7 +111,7 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return
GetGPUBuddyAllocator
(
place
.
device
)
->
Used
();
}
#endif
// PADDLE_ONLY_CPU
#endif
}
// namespace memory
}
// namespace paddle
paddle/memory/memory_test.cc
浏览文件 @
69fd376b
...
...
@@ -135,4 +135,4 @@ TEST(BuddyAllocator, GPUMultAlloc) {
}
}
#endif
// PADDLE_ONLY_CPU
#endif
paddle/operators/CMakeLists.txt
浏览文件 @
69fd376b
...
...
@@ -55,12 +55,20 @@ function(op_library TARGET)
set
(
pybind_flag 1
)
endif
()
# pool_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"pool_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(pool2d);
\n
"
)
endif
()
# pool_with_index_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"pool_with_index_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(max_pool2d_with_index);
\n
"
)
endif
()
# activation_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"activation_op"
)
set
(
pybind_flag 1
)
...
...
paddle/operators/activation_op.cc
浏览文件 @
69fd376b
...
...
@@ -201,6 +201,27 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
template
<
typename
AttrType
>
class
ELUOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
ELUOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor) The input of ELU operator, it shouldn't be empty. Input "
"is flattened and treated as a 1D array."
);
AddOutput
(
"Y"
,
"(Tensor) The output of ELU operator. It has the same shape as "
"the input."
);
AddAttr
<
AttrType
>
(
"alpha"
,
"(float, default 1.0) Alpha value in the elu formulation."
)
.
SetDefault
(
static_cast
<
AttrType
>
(
1.
));
AddComment
(
R"DOC(
ELU activation operator. It applies this element-wise computation on
the input: f(x) = max(0, x) + min(0, alpha * (exp(x) - 1)).
Check .. _Link: https://arxiv.org/abs/1511.07289 for more details.)DOC"
);
}
};
template
<
typename
AttrType
>
class
Relu6OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
...
...
@@ -289,6 +310,9 @@ REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>,
REGISTER_OP
(
soft_relu
,
ops
::
ActivationOp
,
ops
::
SoftReluOpMaker
<
float
>
,
soft_relu_grad
,
ops
::
ActivationOpGrad
);
REGISTER_OP
(
elu
,
ops
::
ActivationOp
,
ops
::
ELUOpMaker
<
float
>
,
elu_grad
,
ops
::
ActivationOpGrad
);
REGISTER_OP
(
relu6
,
ops
::
ActivationOp
,
ops
::
Relu6OpMaker
<
float
>
,
relu6_grad
,
ops
::
ActivationOpGrad
);
...
...
paddle/operators/activation_op.h
浏览文件 @
69fd376b
...
...
@@ -384,6 +384,35 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
}
};
template
<
typename
T
>
struct
ELUFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Y
>
void
operator
()(
Device
d
,
X
x
,
Y
y
)
const
{
y
.
device
(
d
)
=
x
.
cwiseMax
(
static_cast
<
T
>
(
0
))
+
(
alpha
*
(
x
.
exp
()
-
static_cast
<
T
>
(
1
))).
cwiseMin
(
static_cast
<
T
>
(
0
));
}
};
template
<
typename
T
>
struct
ELUGradFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
alpha
;
typename
BaseActivationFunctor
<
T
>::
AttrPair
GetAttrs
()
{
return
{{
"alpha"
,
&
alpha
}};
}
template
<
typename
Device
,
typename
X
,
typename
Y
,
typename
dY
,
typename
dX
>
void
operator
()(
Device
d
,
X
x
,
Y
y
,
dY
dy
,
dX
dx
)
const
{
dx
.
device
(
d
)
=
dy
*
(
x
>
static_cast
<
T
>
(
0
)).
template
cast
<
T
>()
+
dy
*
(
y
+
alpha
)
*
(
x
<
static_cast
<
T
>
(
0
)).
template
cast
<
T
>();
}
};
template
<
typename
T
>
struct
PowFunctor
:
public
BaseActivationFunctor
<
T
>
{
float
factor
;
...
...
@@ -440,21 +469,22 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
}
// namespace operators
}
// namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor)
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
__macro(elu, ELUFunctor, ELUGradFunctor)
paddle/operators/fill_constant_op.cc
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fill_constant_op.h"
namespace
paddle
{
namespace
operators
{
class
FillConstantOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of FillConstantOp should not be null."
);
auto
&
shape
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"shape"
);
std
::
vector
<
int64_t
>
shape_int64
(
shape
.
size
(),
0
);
std
::
transform
(
shape
.
begin
(),
shape
.
end
(),
shape_int64
.
begin
(),
[](
int
a
)
{
return
static_cast
<
int64_t
>
(
a
);
});
auto
dims
=
framework
::
make_ddim
(
shape_int64
);
ctx
->
SetOutputDim
(
"Out"
,
dims
);
}
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
static_cast
<
framework
::
DataType
>
(
ctx
.
Attr
<
int
>
(
"dataType"
));
}
};
class
FillConstantOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
FillConstantOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
framework
::
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddAttr
<
int
>
(
"dataType"
,
"(int, default 5 (FP32)) "
"Output data type"
)
.
SetDefault
(
framework
::
DataType
::
FP32
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"(vector<int>) The shape of the output"
);
AddAttr
<
float
>
(
"value"
,
"(float, default 0) The value to be filled"
)
.
SetDefault
(
0.0
f
);
AddOutput
(
"Out"
,
"(Tensor) Tensor of specified shape will be filled "
"with the specified value"
);
AddComment
(
R"DOC(Fill up a variable with specified constant value.)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
fill_constant
,
ops
::
FillConstantOp
,
ops
::
FillConstantOpMaker
);
REGISTER_OP_CPU_KERNEL
(
fill_constant
,
ops
::
FillConstantOpKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/fill_constant_op.cu
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/framework/op_registry.h"
#include "paddle/operators/fill_constant_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
fill_constant
,
ops
::
FillConstantOpKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/fill_constant_op.h
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
Place
,
typename
T
>
class
FillConstantOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
value
=
ctx
.
Attr
<
T
>
(
"value"
);
auto
out_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
out
);
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
out_eigen
.
device
(
place
)
=
out_eigen
.
constant
(
static_cast
<
T
>
(
value
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/operators/interp_op.cc
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/net_op.h"
namespace
paddle
{
namespace
operators
{
class
InterpOp
:
public
NetOp
{
public:
InterpOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
PADDLE_ENFORCE_NE
(
Input
(
"X"
),
framework
::
kEmptyVarName
,
"Input(X) of InterpOp should not be null."
);
PADDLE_ENFORCE_NE
(
Input
(
"Y"
),
framework
::
kEmptyVarName
,
"Input(Y) of InterpOp should not be null."
);
PADDLE_ENFORCE_NE
(
Input
(
"W"
),
framework
::
kEmptyVarName
,
"Input(W) of InterpOp should not be null."
);
PADDLE_ENFORCE_NE
(
Output
(
"SubOut"
),
framework
::
kEmptyVarName
,
"Output(SubOut) of InterpOp should not be null."
);
PADDLE_ENFORCE_NE
(
Output
(
"MulOut"
),
framework
::
kEmptyVarName
,
"Output(MulOut) of InterpOp should not be null."
);
PADDLE_ENFORCE_NE
(
Output
(
"Out"
),
framework
::
kEmptyVarName
,
"Output(Out) of InterpOp should not be null."
);
// SubOut = X - Y
auto
x
=
Input
(
"X"
);
auto
y
=
Input
(
"Y"
);
auto
sub_out
=
Output
(
"SubOut"
);
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"elementwise_sub"
,
{{
"X"
,
{
x
}},
{
"Y"
,
{
y
}}},
{{
"Out"
,
{
sub_out
}}},
{}));
// MulOut = SubOut * W = (X - Y) * W
auto
w
=
Input
(
"W"
);
auto
mul_out
=
Output
(
"MulOut"
);
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"elementwise_mul"
,
{{
"X"
,
{
sub_out
}},
{
"Y"
,
{
w
}}},
{{
"Out"
,
{
mul_out
}}},
{{
"axis"
,
0
}}));
// Out = MulOut + Y = (X - Y) * W + Y = X * W + Y * (1 - W)
AppendOp
(
framework
::
OpRegistry
::
CreateOp
(
"elementwise_add"
,
{{
"X"
,
{
mul_out
}},
{
"Y"
,
{
y
}}},
{{
"Out"
,
{
Output
(
"Out"
)}}},
{}));
CompleteAddOp
(
false
);
}
};
class
InterpOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
InterpOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(Tensor), 2-D Matrix of shape [batch_size, data_dim]"
"containing data samples, the first input of interp_op"
);
AddInput
(
"Y"
,
"(Tensor), 2-D Matrix of shape `[batch_size, data_dim]`"
"containing data samples, the second input of interp_op"
);
AddInput
(
"W"
,
"(Tensor), 1-D Vector of shape [batch_size],"
"the interpolated values in the half-open interval [0.0, 1.0)"
);
AddOutput
(
"SubOut"
,
"(Tensor), the intermediate subtraction outputs, saving X - Y."
)
.
AsIntermediate
();
AddOutput
(
"MulOut"
,
"(Tensor), the intermediate multiplication outputs,"
"saving the elementwise multiplication of (X - Y) and W."
)
.
AsIntermediate
();
AddOutput
(
"Out"
,
"(Tensor), the output of interp_op, same shape with X,"
"returns the first-dimensional piecewise linear interpolant "
"between X and Y"
);
AddComment
(
R"DOC(
Linear Interpolation with two inputs, used in NEURAL TURING MACHINE.
Equation:
Out.row[i] = X.row[i] * W[i] + Y.row[i] * (1 - W[i])
= (X.row[i] - Y.row[i]) * W[i] + Y.row[i]
Example:
X = [[1,2],[3,4]],
Y = [[2,1],[4,3]],
W = [0.3, 0.4]
Then, Out = [[1.7,1.3],[3.6,3.4]]
where 1.7 = 1*0.3+2*(1-0.3),
1.3 = 2*0.3+1*(1-0.3),
3.6 = 3*0.4+4*(1-0.4),
3.4 = 4*0.4+3*(1-0.4)
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
interp
,
ops
::
InterpOp
,
ops
::
InterpOpMaker
);
paddle/operators/math/pooling.cc
浏览文件 @
69fd376b
...
...
@@ -18,6 +18,11 @@ namespace paddle {
namespace
operators
{
namespace
math
{
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
typename
PoolProcess
,
typename
T
>
class
Pool2dFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
...
...
@@ -73,6 +78,11 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent height
* and width, respectively.
*/
template
<
typename
PoolProcess
,
class
T
>
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
...
...
@@ -135,6 +145,11 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
class
T
>
class
MaxPool2dGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
...
...
@@ -197,7 +212,7 @@ class MaxPool2dGradFunctor<platform::CPUPlace, T> {
};
template
class
MaxPool2dGradFunctor
<
platform
::
CPUPlace
,
float
>;
//
template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template
class
MaxPool2dGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Pool2dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
...
...
@@ -216,6 +231,11 @@ template class Pool2dGradFunctor<
template
class
Pool2dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
double
>,
double
>
;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template
<
typename
PoolProcess
,
class
T
>
class
Pool3dFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
...
...
@@ -286,6 +306,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template
<
typename
PoolProcess
,
class
T
>
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
PoolProcess
,
T
>
{
public:
...
...
@@ -364,6 +389,11 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template
<
class
T
>
class
MaxPool3dGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
...
...
@@ -440,7 +470,7 @@ class MaxPool3dGradFunctor<platform::CPUPlace, T> {
};
template
class
MaxPool3dGradFunctor
<
platform
::
CPUPlace
,
float
>;
//
template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template
class
MaxPool3dGradFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
Pool3dFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxPool
<
float
>,
float
>
;
...
...
@@ -458,6 +488,253 @@ template class Pool3dGradFunctor<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
MaxPoolGrad
<
double
>,
double
>
;
template
class
Pool3dGradFunctor
<
platform
::
CPUPlace
,
paddle
::
operators
::
math
::
AvgPoolGrad
<
double
>,
double
>
;
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
typename
T
>
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_height
=
input
.
dims
()[
2
];
const
int
input_width
=
input
.
dims
()[
3
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_height
=
output
.
dims
()[
2
];
const
int
output_width
=
output
.
dims
()[
3
];
const
int
ksize_height
=
ksize
[
0
];
const
int
ksize_width
=
ksize
[
1
];
const
int
stride_height
=
strides
[
0
];
const
int
stride_width
=
strides
[
1
];
const
int
padding_height
=
paddings
[
0
];
const
int
padding_width
=
paddings
[
1
];
const
int
input_stride
=
input_height
*
input_width
;
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
std
::
max
(
hstart
,
0
);
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
T
ele
=
static_cast
<
T
>
(
-
FLT_MAX
);
int
index
=
-
1
;
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
if
(
ele
<
input_data
[
h
*
input_width
+
w
])
{
ele
=
input_data
[
h
*
input_width
+
w
];
index
=
h
*
input_width
+
w
;
}
}
}
output_data
[
ph
*
output_width
+
pw
]
=
ele
;
mask_data
[
ph
*
output_width
+
pw
]
=
index
;
}
}
// offset
input_data
+=
input_stride
;
output_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
/*
* All tensors are in NCHW format.
* Ksize, strides, paddings are two elements. These two elements represent
* height and width, respectively.
*/
template
<
typename
T
>
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_height
=
input_grad
.
dims
()[
2
];
const
int
input_width
=
input_grad
.
dims
()[
3
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_height
=
output_grad
.
dims
()[
2
];
const
int
output_width
=
output_grad
.
dims
()[
3
];
const
int
input_stride
=
input_height
*
input_width
;
const
int
output_stride
=
output_height
*
output_width
;
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
const
int
output_idx
=
ph
*
output_width
+
pw
;
const
int
input_idx
=
static_cast
<
int
>
(
mask_data
[
output_idx
]);
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
}
}
// offset
input_grad_data
+=
input_stride
;
output_grad_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
template
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool2dWithIndexFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
MaxPool2dWithIndexGradFunctor
<
platform
::
CPUPlace
,
double
>;
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template
<
typename
T
>
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input
.
dims
()[
0
];
const
int
input_depth
=
input
.
dims
()[
2
];
const
int
input_height
=
input
.
dims
()[
3
];
const
int
input_width
=
input
.
dims
()[
4
];
const
int
output_channels
=
output
.
dims
()[
1
];
const
int
output_depth
=
output
.
dims
()[
2
];
const
int
output_height
=
output
.
dims
()[
3
];
const
int
output_width
=
output
.
dims
()[
4
];
const
int
ksize_depth
=
ksize
[
0
];
const
int
ksize_height
=
ksize
[
1
];
const
int
ksize_width
=
ksize
[
2
];
const
int
stride_depth
=
strides
[
0
];
const
int
stride_height
=
strides
[
1
];
const
int
stride_width
=
strides
[
2
];
const
int
padding_depth
=
paddings
[
0
];
const
int
padding_height
=
paddings
[
1
];
const
int
padding_width
=
paddings
[
2
];
const
int
input_stride
=
input_depth
*
input_height
*
input_width
;
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
input_data
=
input
.
data
<
T
>
();
T
*
output_data
=
output
.
mutable_data
<
T
>
(
context
.
GetPlace
());
T
*
mask_data
=
mask
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
pd
=
0
;
pd
<
output_depth
;
++
pd
)
{
int
dstart
=
pd
*
stride_depth
-
padding_depth
;
int
dend
=
std
::
min
(
dstart
+
ksize_depth
,
input_depth
);
dstart
=
std
::
max
(
dstart
,
0
);
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
int
hstart
=
ph
*
stride_height
-
padding_height
;
int
hend
=
std
::
min
(
hstart
+
ksize_height
,
input_height
);
hstart
=
std
::
max
(
hstart
,
0
);
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
int
wstart
=
pw
*
stride_width
-
padding_width
;
int
wend
=
std
::
min
(
wstart
+
ksize_width
,
input_width
);
wstart
=
std
::
max
(
wstart
,
0
);
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
T
ele
=
static_cast
<
T
>
(
-
FLT_MAX
);
int
index
=
-
1
;
for
(
int
d
=
dstart
;
d
<
dend
;
++
d
)
{
for
(
int
h
=
hstart
;
h
<
hend
;
++
h
)
{
for
(
int
w
=
wstart
;
w
<
wend
;
++
w
)
{
int
input_idx
=
(
d
*
input_height
+
h
)
*
input_width
+
w
;
if
(
ele
<
input_data
[
input_idx
])
{
index
=
input_idx
;
ele
=
input_data
[
input_idx
];
}
}
}
}
output_data
[
output_idx
]
=
ele
;
mask_data
[
output_idx
]
=
index
;
}
}
}
// offset
input_data
+=
input_stride
;
output_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
/*
* All tensors are in NCDHW format.
* Ksize, strides, paddings are three elements. These three elements represent
* depth, height and width, respectively.
*/
template
<
typename
T
>
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
)
{
const
int
batch_size
=
input_grad
.
dims
()[
0
];
const
int
input_depth
=
input_grad
.
dims
()[
2
];
const
int
input_height
=
input_grad
.
dims
()[
3
];
const
int
input_width
=
input_grad
.
dims
()[
4
];
const
int
output_channels
=
output_grad
.
dims
()[
1
];
const
int
output_depth
=
output_grad
.
dims
()[
2
];
const
int
output_height
=
output_grad
.
dims
()[
3
];
const
int
output_width
=
output_grad
.
dims
()[
4
];
const
int
input_stride
=
input_depth
*
input_height
*
input_width
;
const
int
output_stride
=
output_depth
*
output_height
*
output_width
;
const
T
*
mask_data
=
mask
.
data
<
T
>
();
const
T
*
output_grad_data
=
output_grad
.
data
<
T
>
();
T
*
input_grad_data
=
input_grad
.
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int
n
=
0
;
n
<
batch_size
;
++
n
)
{
for
(
int
c
=
0
;
c
<
output_channels
;
++
c
)
{
for
(
int
pd
=
0
;
pd
<
output_depth
;
++
pd
)
{
for
(
int
ph
=
0
;
ph
<
output_height
;
++
ph
)
{
for
(
int
pw
=
0
;
pw
<
output_width
;
++
pw
)
{
const
int
output_idx
=
(
pd
*
output_height
+
ph
)
*
output_width
+
pw
;
const
int
input_idx
=
static_cast
<
int
>
(
mask_data
[
output_idx
]);
input_grad_data
[
input_idx
]
+=
output_grad_data
[
output_idx
];
}
}
}
// offset
input_grad_data
+=
input_stride
;
output_grad_data
+=
output_stride
;
mask_data
+=
output_stride
;
}
}
}
};
template
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxPool3dWithIndexFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
MaxPool3dWithIndexGradFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/pooling.cu
浏览文件 @
69fd376b
此差异已折叠。
点击以展开。
paddle/operators/math/pooling.h
浏览文件 @
69fd376b
...
...
@@ -21,15 +21,27 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
namespace
math
{
//////////////////////
#define FLT_MAX __FLT_MAX__ //
#define FLT_MAX \
__FLT_MAX__ // It might need to be placed in another file, but I'm still
// wondering where to put it.
/*
* \brief Extracting simple operations from pooling.
* Both MaxPool and AvgPool need "initial", "compute" and "finalize"
* operation.
* MaxPool initializes temp variable to the negative maximum to find the
* maximum value in the pooling field.
* AvgPool initializes temp variable to the zero to accumulate all values
* in pool pooling, and finally takes the average.
* MaxPoolGrad and AvgPoolGrad are gradient operations respectively.
*/
template
<
class
T
>
class
MaxPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
=
y
>
x
?
y
:
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
poo
_size
)
{}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
poo
l_field
)
{}
};
template
<
class
T
>
...
...
@@ -37,8 +49,9 @@ class AvgPool {
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
0
);
}
DEVICE
inline
void
compute
(
T
&
y
,
const
T
&
x
)
{
y
+=
x
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
poo
_size
)
{
y
/=
poo_size
;
}
DEVICE
inline
void
finalize
(
T
&
y
,
const
T
&
poo
l_field
)
{
y
/=
pool_field
;
}
};
template
<
class
T
>
class
MaxPoolGrad
{
public:
...
...
@@ -57,6 +70,20 @@ class AvgPoolGrad {
}
};
/*
* \brief Getting pooling results, and calculating gradient.
*
* In pool2d, all tensors are in NCHW format. Where N is batch size, C is the
* number of channels, H and W is the height and width of feature.
* In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the
* number of channels, D, H and W is the depth, height and width of feature.
*
* In max pooling, it is possible that the pooling region has multiple maximum
* elements. In this case, we should compute the gradient of the first maximum
* element.
* This is different from average pooling. So we rewrite the max_pool_grad:
* MaxPool2dGradFunctor, MaxPool3dGradFunctor.
*/
template
<
typename
Place
,
typename
PoolProcess
,
typename
T
>
class
Pool2dFunctor
{
public:
...
...
@@ -117,6 +144,51 @@ class MaxPool3dGradFunctor {
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
/*
* \brief Getting max pooling results and corresponding max index, and
* calculating gradient.
* In up-sampling-pooling, it is necessary to know max element index.
* In pool2d, all tensors are in NCHW format. In pool3d, all tensors are in
* NCDHW format.
*/
template
<
typename
Place
,
typename
T
>
class
MaxPool2dWithIndexFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool2dWithIndexGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool3dWithIndexFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
framework
::
Tensor
&
output
,
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
template
<
typename
Place
,
typename
T
>
class
MaxPool3dWithIndexGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
framework
::
Tensor
&
input_grad
,
const
framework
::
Tensor
&
output_grad
,
const
framework
::
Tensor
&
mask
,
std
::
vector
<
int
>&
ksize
,
std
::
vector
<
int
>&
strides
,
std
::
vector
<
int
>&
paddings
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/pool_with_index_op.cc
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace
paddle
{
namespace
operators
{
inline
int
OutputSizeMaxPool
(
int
input_size
,
int
filter_size
,
int
padding
,
int
stride
)
{
int
output_size
=
(
input_size
-
filter_size
+
2
*
padding
)
/
stride
+
1
;
return
output_size
;
}
class
MaxPoolWithIndexOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"X(Input) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Out(Output) of Pooling should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Mask"
),
"Mask(Output) of Pooling should not be null."
);
auto
in_x_dims
=
ctx
->
GetInputDim
(
"X"
);
std
::
vector
<
int
>
ksize
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"paddings"
);
PADDLE_ENFORCE
(
in_x_dims
.
size
()
==
4
||
in_x_dims
.
size
()
==
5
,
"Pooling intput should be 4-D or 5-D"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"globalPooling"
))
{
ksize
.
resize
(
static_cast
<
size_t
>
(
in_x_dims
.
size
())
-
2
);
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
ksize
[
i
]
=
static_cast
<
int
>
(
in_x_dims
[
i
+
2
]);
}
PADDLE_ENFORCE
(
in_x_dims
.
size
()
-
ksize
.
size
()
==
2U
,
"Intput size and pooling size should be consistent."
);
PADDLE_ENFORCE_EQ
(
ksize
.
size
(),
strides
.
size
(),
"Strides size and pooling size should be the same."
);
PADDLE_ENFORCE_EQ
(
ksize
.
size
(),
paddings
.
size
(),
"Paddings size and pooling size should be the same."
);
std
::
vector
<
int64_t
>
output_shape
({
in_x_dims
[
0
],
in_x_dims
[
1
]});
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
output_shape
.
push_back
(
OutputSizeMaxPool
(
in_x_dims
[
i
+
2
],
ksize
[
i
],
paddings
[
i
],
strides
[
i
]));
}
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_shape
));
ctx
->
SetOutputDim
(
"Mask"
,
framework
::
make_ddim
(
output_shape
));
}
};
class
MaxPoolWithIndexOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)),
"Input(X@GRAD) should not be null."
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
};
class
MaxPool2dWithIndexOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
MaxPool2dWithIndexOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image."
);
AddOutput
(
"Out"
,
"The output tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is "
"the number of channels, H and W is the height and "
"width of image."
);
AddOutput
(
"Mask"
,
"The Mask tensor of pooling operator."
"The format of output tensor is also NCHW."
"Where N is batch size, C is the number of channels, H and W "
"is the height and width of image."
"The value in it is the index in current feature map"
);
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
"The pooling size(height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."
);
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr
<
bool
>
(
"globalPooling"
,
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"Strides(height, width) of pooling operator."
"Default {1,1}."
)
.
SetDefault
({
1
,
1
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"Paddings(height, width) of pooling operator."
"Default {0,0}."
)
.
SetDefault
({
0
,
0
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment
(
R"DOC(
The maxPooling2d with index operation calculates the output and the mask
based on the input and ksize, strides, paddings parameters. Input(X) and
output(Out, Mask) are in NCHW format. Where N is batch size, C is the
number of channels, H and W is the height and width of feature.
Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively.
)DOC"
);
}
};
class
MaxPool3dWithIndexOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
MaxPool3dWithIndexOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input tensor of pooling operator. "
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of "
"image."
);
AddOutput
(
"Out"
,
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and "
"width of image."
);
AddOutput
(
"Mask"
,
"The Mask tensor of pooling operator."
"The format of output tensor is also NCDHW."
"Where N is batch size, C is the number of channels, D, H and W "
"is the depth, height and width of image."
"The value in it is the index in current feature map"
);
AddAttr
<
std
::
vector
<
int
>>
(
"ksize"
,
"The pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be "
"specified."
);
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr
<
bool
>
(
"globalPooling"
,
"Whether to use the globalPooling."
"Bool constant equal to false or true."
"Default false."
"If globalPooling = true, ksize is ignored and need not be specified."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"strides"
,
"Strides(depth, height, width) of pooling operator."
"Default {1,1,1}."
)
.
SetDefault
({
1
,
1
,
1
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddAttr
<
std
::
vector
<
int
>>
(
"paddings"
,
"Paddings(depth, height, width) of pooling operator."
"Default {0,0,0}."
)
.
SetDefault
({
0
,
0
,
0
});
// TODO(Chengduo): Add checker. (Currently,
// TypedAttrChecker don't support vector type.)
AddComment
(
R"DOC(
The maxpooling3d with index operation calculates the output and the mask
based on the input and ksize, strides, paddings parameters.
Input(X) and output(Out, Mask) are in NCDHW format. Where N is batch
size, C is the number of channels, D, H and W is the depth, height and
width of feature. Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
max_pool2d_with_index
,
ops
::
MaxPoolWithIndexOp
,
ops
::
MaxPool2dWithIndexOpMaker
,
max_pool2d_with_index_grad
,
ops
::
MaxPoolWithIndexOpGrad
);
REGISTER_OP_CPU_KERNEL
(
max_pool2d_with_index
,
ops
::
MaxPoolWithIndexKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
max_pool2d_with_index_grad
,
ops
::
MaxPoolWithIndexGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
REGISTER_OP
(
max_pool3d_with_index
,
ops
::
MaxPoolWithIndexOp
,
ops
::
MaxPool3dWithIndexOpMaker
,
max_pool3d_with_index_grad
,
ops
::
MaxPoolWithIndexOpGrad
);
REGISTER_OP_CPU_KERNEL
(
max_pool3d_with_index
,
ops
::
MaxPoolWithIndexKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
max_pool3d_with_index_grad
,
ops
::
MaxPoolWithIndexGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
paddle/operators/pool_with_index_op.cu
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_with_index_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
max_pool2d_with_index
,
ops
::
MaxPoolWithIndexKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
max_pool2d_with_index_grad
,
ops
::
MaxPoolWithIndexGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
)
REGISTER_OP_GPU_KERNEL
(
max_pool3d_with_index
,
ops
::
MaxPoolWithIndexKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
max_pool3d_with_index_grad
,
ops
::
MaxPoolWithIndexGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
)
paddle/operators/pool_with_index_op.h
0 → 100644
浏览文件 @
69fd376b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/pooling.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
Place
,
typename
T
>
class
MaxPoolWithIndexKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
in_x
=
context
.
Input
<
Tensor
>
(
"X"
);
Tensor
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
Tensor
*
mask
=
context
.
Output
<
Tensor
>
(
"Mask"
);
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
if
(
context
.
Attr
<
bool
>
(
"globalPooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
ksize
[
i
]
=
static_cast
<
int
>
(
in_x
->
dims
()[
i
+
2
]);
}
}
switch
(
ksize
.
size
())
{
case
2
:
{
paddle
::
operators
::
math
::
MaxPool2dWithIndexFunctor
<
Place
,
T
>
pool2d_forward
;
pool2d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
mask
,
ksize
,
strides
,
paddings
);
}
break
;
case
3
:
{
paddle
::
operators
::
math
::
MaxPool3dWithIndexFunctor
<
Place
,
T
>
pool3d_forward
;
pool3d_forward
(
context
.
device_context
(),
*
in_x
,
*
out
,
*
mask
,
ksize
,
strides
,
paddings
);
}
break
;
}
}
};
template
<
typename
Place
,
typename
T
>
class
MaxPoolWithIndexGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
Tensor
*
mask
=
context
.
Input
<
Tensor
>
(
"Mask"
);
const
Tensor
*
out_grad
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
Tensor
*
in_x_grad
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
vector
<
int
>
ksize
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"ksize"
);
std
::
vector
<
int
>
strides
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"strides"
);
std
::
vector
<
int
>
paddings
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"paddings"
);
if
(
context
.
Attr
<
bool
>
(
"globalPooling"
))
{
for
(
size_t
i
=
0
;
i
<
ksize
.
size
();
++
i
)
{
ksize
[
i
]
=
static_cast
<
int
>
(
in_x_grad
->
dims
()[
i
+
2
]);
}
}
if
(
in_x_grad
)
{
in_x_grad
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
temp
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
in_x_grad
);
temp
.
device
(
context
.
GetEigenDevice
<
Place
>
())
=
temp
.
constant
(
static_cast
<
T
>
(
0
));
switch
(
ksize
.
size
())
{
case
2
:
{
paddle
::
operators
::
math
::
MaxPool2dWithIndexGradFunctor
<
Place
,
T
>
pool2d_backward
;
pool2d_backward
(
context
.
device_context
(),
*
in_x_grad
,
*
out_grad
,
*
mask
,
ksize
,
strides
,
paddings
);
}
break
;
case
3
:
{
paddle
::
operators
::
math
::
MaxPool3dWithIndexGradFunctor
<
Place
,
T
>
pool3d_backward
;
pool3d_backward
(
context
.
device_context
(),
*
in_x_grad
,
*
out_grad
,
*
mask
,
ksize
,
strides
,
paddings
);
}
break
;
}
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/platform/device_context.cc
浏览文件 @
69fd376b
...
...
@@ -136,7 +136,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
#endif
// PADDLE_ONLY_CPU
#endif
}
// namespace platform
}
// namespace paddle
paddle/platform/enforce.h
浏览文件 @
69fd376b
...
...
@@ -41,7 +41,7 @@ limitations under the License. */
#include <thrust/system/cuda/error.h>
#include <thrust/system_error.h>
#endif
// PADDLE_ONLY_CPU
#endif
namespace
paddle
{
namespace
platform
{
...
...
paddle/platform/gpu_info.h
浏览文件 @
69fd376b
...
...
@@ -63,4 +63,4 @@ void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
}
// namespace platform
}
// namespace paddle
#endif
// PADDLE_ONLY_CPU
#endif
python/paddle/v2/framework/tests/test_activation_op.py
浏览文件 @
69fd376b
...
...
@@ -181,6 +181,26 @@ class TestSoftRelu(OpTest):
self
.
check_grad
([
'X'
],
'Y'
,
max_relative_error
=
0.02
)
class
TestELU
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"elu"
x
=
np
.
random
.
uniform
(
-
3
,
3
,
[
4
,
4
]).
astype
(
"float32"
)
alpha
=
1.
# Note: unlike other Relu extensions, point 0 on standard ELU function (i.e. alpha = 1)
# is differentiable, so we can skip modifications like x[np.abs(x) < 0.005] = 0.02 here
self
.
inputs
=
{
'X'
:
x
}
self
.
attrs
=
{
'alpha'
:
alpha
}
self
.
outputs
=
{
'Y'
:
np
.
maximum
(
0
,
x
)
+
np
.
minimum
(
0
,
alpha
*
(
np
.
exp
(
x
)
-
1
))
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
],
'Y'
,
max_relative_error
=
0.02
)
class
TestReciprocal
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"reciprocal"
...
...
python/paddle/v2/framework/tests/test_fill_constant_op.py
0 → 100644
浏览文件 @
69fd376b
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestFillConstantOp1
(
OpTest
):
def
setUp
(
self
):
'''Test fill_constant op with specified value
'''
self
.
op_type
=
"fill_constant"
self
.
inputs
=
{}
self
.
attrs
=
{
'shape'
:
[
123
,
92
],
'value'
:
3.8
}
self
.
outputs
=
{
'Out'
:
np
.
full
((
123
,
92
),
3.8
)}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestFillConstantOp2
(
OpTest
):
def
setUp
(
self
):
'''Test fill_constant op with default value
'''
self
.
op_type
=
"fill_constant"
self
.
inputs
=
{}
self
.
attrs
=
{
'shape'
:
[
123
,
92
]}
self
.
outputs
=
{
'Out'
:
np
.
full
((
123
,
92
),
0.0
)}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_interp_op.py
0 → 100644
浏览文件 @
69fd376b
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
class
TestInterpOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"interp"
x
=
np
.
random
.
random
((
2
,
3
)).
astype
(
"float32"
)
y
=
np
.
random
.
random
((
2
,
3
)).
astype
(
"float32"
)
w
=
np
.
random
.
random
(
2
).
astype
(
"float32"
)
sub_out
=
x
-
y
mul_out
=
sub_out
*
w
.
reshape
(
2
,
1
)
out
=
mul_out
+
y
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
,
'W'
:
w
}
self
.
outputs
=
{
'Out'
:
out
,
'SubOut'
:
sub_out
,
'MulOut'
:
mul_out
}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad_normal
(
self
):
self
.
check_grad
([
'X'
,
'Y'
],
'Out'
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/v2/framework/tests/test_pool_max_op.py
0 → 100644
浏览文件 @
69fd376b
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
def
max_pool3D_forward_naive
(
x
,
ksize
,
strides
,
paddings
=
[
0
,
0
,
0
],
global_pool
=
0
):
N
,
C
,
D
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
D
,
H
,
W
]
D_out
=
(
D
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
H_out
=
(
H
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
W_out
=
(
W
-
ksize
[
2
]
+
2
*
paddings
[
2
])
/
strides
[
2
]
+
1
out
=
np
.
zeros
((
N
,
C
,
D_out
,
H_out
,
W_out
))
mask
=
np
.
zeros
((
N
,
C
,
D_out
,
H_out
,
W_out
))
for
k
in
xrange
(
D_out
):
d_start
=
np
.
max
((
k
*
strides
[
0
]
-
paddings
[
0
],
0
))
d_end
=
np
.
min
((
k
*
strides
[
0
]
+
ksize
[
0
]
-
paddings
[
0
],
D
))
for
i
in
xrange
(
H_out
):
h_start
=
np
.
max
((
i
*
strides
[
0
]
-
paddings
[
0
],
0
))
h_end
=
np
.
min
((
i
*
strides
[
0
]
+
ksize
[
0
]
-
paddings
[
0
],
H
))
for
j
in
xrange
(
W_out
):
w_start
=
np
.
max
((
j
*
strides
[
1
]
-
paddings
[
1
],
0
))
w_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
d_start
:
d_end
,
h_start
:
h_end
,
w_start
:
w_end
]
out
[:,
:,
k
,
i
,
j
]
=
np
.
max
(
x_masked
,
axis
=
(
2
,
3
,
4
))
for
n
in
xrange
(
N
):
for
c
in
xrange
(
C
):
arr
=
x_masked
[
n
,
c
,
:,
:,
:]
index
=
np
.
where
(
arr
==
np
.
max
(
arr
))
sub_deep
=
index
[
0
][
0
]
sub_row
=
index
[
1
][
0
]
sub_col
=
index
[
2
][
0
]
index
=
((
d_start
+
sub_deep
)
*
H
+
(
h_start
+
sub_row
))
*
W
+
w_start
+
sub_col
mask
[
n
,
c
,
k
,
i
,
j
]
=
index
return
out
,
mask
def
max_pool2D_forward_naive
(
x
,
ksize
,
strides
,
paddings
=
[
0
,
0
],
global_pool
=
0
):
N
,
C
,
H
,
W
=
x
.
shape
if
global_pool
==
1
:
ksize
=
[
H
,
W
]
H_out
=
(
H
-
ksize
[
0
]
+
2
*
paddings
[
0
])
/
strides
[
0
]
+
1
W_out
=
(
W
-
ksize
[
1
]
+
2
*
paddings
[
1
])
/
strides
[
1
]
+
1
out
=
np
.
zeros
((
N
,
C
,
H_out
,
W_out
))
mask
=
np
.
zeros
((
N
,
C
,
H_out
,
W_out
))
for
i
in
xrange
(
H_out
):
for
j
in
xrange
(
W_out
):
r_start
=
np
.
max
((
i
*
strides
[
0
]
-
paddings
[
0
],
0
))
r_end
=
np
.
min
((
i
*
strides
[
0
]
+
ksize
[
0
]
-
paddings
[
0
],
H
))
c_start
=
np
.
max
((
j
*
strides
[
1
]
-
paddings
[
1
],
0
))
c_end
=
np
.
min
((
j
*
strides
[
1
]
+
ksize
[
1
]
-
paddings
[
1
],
W
))
x_masked
=
x
[:,
:,
r_start
:
r_end
,
c_start
:
c_end
]
out
[:,
:,
i
,
j
]
=
np
.
max
(
x_masked
,
axis
=
(
2
,
3
))
for
n
in
xrange
(
N
):
for
c
in
xrange
(
C
):
arr
=
x_masked
[
n
,
c
,
:,
:]
index
=
np
.
where
(
arr
==
np
.
max
(
arr
))
sub_row
=
index
[
0
][
0
]
sub_col
=
index
[
1
][
0
]
index
=
(
r_start
+
sub_row
)
*
W
+
c_start
+
sub_col
mask
[
n
,
c
,
i
,
j
]
=
index
return
out
,
mask
class
TestMaxPoolWithIndex_Op
(
OpTest
):
def
setUp
(
self
):
self
.
initTestCase
()
input
=
np
.
random
.
random
(
self
.
shape
).
astype
(
"float32"
)
output
,
mask
=
self
.
pool_forward_naive
(
input
,
self
.
ksize
,
self
.
strides
,
self
.
paddings
,
self
.
global_pool
)
self
.
attrs
=
{
'strides'
:
self
.
strides
,
'paddings'
:
self
.
paddings
,
'ksize'
:
self
.
ksize
,
'globalPooling'
:
self
.
global_pool
,
}
self
.
inputs
=
{
'X'
:
input
}
self
.
outputs
=
{
'Out'
:
output
,
"Mask"
:
mask
}
def
test_check_output
(
self
):
self
.
check_output
()
# def test_check_grad(self):
# self.check_grad(set(['X']), ['Out'], max_relative_error=0.07)
def
initTestCase
(
self
):
self
.
global_pool
=
True
self
.
index
=
"max_pool3d_with_index"
self
.
op_type
=
"%s"
%
self
.
index
self
.
pool_forward_naive
=
max_pool3D_forward_naive
self
.
shape
=
[
2
,
3
,
5
,
5
,
5
]
self
.
ksize
=
[
3
,
3
,
3
]
self
.
strides
=
[
1
,
1
,
1
]
self
.
paddings
=
[
1
,
1
,
1
]
class
TestCase1
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
True
self
.
op_type
=
"max_pool3d_with_index"
self
.
pool_forward_naive
=
max_pool3D_forward_naive
self
.
shape
=
[
2
,
3
,
5
,
5
,
5
]
self
.
ksize
=
[
3
,
3
,
3
]
self
.
strides
=
[
1
,
1
,
1
]
self
.
paddings
=
[
1
,
1
,
1
]
class
TestCase2
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
False
self
.
op_type
=
"max_pool3d_with_index"
self
.
pool_forward_naive
=
max_pool3D_forward_naive
self
.
shape
=
[
2
,
3
,
7
,
7
,
7
]
self
.
ksize
=
[
3
,
3
,
3
]
self
.
strides
=
[
1
,
1
,
1
]
self
.
paddings
=
[
1
,
1
,
1
]
class
TestCase3
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
False
self
.
op_type
=
"max_pool3d_with_index"
self
.
pool_forward_naive
=
max_pool3D_forward_naive
self
.
shape
=
[
2
,
3
,
7
,
7
,
7
]
self
.
ksize
=
[
3
,
3
,
3
]
self
.
strides
=
[
2
,
2
,
2
]
self
.
paddings
=
[
0
,
0
,
0
]
class
TestCase4
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
True
self
.
op_type
=
"max_pool3d_with_index"
self
.
pool_forward_naive
=
max_pool3D_forward_naive
self
.
shape
=
[
2
,
3
,
5
,
5
,
5
]
self
.
ksize
=
[
3
,
3
,
3
]
self
.
strides
=
[
1
,
1
,
1
]
self
.
paddings
=
[
1
,
1
,
1
]
class
TestCase5
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
True
self
.
op_type
=
"max_pool3d_with_index"
self
.
pool_forward_naive
=
max_pool3D_forward_naive
self
.
shape
=
[
2
,
3
,
5
,
5
,
5
]
self
.
ksize
=
[
3
,
3
,
3
]
self
.
strides
=
[
2
,
2
,
2
]
self
.
paddings
=
[
0
,
0
,
0
]
class
TestCase6
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
False
self
.
op_type
=
"max_pool2d_with_index"
self
.
pool_forward_naive
=
max_pool2D_forward_naive
self
.
shape
=
[
2
,
3
,
7
,
7
]
self
.
ksize
=
[
3
,
3
]
self
.
strides
=
[
1
,
1
]
self
.
paddings
=
[
1
,
1
]
class
TestCase7
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
False
self
.
op_type
=
"max_pool2d_with_index"
self
.
pool_forward_naive
=
max_pool2D_forward_naive
self
.
shape
=
[
2
,
3
,
7
,
7
]
self
.
ksize
=
[
3
,
3
]
self
.
strides
=
[
2
,
2
]
self
.
paddings
=
[
0
,
0
]
class
TestCase8
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
True
self
.
op_type
=
"max_pool2d_with_index"
self
.
pool_forward_naive
=
max_pool2D_forward_naive
self
.
shape
=
[
2
,
3
,
5
,
5
]
self
.
ksize
=
[
3
,
3
]
self
.
strides
=
[
1
,
1
]
self
.
paddings
=
[
1
,
1
]
class
TestCase9
(
TestMaxPoolWithIndex_Op
):
def
initTestCase
(
self
):
self
.
global_pool
=
True
self
.
op_type
=
"max_pool2d_with_index"
self
.
pool_forward_naive
=
max_pool2D_forward_naive
self
.
shape
=
[
2
,
3
,
5
,
5
]
self
.
ksize
=
[
3
,
3
]
self
.
strides
=
[
2
,
2
]
self
.
paddings
=
[
0
,
0
]
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录