Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
06b42e9e
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
06b42e9e
编写于
9月 06, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add crop op.
上级
f2f839af
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
277 addition
and
0 deletion
+277
-0
paddle/operators/crop_op.cc
paddle/operators/crop_op.cc
+81
-0
paddle/operators/crop_op.cu
paddle/operators/crop_op.cu
+22
-0
paddle/operators/crop_op.h
paddle/operators/crop_op.h
+138
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-0
python/paddle/v2/framework/tests/test_crop_op.py
python/paddle/v2/framework/tests/test_crop_op.py
+35
-0
未找到文件。
paddle/operators/crop_op.cc
0 → 100644
浏览文件 @
06b42e9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/crop_op.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
class
CropOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
Y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
if
(
Y
==
nullptr
)
{
auto
shape
=
GetAttr
<
std
::
vector
<
int
>>
(
"shape"
);
PADDLE_ENFORCE_EQ
(
shape
.
size
(),
dim0
.
size
(),
"Shape size should be equal to dimention size of input tensor."
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
paddle
::
framework
::
make_ddim
(
shape
));
}
else
{
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
Y
->
dims
());
}
}
};
class
CropOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
CropOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of crop op"
);
AddInput
(
"Y"
,
"The input used as reference for cropping. "
);
AddOutput
(
"Out"
,
"The output of crop op."
);
AddComment
(
R"DOC(
Crop Operator.
)DOC"
);
AddAttr
<
std
::
vector
<
int
>>
(
"offsets"
,
"The offsets for cropping."
);
AddAttr
<
std
::
vector
<
int
>>
(
"shape"
,
"The shape for cropping."
);
}
};
class
CropOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
x_grad
->
Resize
(
x_dims
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
crop
,
ops
::
CropOp
,
ops
::
CropOpMaker
,
crop_grad
,
ops
::
CropOpGrad
);
REGISTER_OP_CPU_KERNEL
(
crop
,
ops
::
CropKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
crop_grad
,
ops
::
CropGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/crop_op.cu
0 → 100644
浏览文件 @
06b42e9e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/crop_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
crop
,
ops
::
CropKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
crop_grad
,
ops
::
CropGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/crop_op.h
0 → 100644
浏览文件 @
06b42e9e
/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
Place
,
typename
T
,
size_t
D
>
void
CropFunction
(
const
framework
::
ExecutionContext
&
context
)
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_dims
=
x
->
dims
();
auto
out_dims
=
out
->
dims
();
auto
offsets
=
context
.
op
().
GetAttr
<
std
::
vector
<
int
>>
(
"offsets"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
offsets
.
size
(),
"Offsets size should be equal to dimension size of input tensor."
);
Eigen
::
array
<
std
::
pair
<
int
,
int
>
,
D
>
paddings
;
for
(
size_t
i
=
0
;
i
<
D
;
++
i
)
{
paddings
[
i
].
first
=
-
(
offsets
[
i
]);
paddings
[
i
].
second
=
-
(
x_dims
[
i
]
-
out_dims
[
i
]
-
offsets
[
i
]);
}
auto
x_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
x
);
auto
out_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
out
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
out_tensor
.
device
(
place
)
=
x_tensor
.
pad
(
paddings
,
0
);
}
template
<
typename
Place
,
typename
T
>
class
CropKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
int
dim
=
context
.
Input
<
Tensor
>
(
"X"
)
->
dims
().
size
();
switch
(
dim
)
{
case
1
:
CropFunction
<
Place
,
T
,
1
>
(
context
);
break
;
case
2
:
CropFunction
<
Place
,
T
,
2
>
(
context
);
break
;
case
3
:
CropFunction
<
Place
,
T
,
3
>
(
context
);
break
;
case
4
:
CropFunction
<
Place
,
T
,
4
>
(
context
);
break
;
case
5
:
CropFunction
<
Place
,
T
,
5
>
(
context
);
break
;
case
6
:
CropFunction
<
Place
,
T
,
6
>
(
context
);
break
;
default:
LOG
(
ERROR
)
<<
"Only ranks up to 6 supported."
;
}
}
};
template
<
typename
Place
,
typename
T
,
size_t
D
>
void
CropGradFunction
(
const
framework
::
ExecutionContext
&
context
)
{
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_x_dims
=
d_x
->
dims
();
auto
d_out_dims
=
d_out
->
dims
();
auto
offsets
=
context
.
op
().
GetAttr
<
std
::
vector
<
int
>>
(
"offsets"
);
Eigen
::
array
<
std
::
pair
<
int
,
int
>
,
D
>
paddings
;
for
(
int
i
=
0
;
i
<
d_out_dims
.
size
();
++
i
)
{
paddings
[
i
].
first
=
offsets
[
i
];
paddings
[
i
].
second
=
d_x_dims
[
i
]
-
d_out_dims
[
i
]
-
offsets
[
i
];
}
auto
d_x_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
d_x
);
auto
d_out_tensor
=
EigenTensor
<
T
,
D
>::
From
(
*
d_out
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
d_x_tensor
.
device
(
place
)
=
d_out_tensor
.
pad
(
paddings
,
0
);
}
template
<
typename
Place
,
typename
T
>
class
CropGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
size_t
dim
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
))
->
dims
().
size
();
switch
(
dim
)
{
case
1
:
CropGradFunction
<
Place
,
T
,
1
>
(
context
);
break
;
case
2
:
CropGradFunction
<
Place
,
T
,
2
>
(
context
);
break
;
case
3
:
CropGradFunction
<
Place
,
T
,
3
>
(
context
);
break
;
case
4
:
CropGradFunction
<
Place
,
T
,
4
>
(
context
);
break
;
case
5
:
CropGradFunction
<
Place
,
T
,
5
>
(
context
);
break
;
case
6
:
CropGradFunction
<
Place
,
T
,
6
>
(
context
);
break
;
default:
LOG
(
ERROR
)
<<
"Only ranks up to 6 supported."
;
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
06b42e9e
...
...
@@ -48,6 +48,7 @@ USE_NO_KERNEL_OP(identity);
USE_OP
(
minus
);
USE_CPU_ONLY_OP
(
gather
);
USE_CPU_ONLY_OP
(
scatter
);
USE_OP
(
crop
);
namespace
paddle
{
namespace
framework
{
...
...
python/paddle/v2/framework/tests/test_crop_op.py
0 → 100644
浏览文件 @
06b42e9e
import
unittest
import
numpy
as
np
from
paddle.v2.framework.op
import
Operator
from
gradient_checker
import
GradientChecker
from
op_test_util
import
OpTestMeta
class
TestCropOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
self
.
type
=
"crop"
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
16
,
16
)).
astype
(
"float32"
),
}
self
.
attrs
=
{}
self
.
attrs
[
'offsets'
]
=
[
2
,
3
]
self
.
attrs
[
'shape'
]
=
[
8
,
8
]
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
][
2
:
10
,
3
:
11
]}
class
TestCropGradOp
(
GradientChecker
):
def
setUp
(
self
):
self
.
op
=
Operator
(
type
=
"crop"
,
X
=
"X"
,
Out
=
"Out"
,
offsets
=
[
2
,
3
],
shape
=
[
8
,
8
])
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
16
,
16
)).
astype
(
"float32"
),
}
def
test_normal
(
self
):
self
.
check_grad
(
self
.
op
,
self
.
inputs
,
set
([
"X"
]),
"Out"
,
max_relative_error
=
0.5
)
def
test_cpu_gpu_compare
(
self
):
self
.
compare_grad
(
self
.
op
,
self
.
inputs
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录