Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
987cdf11
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
987cdf11
编写于
9月 07, 2017
作者:
W
wanghaoshuang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add clip op
上级
ba43904a
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
255 addition
and
0 deletion
+255
-0
paddle/operators/clip_op.cc
paddle/operators/clip_op.cc
+73
-0
paddle/operators/clip_op.cu
paddle/operators/clip_op.cu
+67
-0
paddle/operators/clip_op.h
paddle/operators/clip_op.h
+70
-0
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-0
python/paddle/v2/framework/tests/op_test_util.py
python/paddle/v2/framework/tests/op_test_util.py
+5
-0
python/paddle/v2/framework/tests/test_clip_op.py
python/paddle/v2/framework/tests/test_clip_op.py
+39
-0
未找到文件。
paddle/operators/clip_op.cc
0 → 100644
浏览文件 @
987cdf11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/clip_op.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
class
ClipOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
max
=
GetAttr
<
float
>
(
"max"
);
auto
min
=
GetAttr
<
float
>
(
"min"
);
PADDLE_ENFORCE_LT
(
min
,
max
,
"max should be greater than min."
);
ctx
.
Output
<
Tensor
>
(
"Out"
)
->
Resize
(
x_dims
);
}
};
class
ClipOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
ClipOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of clip op"
);
AddOutput
(
"Out"
,
"The output of clip op"
);
AddComment
(
R"DOC(
Clip Operator.
)DOC"
);
AddAttr
<
float
>
(
"min"
,
"min value to be clipped."
);
AddAttr
<
float
>
(
"max"
,
"max value to be clipped."
);
}
};
class
ClipOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
x_dims
=
ctx
.
Input
<
Tensor
>
(
"X"
)
->
dims
();
auto
*
x_grad
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
x_grad
->
Resize
(
x_dims
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
clip
,
ops
::
ClipOp
,
ops
::
ClipOpMaker
,
clip_grad
,
ops
::
ClipOpGrad
);
REGISTER_OP_CPU_KERNEL
(
clip
,
ops
::
ClipKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
clip_grad
,
ops
::
ClipGradKernel
<
float
>
);
paddle/operators/clip_op.cu
0 → 100644
浏览文件 @
987cdf11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/clip_op.h"
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \
i += blockDim.x * gridDim.x)
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
__global__
void
ClipGradientKernel
(
const
int
N
,
const
T
min
,
const
T
max
,
const
T
*
Y
,
const
T
*
dY
,
T
*
dX
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
N
)
{
dX
[
i
]
=
dY
[
i
]
*
(
Y
[
i
]
>
min
&&
Y
[
i
]
<
max
);
}
}
template
<
typename
T
>
class
ClipGradientOpCUDAKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
op
().
GetAttr
<
float
>
(
"max"
);
auto
min
=
context
.
op
().
GetAttr
<
float
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
x
=
context
.
Output
<
Tensor
>
(
"X"
);
auto
dims
=
d_x
->
dims
();
size_t
count
=
1
;
for
(
int
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
count
*=
dims
[
i
];
}
auto
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_out_data
=
d_out
->
data
<
T
>
();
auto
x_data
=
x
->
data
<
T
>
();
int
N
=
d_x
->
dims
()[
0
];
int
D
=
d_x
->
dims
()[
1
];
int
block
=
512
;
int
grid
=
(
N
*
D
+
block
-
1
)
/
block
;
ClipGradientKernel
<
T
><<<
grid
,
block
>>>
(
count
,
min
,
max
,
x_data
,
d_out_data
,
d_x_data
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
clip
,
ops
::
ClipKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
clip_grad
,
ops
::
ClipGradientOpCUDAKernel
<
float
>
);
paddle/operators/clip_op.h
0 → 100644
浏览文件 @
987cdf11
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
size_t
D
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenTensor
=
framework
::
EigenTensor
<
T
,
D
,
MajorType
,
IndexType
>
;
template
<
typename
Place
,
typename
T
>
class
ClipKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
op
().
GetAttr
<
float
>
(
"max"
);
auto
min
=
context
.
op
().
GetAttr
<
float
>
(
"min"
);
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_tensor
=
EigenTensor
<
T
,
2
>::
From
(
*
x
);
auto
out_tensor
=
EigenTensor
<
T
,
2
>::
From
(
*
out
);
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
out_tensor
.
device
(
place
)
=
x_tensor
.
cwiseMin
(
max
).
cwiseMax
(
min
);
}
};
template
<
typename
T
>
class
ClipGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
max
=
context
.
op
().
GetAttr
<
float
>
(
"max"
);
auto
min
=
context
.
op
().
GetAttr
<
float
>
(
"min"
);
auto
*
d_out
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
d_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
x
=
context
.
Output
<
Tensor
>
(
"X"
);
auto
dims
=
d_x
->
dims
();
size_t
count
=
1
;
for
(
int
i
=
0
;
i
<
dims
.
size
();
++
i
)
{
count
*=
dims
[
i
];
}
auto
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
d_out_data
=
d_out
->
data
<
T
>
();
auto
x_data
=
x
->
data
<
T
>
();
for
(
int
i
=
0
;
i
<
count
;
++
i
)
{
d_x_data
[
i
]
=
d_out_data
[
i
]
*
(
x_data
[
i
]
>
min
&&
x_data
[
i
]
<
max
);
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
987cdf11
...
@@ -48,6 +48,7 @@ USE_NO_KERNEL_OP(identity);
...
@@ -48,6 +48,7 @@ USE_NO_KERNEL_OP(identity);
USE_OP
(
minus
);
USE_OP
(
minus
);
USE_CPU_ONLY_OP
(
gather
);
USE_CPU_ONLY_OP
(
gather
);
USE_CPU_ONLY_OP
(
scatter
);
USE_CPU_ONLY_OP
(
scatter
);
USE_OP
(
clip
);
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
...
python/paddle/v2/framework/tests/op_test_util.py
浏览文件 @
987cdf11
...
@@ -34,8 +34,10 @@ class OpTestMeta(type):
...
@@ -34,8 +34,10 @@ class OpTestMeta(type):
arr
=
self
.
inputs
[
in_name
]
arr
=
self
.
inputs
[
in_name
]
var
.
set_dims
(
arr
.
shape
)
var
.
set_dims
(
arr
.
shape
)
var
.
set
(
arr
,
place
)
var
.
set
(
arr
,
place
)
print
"var: %s"
%
in_name
else
:
else
:
kwargs
[
in_name
]
=
"@EMPTY@"
kwargs
[
in_name
]
=
"@EMPTY@"
print
"var: %s=EMPTY"
%
in_name
for
out_name
in
Operator
.
get_op_output_names
(
self
.
type
):
for
out_name
in
Operator
.
get_op_output_names
(
self
.
type
):
if
not
hasattr
(
self
,
"outputs"
):
if
not
hasattr
(
self
,
"outputs"
):
...
@@ -46,6 +48,7 @@ class OpTestMeta(type):
...
@@ -46,6 +48,7 @@ class OpTestMeta(type):
(
out_name
))
(
out_name
))
kwargs
[
out_name
]
=
out_name
kwargs
[
out_name
]
=
out_name
scope
.
new_var
(
out_name
).
get_tensor
()
scope
.
new_var
(
out_name
).
get_tensor
()
print
"var: %s"
%
out_name
for
attr_name
in
Operator
.
get_op_attr_names
(
self
.
type
):
for
attr_name
in
Operator
.
get_op_attr_names
(
self
.
type
):
if
hasattr
(
self
,
"attrs"
)
and
attr_name
in
self
.
attrs
:
if
hasattr
(
self
,
"attrs"
)
and
attr_name
in
self
.
attrs
:
...
@@ -62,7 +65,9 @@ class OpTestMeta(type):
...
@@ -62,7 +65,9 @@ class OpTestMeta(type):
for
out_name
in
Operator
.
get_op_output_names
(
self
.
type
):
for
out_name
in
Operator
.
get_op_output_names
(
self
.
type
):
actual
=
numpy
.
array
(
scope
.
find_var
(
out_name
).
get_tensor
())
actual
=
numpy
.
array
(
scope
.
find_var
(
out_name
).
get_tensor
())
print
"actual: %s"
%
actual
expect
=
self
.
outputs
[
out_name
]
expect
=
self
.
outputs
[
out_name
]
print
"expect: %s"
%
expect
self
.
assertTrue
(
self
.
assertTrue
(
numpy
.
allclose
(
numpy
.
allclose
(
actual
,
expect
,
atol
=
1e-05
),
actual
,
expect
,
atol
=
1e-05
),
...
...
python/paddle/v2/framework/tests/test_clip_op.py
0 → 100644
浏览文件 @
987cdf11
import
unittest
import
numpy
as
np
from
paddle.v2.framework.op
import
Operator
from
gradient_checker
import
GradientChecker
from
op_test_util
import
OpTestMeta
class
TestClipOp
(
unittest
.
TestCase
):
__metaclass__
=
OpTestMeta
def
setUp
(
self
):
input
=
np
.
random
.
random
((
16
,
16
)).
astype
(
"float32"
)
print
"input: %s"
%
input
self
.
type
=
"clip"
self
.
inputs
=
{
'X'
:
input
,
}
self
.
attrs
=
{}
self
.
attrs
[
'min'
]
=
0.1
self
.
attrs
[
'max'
]
=
0.9
self
.
outputs
=
{
'Out'
:
np
.
clip
(
self
.
inputs
[
'X'
],
self
.
attrs
[
'min'
],
self
.
attrs
[
'max'
])
}
class
TestClipGradOp
(
GradientChecker
):
def
setUp
(
self
):
self
.
op
=
Operator
(
type
=
"clip"
,
X
=
"X"
,
Out
=
"Out"
,
min
=
0.1
,
max
=
0.9
)
self
.
inputs
=
{
'X'
:
np
.
random
.
random
((
16
,
16
)).
astype
(
"float32"
),
}
def
test_normal
(
self
):
self
.
check_grad
(
self
.
op
,
self
.
inputs
,
set
([
"X"
]),
"Out"
,
max_relative_error
=
0.5
)
def
test_cpu_gpu_compare
(
self
):
self
.
compare_grad
(
self
.
op
,
self
.
inputs
)
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录