Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0e77b31a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0e77b31a
编写于
7月 17, 2017
作者:
Y
Yu Yang
提交者:
GitHub
7月 17, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into feature/op_creation_methods
上级
c78a5e5d
a0caf234
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
579 addition
and
152 deletion
+579
-152
Dockerfile.android
Dockerfile.android
+11
-0
cmake/cross_compiling/android.cmake
cmake/cross_compiling/android.cmake
+8
-3
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+7
-5
paddle/framework/ddim.cc
paddle/framework/ddim.cc
+49
-5
paddle/framework/ddim.h
paddle/framework/ddim.h
+9
-0
paddle/framework/ddim_test.cc
paddle/framework/ddim_test.cc
+20
-0
paddle/framework/dim_test.cu
paddle/framework/dim_test.cu
+82
-81
paddle/framework/enforce.cc
paddle/framework/enforce.cc
+15
-0
paddle/framework/enforce.h
paddle/framework/enforce.h
+6
-0
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+33
-8
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+34
-2
paddle/framework/operator.cc
paddle/framework/operator.cc
+58
-0
paddle/framework/operator.h
paddle/framework/operator.h
+72
-27
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+108
-8
paddle/framework/tensor.cc
paddle/framework/tensor.cc
+19
-0
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+43
-5
paddle/operators/add_op.h
paddle/operators/add_op.h
+2
-2
paddle/scripts/docker/build_android.sh
paddle/scripts/docker/build_android.sh
+3
-6
未找到文件。
Dockerfile.android
浏览文件 @
0e77b31a
...
@@ -14,6 +14,17 @@ RUN apt-get update && \
...
@@ -14,6 +14,17 @@ RUN apt-get update && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
wget curl tar unzip gcc g++ locales clang-format-3.8 swig cmake && \
apt-get clean -y
apt-get clean -y
# Install Go and glide
RUN wget -O go.tgz https://storage.googleapis.com/golang/go1.8.1.linux-amd64.tar.gz && \
tar -C /usr/local -xzf go.tgz && \
mkdir /root/gopath && \
mkdir /root/gopath/bin && \
mkdir /root/gopath/src && \
rm go.tgz
ENV GOROOT=/usr/local/go GOPATH=/root/gopath
# should not be in the same line with GOROOT definition, otherwise docker build could not find GOROOT.
ENV PATH=${PATH}:${GOROOT}/bin:${GOPATH}/bin
# git credential to skip password typing
# git credential to skip password typing
RUN git config --global credential.helper store
RUN git config --global credential.helper store
...
...
cmake/cross_compiling/android.cmake
浏览文件 @
0e77b31a
...
@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
...
@@ -108,6 +108,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF
()
ENDIF
()
IF
(
ANDROID_ABI STREQUAL
"arm64-v8a"
)
IF
(
ANDROID_ABI STREQUAL
"arm64-v8a"
)
SET
(
ANDROID_TOOLCHAIN_NAME aarch64-linux-android
)
SET
(
ANDROID_TOOLCHAIN_NAME aarch64-linux-android
)
SET
(
CMAKE_SYSTEM_PROCESSOR aarch64
)
ENDIF
()
ENDIF
()
SET
(
ANDROID_TOOLCHAIN_PREFIX
"
${
ANDROID_TOOLCHAIN_ROOT
}
/bin/
${
ANDROID_TOOLCHAIN_NAME
}
-"
)
SET
(
ANDROID_TOOLCHAIN_PREFIX
"
${
ANDROID_TOOLCHAIN_ROOT
}
/bin/
${
ANDROID_TOOLCHAIN_NAME
}
-"
)
ENDIF
()
ENDIF
()
...
@@ -166,7 +167,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
...
@@ -166,7 +167,7 @@ IF("${CMAKE_VERSION}" VERSION_LESS "3.7.0")
ENDIF
()
ENDIF
()
IF
(
ANDROID_ABI STREQUAL
"arm64-v8a"
)
IF
(
ANDROID_ABI STREQUAL
"arm64-v8a"
)
LIST
(
APPEND ANDROID_COMPILER_FLAGS -march=armv8-a
)
LIST
(
APPEND ANDROID_COMPILER_FLAGS -march=armv8-a
)
ENDIF
()
ENDIF
()
STRING
(
REPLACE
";"
" "
ANDROID_COMPILER_FLAGS
"
${
ANDROID_COMPILER_FLAGS
}
"
)
STRING
(
REPLACE
";"
" "
ANDROID_COMPILER_FLAGS
"
${
ANDROID_COMPILER_FLAGS
}
"
)
...
@@ -193,6 +194,10 @@ ELSE()
...
@@ -193,6 +194,10 @@ ELSE()
SET
(
CMAKE_ANDROID_STANDALONE_TOOLCHAIN
${
ANDROID_STANDALONE_TOOLCHAIN
}
)
SET
(
CMAKE_ANDROID_STANDALONE_TOOLCHAIN
${
ANDROID_STANDALONE_TOOLCHAIN
}
)
ENDIF
()
ENDIF
()
SET
(
CMAKE_ANDROID_ARCH_ABI
${
ANDROID_ABI
}
)
SET
(
CMAKE_ANDROID_ARCH_ABI
${
ANDROID_ABI
}
)
SET
(
CMAKE_ANDROID_ARM_MODE
${
ANDROID_ARM_MODE
}
)
IF
(
ANDROID_ABI MATCHES
"^armeabi(-v7a)?$"
)
SET
(
CMAKE_ANDROID_ARM_NEON
${
ANDROID_ARM_NEON
}
)
SET
(
CMAKE_ANDROID_ARM_MODE
${
ANDROID_ARM_MODE
}
)
IF
(
ANDROID_ABI STREQUAL
"armeabi-v7a"
)
SET
(
CMAKE_ANDROID_ARM_NEON
${
ANDROID_ARM_NEON
}
)
ENDIF
()
ENDIF
()
ENDIF
()
ENDIF
()
paddle/framework/CMakeLists.txt
浏览文件 @
0e77b31a
# ddim lib
cc_library
(
enforce SRCS enforce.cc DEPS glog
)
cc_test
(
enforce_test SRCS enforce_test.cc DEPS enforce
)
cc_library
(
ddim SRCS ddim.cc
)
cc_library
(
ddim SRCS ddim.cc
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
cc_test
(
ddim_test SRCS ddim_test.cc DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
nv_test
(
dim_test SRCS dim_test.cu DEPS ddim
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS ddim
)
cc_library
(
tensor SRCS tensor.cc DEPS ddim place enforce paddle_memory
)
cc_test
(
tensor_test SRCS tensor_test.cc DEPS tensor
)
cc_test
(
variable_test SRCS variable_test.cc
)
cc_test
(
variable_test SRCS variable_test.cc
)
cc_test
(
scope_test SRCS scope_test.cc
)
cc_test
(
scope_test SRCS scope_test.cc
)
cc_test
(
enforce_test SRCS enforce_test.cc
)
proto_library
(
attr_type SRCS attr_type.proto
)
proto_library
(
attr_type SRCS attr_type.proto
)
proto_library
(
op_proto SRCS op_proto.proto DEPS attr_type
)
proto_library
(
op_proto SRCS op_proto.proto DEPS attr_type
)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
cc_test
(
op_proto_test SRCS op_proto_test.cc DEPS op_proto protobuf
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
proto_library
(
op_desc SRCS op_desc.proto DEPS attr_type
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_test
(
op_desc_test SRCS op_desc_test.cc DEPS op_desc protobuf
)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context
)
cc_library
(
operator SRCS operator.cc DEPS op_desc device_context
tensor
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto op_desc
enforce
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry operator
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
py_proto_compile
(
framework_py_proto SRCS attr_type.proto op_proto.proto op_desc.proto
)
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
# Generate an empty __init__.py to make framework_py_proto as a valid python module.
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
add_custom_target
(
framework_py_proto_init ALL COMMAND
${
CMAKE_COMMAND
}
-E touch __init__.py
)
...
...
paddle/framework/ddim.cc
浏览文件 @
0e77b31a
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/ddim.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -192,13 +193,56 @@ std::vector<int> vectorize(const DDim& ddim) {
...
@@ -192,13 +193,56 @@ std::vector<int> vectorize(const DDim& ddim) {
return
result
;
return
result
;
}
}
struct
ProductVisitor
:
public
boost
::
static_visitor
<
ssize_t
>
{
template
<
int
D
>
ssize_t
operator
()(
const
Dim
<
D
>&
dim
)
{
return
product
(
dim
);
}
};
ssize_t
product
(
const
DDim
&
ddim
)
{
ssize_t
product
(
const
DDim
&
ddim
)
{
ssize_t
result
=
1
;
ProductVisitor
visitor
;
std
::
vector
<
int
>
v
=
vectorize
(
ddim
);
return
boost
::
apply_visitor
(
visitor
,
ddim
);
for
(
auto
i
:
v
)
{
}
result
*=
i
;
struct
SliceVectorizeVisitor
:
public
boost
::
static_visitor
<>
{
std
::
vector
<
int
>&
vector
;
int
begin
;
int
end
;
SliceVectorizeVisitor
(
std
::
vector
<
int
>&
v
,
int
b
,
int
e
)
:
vector
(
v
),
begin
(
b
),
end
(
e
)
{
PADDLE_ENFORCE
(
begin
<
end
,
"Begin index must be less than end index in ddim slice."
);
PADDLE_ENFORCE
(
begin
>=
0
,
"Begin index can't be less than zero in ddim slice."
);
}
}
return
result
;
template
<
int
S
>
void
operator
()(
const
Dim
<
S
>&
dim
)
{
if
(
begin
==
0
)
{
vector
.
push_back
(
dim
.
head
);
}
else
{
--
begin
;
}
--
end
;
if
(
end
>
0
)
{
this
->
operator
()(
dim
.
tail
);
}
}
void
operator
()(
const
Dim
<
1
>&
dim
)
{
PADDLE_ENFORCE
(
end
==
1
,
"End index in ddim slice is out of bound."
);
vector
.
push_back
(
dim
.
head
);
}
};
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
)
{
std
::
vector
<
int
>
vec
;
vec
.
reserve
(
end
-
begin
);
SliceVectorizeVisitor
visitor
(
vec
,
begin
,
end
);
boost
::
apply_visitor
(
visitor
,
dim
);
return
make_ddim
(
vec
);
}
}
/// \cond HIDDEN
/// \cond HIDDEN
...
...
paddle/framework/ddim.h
浏览文件 @
0e77b31a
...
@@ -81,6 +81,15 @@ std::vector<int> vectorize(const DDim& ddim);
...
@@ -81,6 +81,15 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t
product
(
const
DDim
&
ddim
);
ssize_t
product
(
const
DDim
&
ddim
);
/**
* \brief Slice a ddim
*
* Slice dim with [begin, end).
* e.g. DDim d = make_ddim({1,2,3,4,5});
* slice_ddim(d, 1, 3); ====> {2,3}
*/
DDim
slice_ddim
(
const
DDim
&
dim
,
int
begin
,
int
end
);
/**
/**
* \brief What is the length of this dimension?
* \brief What is the length of this dimension?
*
*
...
...
paddle/framework/ddim_test.cc
浏览文件 @
0e77b31a
...
@@ -52,6 +52,26 @@ TEST(DDim, Equality) {
...
@@ -52,6 +52,26 @@ TEST(DDim, Equality) {
// product of a DDim
// product of a DDim
EXPECT_EQ
(
paddle
::
framework
::
product
(
vddim
),
45
);
EXPECT_EQ
(
paddle
::
framework
::
product
(
vddim
),
45
);
EXPECT_EQ
(
paddle
::
framework
::
product
(
paddle
::
framework
::
make_ddim
({
3
,
2
,
5
,
3
})),
90
);
// slice a DDim
paddle
::
framework
::
DDim
ddim2
=
paddle
::
framework
::
make_ddim
({
1
,
2
,
3
,
4
,
5
,
6
});
paddle
::
framework
::
DDim
ss
=
paddle
::
framework
::
slice_ddim
(
ddim2
,
2
,
5
);
EXPECT_EQ
(
arity
(
ss
),
3
);
EXPECT_EQ
(
ss
[
0
],
3
);
EXPECT_EQ
(
ss
[
1
],
4
);
EXPECT_EQ
(
ss
[
2
],
5
);
paddle
::
framework
::
DDim
ss2
=
paddle
::
framework
::
slice_ddim
(
ddim2
,
0
,
6
);
EXPECT_EQ
(
arity
(
ss2
),
6
);
EXPECT_EQ
(
ss2
[
0
],
1
);
EXPECT_EQ
(
ss2
[
1
],
2
);
EXPECT_EQ
(
ss2
[
2
],
3
);
EXPECT_EQ
(
ss2
[
3
],
4
);
EXPECT_EQ
(
ss2
[
4
],
5
);
EXPECT_EQ
(
ss2
[
5
],
6
);
}
}
TEST
(
DDim
,
Print
)
{
TEST
(
DDim
,
Print
)
{
...
...
paddle/framework/dim_test.cu
浏览文件 @
0e77b31a
#include <thrust/device_vector.h>
#include <thrust/device_vector.h>
#include <sstream>
#include <sstream>
#include "paddle/framework/dim.h"
#include "gtest/gtest.h"
#include "gtest/gtest.h"
#include "paddle/framework/dim.h"
__global__
void
test
(
paddle
::
framework
::
Dim
<
2
>*
o
)
{
__global__
void
test
(
paddle
::
framework
::
Dim
<
2
>*
o
)
{
o
[
0
]
=
paddle
::
framework
::
make_dim
(
5
,
6
);
o
[
0
]
=
paddle
::
framework
::
make_dim
(
5
,
6
);
}
}
__global__
void
dyn_idx_gpu
(
int
*
o
)
{
__global__
void
dyn_idx_gpu
(
int
*
o
)
{
auto
d
=
paddle
::
framework
::
make_dim
(
5
,
6
);
auto
d
=
paddle
::
framework
::
make_dim
(
5
,
6
);
o
[
0
]
=
d
[
1
];
o
[
0
]
=
d
[
1
];
}
}
TEST
(
Dim
,
Equality
)
{
TEST
(
Dim
,
Equality
)
{
// construct a Dim on the CPU
// construct a Dim on the CPU
auto
a
=
paddle
::
framework
::
make_dim
(
3
,
4
);
auto
a
=
paddle
::
framework
::
make_dim
(
3
,
4
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
a
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
a
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
a
),
4
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
a
),
4
);
// construct a Dim on the GPU
// construct a Dim on the GPU
thrust
::
device_vector
<
paddle
::
framework
::
Dim
<
2
>>
t
(
2
);
thrust
::
device_vector
<
paddle
::
framework
::
Dim
<
2
>>
t
(
2
);
test
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
t
.
data
()));
test
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
t
.
data
()));
a
=
t
[
0
];
a
=
t
[
0
];
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
a
),
5
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
a
),
5
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
a
),
6
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
a
),
6
);
// linearization
// linearization
auto
b
=
paddle
::
framework
::
make_dim
(
7
,
8
);
auto
b
=
paddle
::
framework
::
make_dim
(
7
,
8
);
EXPECT_EQ
(
paddle
::
framework
::
linearize
(
a
,
b
),
83
);
EXPECT_EQ
(
paddle
::
framework
::
linearize
(
a
,
b
),
83
);
// product
// product
EXPECT_EQ
(
paddle
::
framework
::
product
(
a
),
30
);
EXPECT_EQ
(
paddle
::
framework
::
product
(
a
),
30
);
// mutate a Dim
// mutate a Dim
paddle
::
framework
::
get
<
1
>
(
b
)
=
10
;
paddle
::
framework
::
get
<
1
>
(
b
)
=
10
;
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
b
),
7
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
b
),
7
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
b
),
10
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
b
),
10
);
// dynamic access
// dynamic access
paddle
::
framework
::
get
(
b
,
0
)
=
8
;
paddle
::
framework
::
get
(
b
,
0
)
=
8
;
b
[
1
]
=
11
;
b
[
1
]
=
11
;
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
b
),
8
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
b
),
8
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
b
),
11
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
b
),
11
);
EXPECT_EQ
(
paddle
::
framework
::
get
(
b
,
0
),
8
);
EXPECT_EQ
(
paddle
::
framework
::
get
(
b
,
0
),
8
);
EXPECT_EQ
(
b
[
1
],
11
);
EXPECT_EQ
(
b
[
1
],
11
);
// dynamic access on GPU
// dynamic access on GPU
thrust
::
device_vector
<
int
>
r
(
1
);
thrust
::
device_vector
<
int
>
r
(
1
);
dyn_idx_gpu
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
r
.
data
()));
dyn_idx_gpu
<<<
1
,
1
>>>
(
thrust
::
raw_pointer_cast
(
r
.
data
()));
int
res
=
r
[
0
];
int
res
=
r
[
0
];
EXPECT_EQ
(
res
,
6
);
EXPECT_EQ
(
res
,
6
);
// ex_prefix_mul
// ex_prefix_mul
paddle
::
framework
::
Dim
<
3
>
c
=
paddle
::
framework
::
ex_prefix_mul
(
paddle
::
framework
::
Dim
<
3
>
(
3
,
4
,
5
));
paddle
::
framework
::
Dim
<
3
>
c
=
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
paddle
::
framework
::
ex_prefix_mul
(
paddle
::
framework
::
Dim
<
3
>
(
3
,
4
,
5
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
// generate from an index
auto
size
=
paddle
::
framework
::
make_dim
(
4
,
5
,
2
);
// generate from an index
c
=
paddle
::
framework
::
Dim
<
3
>
(
14
,
size
);
auto
size
=
paddle
::
framework
::
make_dim
(
4
,
5
,
2
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
2
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
14
,
size
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
2
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
0
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
25
,
size
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
0
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
25
,
size
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
1
);
}
}
TEST
(
Dim
,
Bool
)
{
TEST
(
Dim
,
Bool
)
{
auto
a
=
paddle
::
framework
::
make_dim
(
3
,
4
);
auto
a
=
paddle
::
framework
::
make_dim
(
3
,
4
);
auto
b
=
paddle
::
framework
::
make_dim
(
5
,
6
);
auto
b
=
paddle
::
framework
::
make_dim
(
5
,
6
);
auto
c
=
paddle
::
framework
::
make_dim
(
3
,
4
);
auto
c
=
paddle
::
framework
::
make_dim
(
3
,
4
);
// in_bounds check
// in_bounds check
EXPECT_TRUE
(
paddle
::
framework
::
contained
(
a
,
b
));
EXPECT_TRUE
(
paddle
::
framework
::
contained
(
a
,
b
));
EXPECT_FALSE
(
paddle
::
framework
::
contained
(
b
,
a
));
EXPECT_FALSE
(
paddle
::
framework
::
contained
(
b
,
a
));
// comparison
// comparison
EXPECT_TRUE
(
a
==
a
);
EXPECT_TRUE
(
a
==
a
);
EXPECT_FALSE
(
a
==
b
);
EXPECT_FALSE
(
a
==
b
);
EXPECT_TRUE
(
a
==
c
);
EXPECT_TRUE
(
a
==
c
);
}
}
TEST
(
Dim
,
Print
)
{
TEST
(
Dim
,
Print
)
{
{
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
auto
a
=
paddle
::
framework
::
make_dim
(
2
,
3
);
auto
a
=
paddle
::
framework
::
make_dim
(
2
,
3
);
ss
<<
a
;
ss
<<
a
;
EXPECT_EQ
(
ss
.
str
(),
"2, 3"
);
EXPECT_EQ
(
ss
.
str
(),
"2, 3"
);
}
}
{
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
paddle
::
framework
::
make_dim
(
8
);
ss
<<
paddle
::
framework
::
make_dim
(
8
);
EXPECT_EQ
(
ss
.
str
(),
"8"
);
EXPECT_EQ
(
ss
.
str
(),
"8"
);
}
}
}
}
paddle/framework/enforce.cc
0 → 100644
浏览文件 @
0e77b31a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/enforce.h"
paddle/framework/enforce.h
浏览文件 @
0e77b31a
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <glog/logging.h>
#include <paddle/string/printf.h>
#include <paddle/string/printf.h>
#include <exception>
#include <exception>
#include <sstream>
#include <sstream>
...
@@ -58,12 +59,17 @@ class EnforceNotMet : public std::exception {
...
@@ -58,12 +59,17 @@ class EnforceNotMet : public std::exception {
/**
/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
*/
#ifdef NDEBUG
#define PADDLE_ENFORCE(condition, ...) \
#define PADDLE_ENFORCE(condition, ...) \
do { \
do { \
if (UNLIKELY(!(condition))) { \
if (UNLIKELY(!(condition))) { \
PADDLE_THROW(__VA_ARGS__); \
PADDLE_THROW(__VA_ARGS__); \
} \
} \
} while (0)
} while (0)
#else
#define PADDLE_ENFORCE(condition, ...) \
CHECK(condition) << ::paddle::string::Sprintf(__VA_ARGS__);
#endif
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/op_registry.h
浏览文件 @
0e77b31a
...
@@ -62,7 +62,14 @@ class OpProtoAndCheckerMaker {
...
@@ -62,7 +62,14 @@ class OpProtoAndCheckerMaker {
OpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
proto_
(
proto
),
op_checker_
(
op_checker
)
{}
:
proto_
(
proto
),
op_checker_
(
op_checker
)
{}
~
OpProtoAndCheckerMaker
()
{
CheckNoDuplicatedAttrs
();
}
~
OpProtoAndCheckerMaker
()
{
PADDLE_ENFORCE
(
validated_
,
"should call Validate after build"
);
}
void
Validate
()
{
validated_
=
true
;
CheckNoDuplicatedInOutAttrs
();
}
protected:
protected:
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
void
AddInput
(
const
std
::
string
&
name
,
const
std
::
string
&
comment
,
...
@@ -164,19 +171,26 @@ Add a mark to which output is temporary is helpful for future optimization.
...
@@ -164,19 +171,26 @@ Add a mark to which output is temporary is helpful for future optimization.
}
}
}
}
void
CheckNoDuplicatedAttrs
()
{
void
CheckNoDuplicated
InOut
Attrs
()
{
std
::
unordered_set
<
std
::
string
>
names
;
std
::
unordered_set
<
std
::
string
>
names
;
size_t
cnt
=
0
;
auto
checker
=
[
&
](
const
std
::
string
&
name
)
{
PADDLE_ENFORCE
(
!
names
.
count
(
name
),
"[%s] is duplicated"
,
name
);
names
.
insert
(
name
);
};
for
(
auto
&
attr
:
proto_
->
attrs
())
{
for
(
auto
&
attr
:
proto_
->
attrs
())
{
names
.
insert
(
attr
.
name
());
checker
(
attr
.
name
());
++
cnt
;
}
for
(
auto
&
input
:
proto_
->
inputs
())
{
checker
(
input
.
name
());
}
for
(
auto
&
output
:
proto_
->
outputs
())
{
checker
(
output
.
name
());
}
}
PADDLE_ENFORCE
(
names
.
size
()
==
cnt
,
"Cannot register two attribute in same name!"
);
}
}
OpProto
*
proto_
;
OpProto
*
proto_
;
OpAttrChecker
*
op_checker_
;
OpAttrChecker
*
op_checker_
;
bool
validated_
{
false
};
bool
has_multiple_input_
{
false
};
bool
has_multiple_input_
{
false
};
bool
has_multiple_output_
{
false
};
bool
has_multiple_output_
{
false
};
bool
has_temporary_output_
{
false
};
bool
has_temporary_output_
{
false
};
...
@@ -191,7 +205,8 @@ class OpRegistry {
...
@@ -191,7 +205,8 @@ class OpRegistry {
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
creators
()[
op_type
]
=
[]
{
return
new
OpType
;
};
OpProto
&
op_proto
=
protos
()[
op_type
];
OpProto
&
op_proto
=
protos
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
OpAttrChecker
&
op_checker
=
op_checkers
()[
op_type
];
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
auto
maker
=
ProtoMakerType
(
&
op_proto
,
&
op_checker
);
maker
.
Validate
();
*
op_proto
.
mutable_type
()
=
op_type
;
*
op_proto
.
mutable_type
()
=
op_type
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
op_proto
.
IsInitialized
(),
op_proto
.
IsInitialized
(),
...
@@ -205,10 +220,13 @@ class OpRegistry {
...
@@ -205,10 +220,13 @@ class OpRegistry {
OperatorPtr
op
(
creators
().
at
(
op_type
)());
OperatorPtr
op
(
creators
().
at
(
op_type
)());
//! Fill op's data member. Not use constructor because it will be noising
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
//! for Op developer.
const
OpProto
&
op_proto
=
protos
().
at
(
op_type
);
op
->
type_
=
op_desc
.
type
();
op
->
type_
=
op_desc
.
type
();
// set op's inputs_ from desc.
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
back_inserter
(
op
->
inputs_
));
std
::
back_inserter
(
op
->
inputs_
));
// set op's outputs_ from desc.
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
back_inserter
(
op
->
outputs_
));
std
::
back_inserter
(
op
->
outputs_
));
...
@@ -222,12 +240,19 @@ class OpRegistry {
...
@@ -222,12 +240,19 @@ class OpRegistry {
//! Convert Temporary variable name to an unique variable name.
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName
(
op
.
get
());
GenerateTempVariableName
(
op
.
get
());
// set argument offsets stored in op.
CreateInOutOffsetMap
(
op
,
op_proto
);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
//! method do nothing.
op
->
Init
();
op
->
Init
();
return
op
;
return
op
;
}
}
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static
void
CreateInOutOffsetMap
(
OperatorPtr
op
,
const
OpProto
&
proto
)
{
op
->
CreateInOutOffsetMap
(
proto
);
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
return
protos_
;
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
0e77b31a
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
namespace
pd
=
paddle
::
framework
;
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
CosineOp
:
public
OperatorBase
{
class
CosineOp
:
public
OperatorBase
{
...
@@ -28,8 +30,6 @@ class MyTestOp : public OperatorBase {
...
@@ -28,8 +30,6 @@ class MyTestOp : public OperatorBase {
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
InferShape
(
const
ScopePtr
&
scope
)
const
override
{}
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
public:
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -182,3 +182,35 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -182,3 +182,35 @@ TEST(OpRegistry, CustomChecker) {
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
ASSERT_EQ
(
test_attr
,
4
);
ASSERT_EQ
(
test_attr
,
4
);
}
}
class
TestAttrProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
public:
TestAttrProtoMaker
(
pd
::
OpProto
*
proto
,
pd
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddAttr
<
float
>
(
"scale"
,
"scale of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of test op"
);
}
};
TEST
(
ProtoMaker
,
DuplicatedAttr
)
{
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestAttrProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
framework
::
EnforceNotMet
);
}
class
TestInOutProtoMaker
:
public
pd
::
OpProtoAndCheckerMaker
{
public:
TestInOutProtoMaker
(
pd
::
OpProto
*
proto
,
pd
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddInput
(
"input"
,
"input of test op"
);
}
};
TEST
(
ProtoMaker
,
DuplicatedInOut
)
{
pd
::
OpProto
op_proto
;
pd
::
OpAttrChecker
op_checker
;
auto
proto_maker
=
TestInOutProtoMaker
(
&
op_proto
,
&
op_checker
);
ASSERT_THROW
(
proto_maker
.
Validate
(),
paddle
::
framework
::
EnforceNotMet
);
}
paddle/framework/operator.cc
浏览文件 @
0e77b31a
...
@@ -12,11 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,11 +12,69 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
OperatorBase
::
CreateInOutOffsetMap
(
const
OpProto
&
proto
)
{
PADDLE_ENFORCE
(
in_out_idxs_
.
empty
(),
"duplicate call CreateInOutOffsetMap"
);
for
(
int
i
=
0
;
i
<
proto
.
inputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
inputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
for
(
int
i
=
0
;
i
<
proto
.
outputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
outputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
[
it
->
second
];
}
else
{
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
(
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
+
1
)};
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
outputs_
[
it
->
second
];
}
else
{
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
int
idx
=
output_format
[
it
->
second
];
return
outputs_
.
at
(
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"Op("
<<
type_
<<
"), inputs:("
;
ss
<<
"Op("
<<
type_
<<
"), inputs:("
;
...
...
paddle/framework/operator.h
浏览文件 @
0e77b31a
...
@@ -14,18 +14,20 @@ limitations under the License. */
...
@@ -14,18 +14,20 @@ limitations under the License. */
#pragma once
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <boost/variant.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -69,11 +71,72 @@ class OperatorBase {
...
@@ -69,11 +71,72 @@ class OperatorBase {
virtual
void
Run
(
const
ScopePtr
&
scope
,
virtual
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
// Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
// init in_out_idxs_ to accelerate argument's offset lookup.
void
CreateInOutOffsetMap
(
const
OpProto
&
proto
);
public:
public:
std
::
string
type_
;
std
::
string
type_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
// store the arguments' offset described in op_desc.
std
::
unordered_map
<
std
::
string
,
int
>
in_out_idxs_
;
};
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
const
Variable
*
Input
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Input
(
name
));
}
const
Variable
*
Output
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Output
(
name
));
}
const
std
::
vector
<
const
Variable
*>
Inputs
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Inputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
const
std
::
vector
<
const
Variable
*>
Outputs
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Outputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
OpKernel
{
class
OpKernel
{
...
@@ -84,25 +147,6 @@ class OpKernel {
...
@@ -84,25 +147,6 @@ class OpKernel {
* device resource such as CUDA stream, cublas handle, etc. from
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
* KernelContext. User should construct it before run the Operator.
*/
*/
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
const
OperatorBase
&
op_
;
const
ScopePtr
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
virtual
~
OpKernel
()
{}
...
@@ -147,7 +191,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -147,7 +191,7 @@ class OperatorWithKernel : public OperatorBase {
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
OpKernel
::
KernelContext
(
this
,
scope
,
dev_ctx
));
opKernel
->
Compute
(
KernelContext
(
this
,
scope
,
dev_ctx
));
}
}
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
...
@@ -155,6 +199,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -155,6 +199,7 @@ class OperatorWithKernel : public OperatorBase {
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
return
g_all_op_kernels
;
return
g_all_op_kernels
;
}
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
const
Tensor
*>
ins
;
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
0e77b31a
...
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
op_run_num
++
;
op_run_num
++
;
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
outputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
outputs_
.
size
(),
1
);
ASSERT_NEAR
(
GetAttr
<
float
>
(
"scale"
),
3.14
,
1e-5
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
...
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
public:
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddInput
(
"x"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddOutput
(
"y"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
AddComment
(
"This is test op"
);
}
}
};
};
...
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
...
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
class
CPUKernelTest
:
public
OpKernel
{
class
CPUKernelTest
:
public
OpKernel
{
public:
public:
void
Compute
(
const
KernelContext
&
context
)
const
{
void
Compute
(
const
KernelContext
&
ctx
)
const
{
std
::
cout
<<
"this is cpu kernel"
<<
std
::
endl
;
std
::
cout
<<
ctx
.
op_
.
DebugString
()
<<
std
::
endl
;
cpu_kernel_run_num
++
;
cpu_kernel_run_num
++
;
ASSERT_EQ
((
int
)
context
.
op_
.
inputs_
.
size
(),
1
);
ASSERT_EQ
(
ctx
.
op_
.
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
((
int
)
context
.
op_
.
outputs_
.
size
(),
1
);
ASSERT_EQ
(
ctx
.
op_
.
Output
(
"y"
),
"OUT1"
);
ASSERT_NEAR
(
context
.
op_
.
GetAttr
<
float
>
(
"scale"
),
3.14
,
1e-5
);
}
};
// multiple inputs test
class
OperatorMultiInputsTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
Input
(
"y"
),
"OUT1"
);
}
public:
float
x
=
0
;
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInputs
(
"xs"
,
"inputs of test op"
);
AddInput
(
"k"
,
"input of test op"
);
AddOutputs
(
"ys"
,
"outputs of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
}
};
class
CPUKernalMultiInputsTest
:
public
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
ctx
)
const
{
auto
xs
=
ctx
.
op_
.
Inputs
(
"xs"
);
ASSERT_EQ
(
xs
.
size
(),
3UL
);
ASSERT_EQ
(
xs
[
0
],
"x0"
);
ASSERT_EQ
(
xs
[
1
],
"x1"
);
ASSERT_EQ
(
xs
[
2
],
"x2"
);
auto
k
=
ctx
.
op_
.
Input
(
"k"
);
ASSERT_EQ
(
k
,
"k0"
);
auto
ys
=
ctx
.
op_
.
Outputs
(
"ys"
);
ASSERT_EQ
(
ys
.
size
(),
2UL
);
ASSERT_EQ
(
ys
[
0
],
"y0"
);
ASSERT_EQ
(
ys
[
1
],
"y1"
);
}
}
};
};
...
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
...
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
// test with single input
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
...
@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
...
@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
op
->
Run
(
scope
,
cpu_device_context
);
op
->
Run
(
scope
,
cpu_device_context
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
}
}
REGISTER_OP
(
op_multi_inputs_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestMultiInputsProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_multi_inputs_with_kernel
,
paddle
::
framework
::
CPUKernalMultiInputsTest
);
// test with multi inputs
TEST
(
OpKernel
,
multi_inputs
)
{
using
namespace
paddle
::
framework
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x0"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x1"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x2"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"k0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.14
);
auto
attr0
=
op_desc
.
mutable_attrs
()
->
Add
();
attr0
->
set_name
(
"input_format"
);
attr0
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
input_format
=
attr0
->
mutable_ints
();
input_format
->
Add
(
0
);
// x0
input_format
->
Add
(
3
);
// k
input_format
->
Add
(
4
);
// end
auto
attr1
=
op_desc
.
mutable_attrs
()
->
Add
();
attr1
->
set_name
(
"output_format"
);
attr1
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
output_format
=
attr1
->
mutable_ints
();
output_format
->
Add
(
0
);
// y0
output_format
->
Add
(
2
);
// y1
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorPtr
op
(
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
));
op
->
Run
(
scope
,
cpu_device_context
);
}
paddle/framework/tensor.cc
0 → 100644
浏览文件 @
0e77b31a
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <paddle/framework/tensor.h>
namespace
paddle
{
namespace
framework
{}
}
// namespace paddle
paddle/operators/CMakeLists.txt
浏览文件 @
0e77b31a
if
(
WITH_GPU
)
function
(
op_library TARGET
)
nv_library
(
add_op SRCS add_op.cc add_op.cu DEPS operator op_registry glog ddim
)
# op_library is a function to create op library. The interface is same as
else
()
# cc_library. But it handle split GPU/CPU code and link some common library
cc_library
(
add_op SRCS add_op.cc DEPS operator op_registry glog ddim
)
# for ops.
endif
()
set
(
cc_srcs
)
set
(
cu_srcs
)
set
(
op_common_deps operator op_registry
)
set
(
options
""
)
set
(
oneValueArgs
""
)
set
(
multiValueArgs SRCS DEPS
)
cmake_parse_arguments
(
op_library
"
${
options
}
"
"
${
oneValueArgs
}
"
"
${
multiValueArgs
}
"
${
ARGN
}
)
foreach
(
src
${
op_library_SRCS
}
)
if
(
${
src
}
MATCHES
".*
\\
.cu$"
)
list
(
APPEND cu_srcs
${
src
}
)
elseif
(
${
src
}
MATCHES
".*
\\
.cc$"
)
list
(
APPEND cc_srcs
${
src
}
)
else
()
message
(
FATAL_ERROR
"
${
TARGET
}
Source file
${
src
}
should only be .cc or .cu"
)
endif
()
endforeach
()
list
(
LENGTH cc_srcs cc_srcs_len
)
if
(
${
cc_srcs_len
}
EQUAL 0
)
message
(
FATAL_ERROR
"The op library
${
TARGET
}
should contains at least one .cc file"
)
endif
()
list
(
LENGTH cu_srcs cu_srcs_len
)
if
(
${
cu_srcs_len
}
EQUAL 0
)
message
(
WARNING
"The op library
${
TARGET
}
not support GPU!"
)
endif
()
if
(
WITH_GPU
)
nv_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
${
cu_srcs
}
DEPS
${
op_library_DEPS
}
${
op_common_deps
}
)
else
()
cc_library
(
${
TARGET
}
SRCS
${
cc_srcs
}
DEPS
${
op_library_DEPS
}
${
op_common_deps
}
)
endif
()
endfunction
()
op_library
(
add_op SRCS add_op.cc add_op.cu
)
cc_test
(
add_op_test SRCS add_op_test.cc DEPS add_op
)
cc_test
(
add_op_test SRCS add_op_test.cc DEPS add_op
)
paddle/operators/add_op.h
浏览文件 @
0e77b31a
...
@@ -8,10 +8,10 @@ namespace operators {
...
@@ -8,10 +8,10 @@ namespace operators {
template
<
typename
Place
>
template
<
typename
Place
>
class
AddKernel
:
public
framework
::
OpKernel
{
class
AddKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
LOG
(
INFO
)
<<
"Add kernel in "
<<
typeid
(
Place
).
name
();
LOG
(
INFO
)
<<
"Add kernel in "
<<
typeid
(
Place
).
name
();
}
}
};
};
}
// namespace op
}
// namespace op
erators
}
// namespace paddle
}
// namespace paddle
paddle/scripts/docker/build_android.sh
浏览文件 @
0e77b31a
...
@@ -2,9 +2,9 @@
...
@@ -2,9 +2,9 @@
set
-xe
set
-xe
mkdir
-p
/paddle/build
mkdir
-p
/paddle/build
_android
cd
/paddle/build
cd
/paddle/build
_android
rm
-f
/paddle/install 2>/dev/null
||
true
rm
-
r
f
/paddle/install 2>/dev/null
||
true
cmake
-DCMAKE_SYSTEM_NAME
=
Android
\
cmake
-DCMAKE_SYSTEM_NAME
=
Android
\
-DANDROID_STANDALONE_TOOLCHAIN
=
$ANDROID_STANDALONE_TOOLCHAIN
\
-DANDROID_STANDALONE_TOOLCHAIN
=
$ANDROID_STANDALONE_TOOLCHAIN
\
-DANDROID_ABI
=
armeabi-v7a
\
-DANDROID_ABI
=
armeabi-v7a
\
...
@@ -21,6 +21,3 @@ cmake -DCMAKE_SYSTEM_NAME=Android \
...
@@ -21,6 +21,3 @@ cmake -DCMAKE_SYSTEM_NAME=Android \
..
..
make
-j
`
nproc
`
make
-j
`
nproc
`
make
install
make
install
export
PATH
=
/paddle/install/bin:/paddle/install/opt/paddle/bin:
$PATH
paddle version
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录