Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ca23d861
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca23d861
编写于
7月 12, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
merge baidu/develop
上级
4d336d90
0a320081
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
294 addition
and
264 deletion
+294
-264
paddle/CMakeLists.txt
paddle/CMakeLists.txt
+0
-1
paddle/framework/dim.h
paddle/framework/dim.h
+0
-48
paddle/framework/dim_test.cu
paddle/framework/dim_test.cu
+0
-28
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+16
-17
paddle/framework/operator.cc
paddle/framework/operator.cc
+0
-8
paddle/framework/operator.h
paddle/framework/operator.h
+80
-37
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+16
-23
paddle/framework/tensor.h
paddle/framework/tensor.h
+48
-9
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+118
-22
paddle/operators/.clang-format
paddle/operators/.clang-format
+0
-5
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+0
-0
paddle/operators/demo_op.h
paddle/operators/demo_op.h
+0
-59
paddle/platform/cuda_device_context.h
paddle/platform/cuda_device_context.h
+6
-3
paddle/platform/device_context.h
paddle/platform/device_context.h
+10
-4
未找到文件。
paddle/CMakeLists.txt
浏览文件 @
ca23d861
...
@@ -15,7 +15,6 @@ if(Boost_FOUND)
...
@@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory
(
memory
)
add_subdirectory
(
memory
)
add_subdirectory
(
platform
)
add_subdirectory
(
platform
)
add_subdirectory
(
framework
)
add_subdirectory
(
framework
)
add_subdirectory
(
operators
)
add_subdirectory
(
pybind
)
add_subdirectory
(
pybind
)
endif
()
endif
()
...
...
paddle/framework/dim.h
浏览文件 @
ca23d861
...
@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
...
@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
));
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
));
}
}
/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template
<
int
i
>
HOST
bool
contiguous
(
const
Dim
<
i
>&
size
,
const
Dim
<
i
>&
stride
,
int
mul
=
1
)
{
if
(
product
(
size
)
==
0
)
return
true
;
int
contiguous_stride
=
get
<
0
>
(
size
)
==
1
?
0
:
mul
;
return
(
get
<
0
>
(
stride
)
==
contiguous_stride
&&
contiguous
(
size
.
tail
,
stride
.
tail
,
mul
*
get
<
0
>
(
size
)));
}
///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template
<
>
inline
bool
contiguous
(
const
Dim
<
1
>&
size
,
const
Dim
<
1
>&
stride
,
int
mul
)
{
if
(
get
<
0
>
(
size
)
==
0
)
return
true
;
int
contiguous_stride
=
get
<
0
>
(
size
)
==
1
?
0
:
mul
;
return
get
<
0
>
(
stride
)
==
contiguous_stride
;
}
///\endcond
/**
/**
* \brief Compute exclusive prefix-multiply of a Dim.
* \brief Compute exclusive prefix-multiply of a Dim.
*/
*/
...
@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
...
@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
}
}
///\endcond
///\endcond
/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
contiguous_strides
(
const
Dim
<
i
>&
size
,
int
base
=
1
)
{
int
stride
=
size
.
head
==
1
?
0
:
base
;
return
Dim
<
i
>
(
stride
,
contiguous_strides
(
size
.
tail
,
base
*
size
.
head
));
}
///\cond HIDDEN
// Base case of contiguous_strides
template
<
>
HOSTDEVICE
inline
Dim
<
1
>
contiguous_strides
(
const
Dim
<
1
>&
size
,
int
base
)
{
int
stride
=
size
.
head
==
1
?
0
:
base
;
return
Dim
<
1
>
(
stride
);
}
///\endcond
/**
/**
* Add two dimensions together
* Add two dimensions together
*/
*/
...
...
paddle/framework/dim_test.cu
浏览文件 @
ca23d861
...
@@ -58,24 +58,6 @@ TEST(Dim, Equality) {
...
@@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
// contiguous_strides
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
10
,
1
,
10
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
0
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
10
);
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
10
,
10
,
1
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
10
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
0
);
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
1
,
10
,
10
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
0
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
10
);
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
2
,
3
,
4
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
2
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
6
);
// generate from an index
// generate from an index
auto
size
=
paddle
::
framework
::
make_dim
(
4
,
5
,
2
);
auto
size
=
paddle
::
framework
::
make_dim
(
4
,
5
,
2
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
14
,
size
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
14
,
size
);
...
@@ -101,16 +83,6 @@ TEST(Dim, Bool) {
...
@@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE
(
a
==
a
);
EXPECT_TRUE
(
a
==
a
);
EXPECT_FALSE
(
a
==
b
);
EXPECT_FALSE
(
a
==
b
);
EXPECT_TRUE
(
a
==
c
);
EXPECT_TRUE
(
a
==
c
);
// contiguous check
int
x
=
4
,
y
=
5
,
z
=
2
;
paddle
::
framework
::
Dim
<
3
>
sizef
(
x
,
y
,
z
);
paddle
::
framework
::
Dim
<
3
>
stridea
(
1
,
x
,
x
*
y
);
paddle
::
framework
::
Dim
<
3
>
strideb
(
2
,
2
*
x
,
2
*
x
*
y
);
paddle
::
framework
::
Dim
<
3
>
stridec
(
1
,
x
,
2
*
x
*
y
);
EXPECT_TRUE
(
paddle
::
framework
::
contiguous
(
sizef
,
stridea
));
EXPECT_FALSE
(
paddle
::
framework
::
contiguous
(
sizef
,
strideb
));
EXPECT_FALSE
(
paddle
::
framework
::
contiguous
(
sizef
,
stridec
));
}
}
TEST
(
Dim
,
Print
)
{
TEST
(
Dim
,
Print
)
{
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
ca23d861
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
CosineOp
:
public
Operator
WithKernel
{
class
CosineOp
:
public
Operator
Base
{
public:
public:
void
Run
(
const
OpRunContext
*
context
)
const
override
{
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
}
};
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
class
MyTestOp
:
public
OperatorWithKernel
{
class
MyTestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
public:
public:
void
Run
(
const
OpRunContext
*
ctx
)
const
override
{
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
printf
(
"test_attr = %d
\n
"
,
ctx
->
op_
->
GetAttr
<
int
>
(
"test_attr"
));
}
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
...
@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
dev_ctx
=
DeviceContext
()
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
op
->
Run
(
scope
,
&
dev_ctx
);
op
->
Run
(
scope
,
dev_ctx
);
float
scale_get
=
op
->
GetAttr
<
float
>
(
"scale"
);
float
scale_get
=
op
->
GetAttr
<
float
>
(
"scale"
);
ASSERT_EQ
(
scale_get
,
scale
);
ASSERT_EQ
(
scale_get
,
scale
);
}
}
...
@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
...
@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
dev_ctx
=
DeviceContext
()
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
op
->
Run
(
scope
,
&
dev_ctx
);
op
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
1.0
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
1.0
);
}
}
...
@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_i
(
4
);
attr
->
set_i
(
4
);
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
dev_ctx
=
DeviceContext
()
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
op
->
Run
(
scope
,
&
dev_ctx
);
op
->
Run
(
scope
,
dev_ctx
);
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
ASSERT_EQ
(
test_attr
,
4
);
ASSERT_EQ
(
test_attr
,
4
);
}
}
...
...
paddle/framework/operator.cc
浏览文件 @
ca23d861
...
@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
...
@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return
ss
.
str
();
return
ss
.
str
();
}
}
const
Variable
*
OpRunContext
::
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
->
inputs_
[
index
]);
}
Variable
*
OpRunContext
::
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
->
outputs_
[
index
]);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
\ No newline at end of file
paddle/framework/operator.h
浏览文件 @
ca23d861
...
@@ -14,44 +14,22 @@ limitations under the License. */
...
@@ -14,44 +14,22 @@ limitations under the License. */
#pragma once
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <boost/variant.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OperatorBase
;
class
OperatorBase
;
class
DeviceContext
{};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class
OpRunContext
{
public:
OpRunContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>
scope
,
const
DeviceContext
*
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
;
Variable
*
Output
(
int
index
)
const
;
public:
const
OperatorBase
*
op_
;
const
std
::
shared_ptr
<
Scope
>
scope_
;
const
DeviceContext
*
device_context_
;
};
/**
/**
* OperatorBase has the basic element that Net will call to do computation.
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
* Only CreateOperator from OpRegistry will new Operator directly. User
...
@@ -77,7 +55,10 @@ class OperatorBase {
...
@@ -77,7 +55,10 @@ class OperatorBase {
/// Net will call this function to Run an op.
/// Net will call this function to Run an op.
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
DeviceContext
*
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
protected:
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
public:
public:
OpDesc
desc_
;
OpDesc
desc_
;
...
@@ -86,22 +67,84 @@ class OperatorBase {
...
@@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
class
OpKernel
{
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
};
class
OperatorWithKernel
:
public
OperatorBase
{
class
OperatorWithKernel
:
public
OperatorBase
{
public:
public:
virtual
~
OperatorWithKernel
()
{}
struct
OpKernelKey
{
platform
::
Place
place_
;
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{}
OpKernelKey
()
=
default
;
OpKernelKey
(
const
platform
::
DeviceContext
&
dev_ctx
)
{
place_
=
dev_ctx
.
GetPlace
();
}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
place_
==
o
.
place_
;
}
};
struct
OpKernelHash
{
std
::
hash
<
bool
>
hash_
;
size_t
operator
()(
const
OpKernelKey
&
key
)
const
{
return
hash_
(
platform
::
is_gpu_place
(
key
.
place_
));
}
};
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
DeviceContext
*
dev_ctx
)
const
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
OpRunContext
op_ctx
(
this
,
scope
,
dev_ctx
);
auto
&
opKernel
=
AllOpKernels
().
at
(
Type
()).
at
(
OpKernelKey
(
dev_ctx
)
);
Run
(
&
op_ctx
);
opKernel
->
Compute
(
OpKernel
::
KernelContext
(
this
,
scope
,
dev_ctx
)
);
}
}
/// when implement an Op, your should implement this function.
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
/// this function should be moved to OpKernel later
AllOpKernels
()
{
virtual
void
Run
(
const
OpRunContext
*
context
)
const
=
0
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
return
g_all_op_kernels
;
};
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
paddle/framework/operator_test.cc
浏览文件 @
ca23d861
...
@@ -19,17 +19,15 @@ limitations under the License. */
...
@@ -19,17 +19,15 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OperatorTest
:
public
Operator
WithKernel
{
class
OperatorTest
:
public
Operator
Base
{
public:
public:
void
Run
(
const
OpRunContext
*
ctx
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
float
scale
=
ctx
->
op_
->
GetAttr
<
float
>
(
"scale"
);
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
PADDLE_ENFORCE
(
ctx
->
Input
(
0
)
==
nullptr
,
"Input(0) should not initialized"
);
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
Output
(
0
)
==
nullptr
,
float
scale
=
GetAttr
<
float
>
(
"scale"
);
"Output(1) should not initialized"
);
ASSERT_NEAR
(
scale
,
3.14
,
1e-5
);
auto
output1
=
ctx
->
scope_
->
CreateVariable
(
"output1"
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
PADDLE_ENFORCE
(
output1
!=
nullptr
,
"should create output1 from scope"
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
printf
(
"get attr %s = %f
\n
"
,
"scale"
,
scale
);
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
}
}
};
};
...
@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP
(
OperatorTest
,
OperatorTestProtoAndCheckerMaker
,
test_operator
)
REGISTER_OP
(
OperatorTest
,
OperatorTestProtoAndCheckerMaker
,
test_operator
)
TEST
(
OperatorBase
,
DebugString
)
{
TEST
(
OperatorBase
,
all
)
{
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
op_desc
.
set_type
(
"test_operator"
);
std
::
vector
<
std
::
string
>
inputs
=
{
"IN1"
,
"IN2"
};
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
for
(
auto
&
input
:
inputs
)
{
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
op_desc
.
add_inputs
(
input
);
}
std
::
vector
<
std
::
string
>
outputs
=
{
"OUT1"
,
"OUT2"
};
for
(
auto
&
output
:
outputs
)
{
op_desc
.
add_outputs
(
output
);
}
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
float
scale
=
3.14
;
float
scale
=
3.14
;
attr
->
set_f
(
scale
);
attr
->
set_f
(
scale
);
DeviceContext
device_context
;
platform
::
CPU
DeviceContext
device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
ASSERT_EQ
(
op
->
inputs_
,
inputs
);
ASSERT_EQ
(
op
->
outputs_
,
outputs
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
scale
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
scale
);
op
->
Run
(
scope
,
&
device_context
);
scope
->
CreateVariable
(
"OUT1"
);
op
->
Run
(
scope
,
device_context
);
std
::
cout
<<
op
->
DebugString
()
<<
std
::
endl
;
delete
op
;
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/tensor.h
浏览文件 @
ca23d861
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <cstdint>
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include "paddle/framework/ddim.h"
#include "paddle/framework/ddim.h"
...
@@ -26,31 +27,65 @@ namespace framework {
...
@@ -26,31 +27,65 @@ namespace framework {
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
:
offset_
(
0
)
{}
explicit
Tensor
(
const
DDim
&
dims
)
:
dims_
(
dims
),
offset_
(
0
)
{}
template
<
typename
T
>
template
<
typename
T
>
const
T
*
data
()
const
{
const
T
*
data
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
PADDLE_ENFORCE
(
"Tensor::data must be called after Tensor::mutable_data."
);
holder_
!=
nullptr
,
return
static_cast
<
const
T
*>
(
holder_
->
Ptr
());
"Tenosr has not been initialized. Call Tensor::mutable_data first."
);
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
Ptr
())
+
offset_
);
}
}
template
<
typename
T
,
// must be POD types
template
<
typename
T
,
// must be POD types
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
T
*
mutable_data
(
DDim
dims
,
paddle
::
platform
::
Place
place
)
{
T
*
mutable_data
(
DDim
dims
,
paddle
::
platform
::
Place
place
)
{
dims_
=
dims
;
if
(
holder_
==
nullptr
||
if
(
holder_
==
nullptr
||
!
(
holder_
->
Place
()
==
!
(
holder_
->
Place
()
==
place
)
/* some versions of boost::variant don't have operator!= */
place
)
/* some versions of boost::variant don't have operator!= */
||
holder_
->
Size
()
<
product
(
dims
)
*
sizeof
(
T
))
{
||
holder_
->
Size
()
<
product
(
dims
)
*
sizeof
(
T
)
+
offset_
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
place
,
product
(
dims
)
*
sizeof
(
T
)));
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
place
,
product
(
dims
)
*
sizeof
(
T
)));
offset_
=
0
;
}
}
return
static_cast
<
T
*>
(
holder_
->
Ptr
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
Ptr
())
+
offset_
);
}
}
template
<
typename
T
,
// must be POD types
void
ShareDataFrom
(
const
Tensor
&
src
)
{
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
PADDLE_ENFORCE
(
src
.
holder_
!=
nullptr
,
T
*
mutable_data
(
DDim
dims
)
{
"Can not share data from an uninitialized tensor."
);
return
mutable_data
<
T
>
(
dims
,
paddle
::
platform
::
get_place
());
holder_
=
src
.
holder_
;
dims_
=
src
.
dims_
;
offset_
=
src
.
offset_
;
}
}
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"The sliced tenosr has not been initialized."
);
PADDLE_ENFORCE
(
begin_idx
>=
0
&&
end_idx
<=
dims_
[
0
],
"Slice index is less than zero or out of bound."
);
PADDLE_ENFORCE
(
begin_idx
<
end_idx
,
"Begin index must be less than end index."
);
PADDLE_ENFORCE
(
dims_
[
0
]
!=
1
,
"Can not slice a tensor with dims_[0] = 1."
);
std
::
vector
<
int
>
d
=
vectorize
(
dims_
);
int
base
=
1
;
for
(
size_t
i
=
1
;
i
<
d
.
size
();
++
i
)
{
base
*=
d
[
i
];
}
Tensor
dst
;
dst
.
holder_
=
holder_
;
dst
.
dims_
=
dims_
;
dst
.
dims_
[
0
]
=
end_idx
-
begin_idx
;
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
holder_
->
TypeSize
();
return
dst
;
}
DDim
dims
()
const
{
return
dims_
;
}
private:
private:
// Placeholder hides type T, so it doesn't appear as a template
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
// parameter of Variable.
...
@@ -59,6 +94,7 @@ class Tensor {
...
@@ -59,6 +94,7 @@ class Tensor {
virtual
void
*
Ptr
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
paddle
::
platform
::
Place
Place
()
const
=
0
;
virtual
paddle
::
platform
::
Place
Place
()
const
=
0
;
virtual
size_t
Size
()
const
=
0
;
virtual
size_t
Size
()
const
=
0
;
virtual
size_t
TypeSize
()
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -85,6 +121,7 @@ class Tensor {
...
@@ -85,6 +121,7 @@ class Tensor {
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
size_t
Size
()
const
{
return
size_
;
}
virtual
size_t
Size
()
const
{
return
size_
;
}
virtual
paddle
::
platform
::
Place
Place
()
const
{
return
place_
;
}
virtual
paddle
::
platform
::
Place
Place
()
const
{
return
place_
;
}
virtual
size_t
TypeSize
()
const
{
return
sizeof
(
T
);
}
std
::
unique_ptr
<
T
,
Deleter
>
ptr_
;
std
::
unique_ptr
<
T
,
Deleter
>
ptr_
;
paddle
::
platform
::
Place
place_
;
// record the place of ptr_.
paddle
::
platform
::
Place
place_
;
// record the place of ptr_.
...
@@ -92,6 +129,8 @@ class Tensor {
...
@@ -92,6 +129,8 @@ class Tensor {
};
};
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holds the memory block if allocated.
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holds the memory block if allocated.
DDim
dims_
;
size_t
offset_
;
// marks the begin of tensor data area.
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/tensor_test.cc
浏览文件 @
ca23d861
...
@@ -15,15 +15,27 @@
...
@@ -15,15 +15,27 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <string>
#include <string>
TEST
(
Tensor
,
ASSERT
)
{
TEST
(
Tensor
,
Dims
)
{
paddle
::
framework
::
Tensor
cpu_tensor
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
tt
(
make_ddim
({
2
,
3
,
4
}));
DDim
dims
=
tt
.
dims
();
ASSERT_EQ
(
arity
(
dims
),
3
);
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
i
+
2
,
dims
[
i
]);
}
}
TEST
(
Tensor
,
DataAssert
)
{
paddle
::
framework
::
Tensor
src_tensor
;
bool
caught
=
false
;
bool
caught
=
false
;
try
{
try
{
const
double
*
p
__attribute__
((
unused
))
=
cpu
_tensor
.
data
<
double
>
();
src
_tensor
.
data
<
double
>
();
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
"Tensor::data must be called after Tensor::mutable_data."
;
std
::
string
msg
=
"Tenosr has not been initialized. Call Tensor::mutable_data first."
;
const
char
*
what
=
err
.
what
();
const
char
*
what
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
...
@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) {
...
@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) {
ASSERT_TRUE
(
caught
);
ASSERT_TRUE
(
caught
);
}
}
/*
mutable_data() is not tested
at present
/*
following tests are not available
at present
because Memory::Alloc() and Memory::Free() have not been ready.
because Memory::Alloc() and Memory::Free() have not been ready.
TEST(Tensor, MutableData) {
TEST(Tensor, MutableData) {
using namespace paddle::framework;
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::platform;
{
{
Tensor
cpu
_tensor;
Tensor
src
_tensor;
float* p1 = nullptr;
float* p1 = nullptr;
float* p2 = nullptr;
float* p2 = nullptr;
// initialization
// initialization
p1 =
cpu
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
p1 =
src
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
EXPECT_NE(p1, nullptr);
EXPECT_NE(p1, nullptr);
// set
cpu
_tensor a new dim with large size
// set
src
_tensor a new dim with large size
// momery is supposed to be re-allocated
// momery is supposed to be re-allocated
p2 =
cpu_tensor.mutable_data<float>(make_ddim({3, 4}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace(
));
EXPECT_NE(p2, nullptr);
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
EXPECT_NE(p1, p2);
// set
cpu
_tensor a new dim with same size
// set
src
_tensor a new dim with same size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p1 =
cpu_tensor.mutable_data<float>(make_ddim({2, 2, 3}
));
p1 =
src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
// set
cpu
_tensor a new dim with smaller size
// set
src
_tensor a new dim with smaller size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p2 =
cpu_tensor.mutable_data<float>(make_ddim({2, 2}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
}
}
{
{
Tensor
gpu
_tensor;
Tensor
src
_tensor;
float* p1 = nullptr;
float* p1 = nullptr;
float* p2 = nullptr;
float* p2 = nullptr;
// initialization
// initialization
p1 =
gpu
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
p1 =
src
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
EXPECT_NE(p1, nullptr);
EXPECT_NE(p1, nullptr);
// set
gpu
_tensor a new dim with large size
// set
src
_tensor a new dim with large size
// momery is supposed to be re-allocated
// momery is supposed to be re-allocated
p2 =
gpu_tensor.mutable_data<float>(make_ddim({3, 4}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({3, 4}), GPUPlace(
));
EXPECT_NE(p2, nullptr);
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
EXPECT_NE(p1, p2);
// set
gpu
_tensor a new dim with same size
// set
src
_tensor a new dim with same size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p1 =
gpu_tensor.mutable_data<float>(make_ddim({2, 2, 3}
));
p1 =
src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), GPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
// set
gpu
_tensor a new dim with smaller size
// set
src
_tensor a new dim with smaller size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p2 =
gpu_tensor.mutable_data<float>(make_ddim({2, 2}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
}
}
}
}
*/
TEST(Tensor, ShareDataFrom) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
dst_tensor.ShareDataFrom(src_tensor);
} catch (EnforceNotMet err) {
caught = true;
std::string msg = "Can not share data from an uninitialized tensor.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
{
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
}
TEST(Tensor, Slice) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
src_tensor.mutable_data<int>(make_ddim({5, 3, 4}), CPUPlace());
Tensor slice_tensor = src_tensor.Slice(1, 3);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
EXPECT_EQ(slice_dims[2], 4);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<int>(slice_tensor.dims(), CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), GPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), GPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
}
}
*/
\ No newline at end of file
paddle/operators/.clang-format
已删除
100644 → 0
浏览文件 @
4d336d90
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
paddle/operators/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
4d336d90
paddle/operators/demo_op.h
已删除
100644 → 0
浏览文件 @
4d336d90
#pragma once
#include "paddle/framework/op_registry.h"
using
namespace
paddle
::
framework
;
namespace
paddle
{
namespace
operators
{
class
CosineOp
:
public
OperatorWithKernel
{
public:
void
Run
(
const
OpRunContext
*
context
)
const
override
{
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
CosineOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddType
(
"cos"
);
AddComment
(
"This is cos op"
);
}
};
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
class
MyTestOp
:
public
OperatorWithKernel
{
public:
void
Run
(
const
OpRunContext
*
context
)
const
override
{
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
}
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
AddAttr
<
int
>
(
"test_attr"
,
"a simple test attribute"
)
.
AddCustomChecker
(
my_checker
);
AddType
(
"my_test_op"
);
AddComment
(
"This is my_test op"
);
}
};
REGISTER_OP
(
MyTestOp
,
MyTestOpProtoAndCheckerMaker
,
my_test_op
)
}
// namespace operators
}
// namespace operators
paddle/platform/cuda_device_context.h
浏览文件 @
ca23d861
...
@@ -23,15 +23,13 @@ limitations under the License. */
...
@@ -23,15 +23,13 @@ limitations under the License. */
#include "paddle/platform/place.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "unsupported/Eigen/CXX11/Tensor"
using
DEVICE_GPU
=
Eigen
::
GpuDevice
;
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
class
CUDADeviceContext
;
class
CUDADeviceContext
;
template
<
>
template
<
>
DEVICE_GPU
DeviceContext
::
get_eigen_device
<
DEVICE_GPU
>
()
{
Eigen
::
GpuDevice
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
{
return
static_cast
<
CUDADeviceContext
*>
(
this
)
->
eigen_handle
();
return
static_cast
<
CUDADeviceContext
*>
(
this
)
->
eigen_handle
();
}
}
...
@@ -59,6 +57,11 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -59,6 +57,11 @@ class CUDADeviceContext : public DeviceContext {
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
}
}
Place
GetPlace
()
const
override
{
Place
retv
=
GPUPlace
();
return
retv
;
}
void
Wait
()
{
void
Wait
()
{
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
"cudaStreamSynchronize failed"
);
...
...
paddle/platform/device_context.h
浏览文件 @
ca23d861
...
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "unsupported/Eigen/CXX11/Tensor"
using
DEVICE_CPU
=
Eigen
::
DefaultDevice
;
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -28,10 +28,12 @@ class DeviceContext {
...
@@ -28,10 +28,12 @@ class DeviceContext {
template
<
typename
DeviceType
>
template
<
typename
DeviceType
>
DeviceType
get_eigen_device
();
DeviceType
get_eigen_device
();
virtual
Place
GetPlace
()
const
=
0
;
};
};
template
<
>
template
<
>
DEVICE_CPU
DeviceContext
::
get_eigen_device
<
DEVICE_CPU
>
()
{
Eigen
::
DefaultDevice
DeviceContext
::
get_eigen_device
<
Eigen
::
DefaultDevice
>
()
{
return
static_cast
<
CPUDeviceContext
*>
(
this
)
->
eigen_handle
();
return
static_cast
<
CPUDeviceContext
*>
(
this
)
->
eigen_handle
();
}
}
...
@@ -44,9 +46,13 @@ class CPUDeviceContext : public DeviceContext {
...
@@ -44,9 +46,13 @@ class CPUDeviceContext : public DeviceContext {
return
*
eigen_handle_
;
return
*
eigen_handle_
;
}
}
Place
GetPlace
()
const
override
{
Place
retv
=
CPUPlace
();
return
retv
;
}
private:
private:
Eigen
::
DefaultDevice
*
eigen_handle_
{
nullptr
};
Eigen
::
DefaultDevice
*
eigen_handle_
{
nullptr
};
};
};
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录