Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ca23d861
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ca23d861
编写于
7月 12, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
差异文件
merge baidu/develop
上级
4d336d90
0a320081
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
294 addition
and
264 deletion
+294
-264
paddle/CMakeLists.txt
paddle/CMakeLists.txt
+0
-1
paddle/framework/dim.h
paddle/framework/dim.h
+0
-48
paddle/framework/dim_test.cu
paddle/framework/dim_test.cu
+0
-28
paddle/framework/op_registry_test.cc
paddle/framework/op_registry_test.cc
+16
-17
paddle/framework/operator.cc
paddle/framework/operator.cc
+0
-8
paddle/framework/operator.h
paddle/framework/operator.h
+80
-37
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+16
-23
paddle/framework/tensor.h
paddle/framework/tensor.h
+48
-9
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+118
-22
paddle/operators/.clang-format
paddle/operators/.clang-format
+0
-5
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+0
-0
paddle/operators/demo_op.h
paddle/operators/demo_op.h
+0
-59
paddle/platform/cuda_device_context.h
paddle/platform/cuda_device_context.h
+6
-3
paddle/platform/device_context.h
paddle/platform/device_context.h
+10
-4
未找到文件。
paddle/CMakeLists.txt
浏览文件 @
ca23d861
...
@@ -15,7 +15,6 @@ if(Boost_FOUND)
...
@@ -15,7 +15,6 @@ if(Boost_FOUND)
add_subdirectory
(
memory
)
add_subdirectory
(
memory
)
add_subdirectory
(
platform
)
add_subdirectory
(
platform
)
add_subdirectory
(
framework
)
add_subdirectory
(
framework
)
add_subdirectory
(
operators
)
add_subdirectory
(
pybind
)
add_subdirectory
(
pybind
)
endif
()
endif
()
...
...
paddle/framework/dim.h
浏览文件 @
ca23d861
...
@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
...
@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
));
return
((
0
<=
idx
.
head
)
&&
(
idx
.
head
<
size
.
head
));
}
}
/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template
<
int
i
>
HOST
bool
contiguous
(
const
Dim
<
i
>&
size
,
const
Dim
<
i
>&
stride
,
int
mul
=
1
)
{
if
(
product
(
size
)
==
0
)
return
true
;
int
contiguous_stride
=
get
<
0
>
(
size
)
==
1
?
0
:
mul
;
return
(
get
<
0
>
(
stride
)
==
contiguous_stride
&&
contiguous
(
size
.
tail
,
stride
.
tail
,
mul
*
get
<
0
>
(
size
)));
}
///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template
<
>
inline
bool
contiguous
(
const
Dim
<
1
>&
size
,
const
Dim
<
1
>&
stride
,
int
mul
)
{
if
(
get
<
0
>
(
size
)
==
0
)
return
true
;
int
contiguous_stride
=
get
<
0
>
(
size
)
==
1
?
0
:
mul
;
return
get
<
0
>
(
stride
)
==
contiguous_stride
;
}
///\endcond
/**
/**
* \brief Compute exclusive prefix-multiply of a Dim.
* \brief Compute exclusive prefix-multiply of a Dim.
*/
*/
...
@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
...
@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
}
}
///\endcond
///\endcond
/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template
<
int
i
>
HOSTDEVICE
Dim
<
i
>
contiguous_strides
(
const
Dim
<
i
>&
size
,
int
base
=
1
)
{
int
stride
=
size
.
head
==
1
?
0
:
base
;
return
Dim
<
i
>
(
stride
,
contiguous_strides
(
size
.
tail
,
base
*
size
.
head
));
}
///\cond HIDDEN
// Base case of contiguous_strides
template
<
>
HOSTDEVICE
inline
Dim
<
1
>
contiguous_strides
(
const
Dim
<
1
>&
size
,
int
base
)
{
int
stride
=
size
.
head
==
1
?
0
:
base
;
return
Dim
<
1
>
(
stride
);
}
///\endcond
/**
/**
* Add two dimensions together
* Add two dimensions together
*/
*/
...
...
paddle/framework/dim_test.cu
浏览文件 @
ca23d861
...
@@ -58,24 +58,6 @@ TEST(Dim, Equality) {
...
@@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
3
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
12
);
// contiguous_strides
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
10
,
1
,
10
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
0
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
10
);
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
10
,
10
,
1
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
10
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
0
);
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
1
,
10
,
10
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
0
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
10
);
c
=
paddle
::
framework
::
contiguous_strides
(
paddle
::
framework
::
Dim
<
3
>
(
2
,
3
,
4
));
EXPECT_EQ
(
paddle
::
framework
::
get
<
0
>
(
c
),
1
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
1
>
(
c
),
2
);
EXPECT_EQ
(
paddle
::
framework
::
get
<
2
>
(
c
),
6
);
// generate from an index
// generate from an index
auto
size
=
paddle
::
framework
::
make_dim
(
4
,
5
,
2
);
auto
size
=
paddle
::
framework
::
make_dim
(
4
,
5
,
2
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
14
,
size
);
c
=
paddle
::
framework
::
Dim
<
3
>
(
14
,
size
);
...
@@ -101,16 +83,6 @@ TEST(Dim, Bool) {
...
@@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE
(
a
==
a
);
EXPECT_TRUE
(
a
==
a
);
EXPECT_FALSE
(
a
==
b
);
EXPECT_FALSE
(
a
==
b
);
EXPECT_TRUE
(
a
==
c
);
EXPECT_TRUE
(
a
==
c
);
// contiguous check
int
x
=
4
,
y
=
5
,
z
=
2
;
paddle
::
framework
::
Dim
<
3
>
sizef
(
x
,
y
,
z
);
paddle
::
framework
::
Dim
<
3
>
stridea
(
1
,
x
,
x
*
y
);
paddle
::
framework
::
Dim
<
3
>
strideb
(
2
,
2
*
x
,
2
*
x
*
y
);
paddle
::
framework
::
Dim
<
3
>
stridec
(
1
,
x
,
2
*
x
*
y
);
EXPECT_TRUE
(
paddle
::
framework
::
contiguous
(
sizef
,
stridea
));
EXPECT_FALSE
(
paddle
::
framework
::
contiguous
(
sizef
,
strideb
));
EXPECT_FALSE
(
paddle
::
framework
::
contiguous
(
sizef
,
stridec
));
}
}
TEST
(
Dim
,
Print
)
{
TEST
(
Dim
,
Print
)
{
...
...
paddle/framework/op_registry_test.cc
浏览文件 @
ca23d861
#include "paddle/framework/op_registry.h"
#include "paddle/framework/op_registry.h"
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/framework/operator.h"
#include "paddle/operators/demo_op.h"
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
framework
;
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
CosineOp
:
public
Operator
WithKernel
{
class
CosineOp
:
public
Operator
Base
{
public:
public:
void
Run
(
const
OpRunContext
*
context
)
const
override
{
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{
}
};
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -30,12 +28,13 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
class
MyTestOp
:
public
OperatorWithKernel
{
class
MyTestOp
:
public
OperatorBase
{
public:
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{}
public:
public:
void
Run
(
const
OpRunContext
*
ctx
)
const
override
{
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
printf
(
"test_attr = %d
\n
"
,
ctx
->
op_
->
GetAttr
<
int
>
(
"test_attr"
));
}
};
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
...
@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
...
@@ -73,8 +72,8 @@ TEST(OpRegistry, CreateOp) {
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
dev_ctx
=
DeviceContext
()
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
op
->
Run
(
scope
,
&
dev_ctx
);
op
->
Run
(
scope
,
dev_ctx
);
float
scale_get
=
op
->
GetAttr
<
float
>
(
"scale"
);
float
scale_get
=
op
->
GetAttr
<
float
>
(
"scale"
);
ASSERT_EQ
(
scale_get
,
scale
);
ASSERT_EQ
(
scale_get
,
scale
);
}
}
...
@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
...
@@ -116,8 +115,8 @@ TEST(OpRegistry, DefaultValue) {
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
dev_ctx
=
DeviceContext
()
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
op
->
Run
(
scope
,
&
dev_ctx
);
op
->
Run
(
scope
,
dev_ctx
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
1.0
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
1.0
);
}
}
...
@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
...
@@ -169,9 +168,9 @@ TEST(OpRegistry, CustomChecker) {
attr
->
set_i
(
4
);
attr
->
set_i
(
4
);
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
auto
dev_ctx
=
DeviceContext
()
;
paddle
::
platform
::
CPUDeviceContext
dev_ctx
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
op
->
Run
(
scope
,
&
dev_ctx
);
op
->
Run
(
scope
,
dev_ctx
);
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
int
test_attr
=
op
->
GetAttr
<
int
>
(
"test_attr"
);
ASSERT_EQ
(
test_attr
,
4
);
ASSERT_EQ
(
test_attr
,
4
);
}
}
...
...
paddle/framework/operator.cc
浏览文件 @
ca23d861
...
@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
...
@@ -39,13 +39,5 @@ std::string OperatorBase::DebugString() const {
return
ss
.
str
();
return
ss
.
str
();
}
}
const
Variable
*
OpRunContext
::
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
->
inputs_
[
index
]);
}
Variable
*
OpRunContext
::
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
->
outputs_
[
index
]);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
\ No newline at end of file
paddle/framework/operator.h
浏览文件 @
ca23d861
...
@@ -14,44 +14,22 @@ limitations under the License. */
...
@@ -14,44 +14,22 @@ limitations under the License. */
#pragma once
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <boost/variant.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/utils/Error.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OperatorBase
;
class
OperatorBase
;
class
DeviceContext
{};
/**
* OpRunContext is the only parameter of Operator's Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* OpRunContext. User should construct it before run the Operator.
*/
class
OpRunContext
{
public:
OpRunContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>
scope
,
const
DeviceContext
*
device_context
)
:
op_
(
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
;
Variable
*
Output
(
int
index
)
const
;
public:
const
OperatorBase
*
op_
;
const
std
::
shared_ptr
<
Scope
>
scope_
;
const
DeviceContext
*
device_context_
;
};
/**
/**
* OperatorBase has the basic element that Net will call to do computation.
* OperatorBase has the basic element that Net will call to do computation.
* Only CreateOperator from OpRegistry will new Operator directly. User
* Only CreateOperator from OpRegistry will new Operator directly. User
...
@@ -77,7 +55,10 @@ class OperatorBase {
...
@@ -77,7 +55,10 @@ class OperatorBase {
/// Net will call this function to Run an op.
/// Net will call this function to Run an op.
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
virtual
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
DeviceContext
*
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
protected:
std
::
string
Type
()
const
{
return
desc_
.
type
();
}
public:
public:
OpDesc
desc_
;
OpDesc
desc_
;
...
@@ -86,22 +67,84 @@ class OperatorBase {
...
@@ -86,22 +67,84 @@ class OperatorBase {
AttributeMap
attrs_
;
AttributeMap
attrs_
;
};
};
class
OpKernel
{
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class
KernelContext
{
public:
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
const
Variable
*
Input
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
inputs_
[
index
]);
}
Variable
*
Output
(
int
index
)
const
{
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
const
OperatorBase
&
op_
;
const
std
::
shared_ptr
<
Scope
>&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
};
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
};
class
OperatorWithKernel
:
public
OperatorBase
{
class
OperatorWithKernel
:
public
OperatorBase
{
public:
public:
virtual
~
OperatorWithKernel
()
{}
struct
OpKernelKey
{
platform
::
Place
place_
;
virtual
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
{}
OpKernelKey
()
=
default
;
OpKernelKey
(
const
platform
::
DeviceContext
&
dev_ctx
)
{
place_
=
dev_ctx
.
GetPlace
();
}
bool
operator
==
(
const
OpKernelKey
&
o
)
const
{
return
place_
==
o
.
place_
;
}
};
struct
OpKernelHash
{
std
::
hash
<
bool
>
hash_
;
size_t
operator
()(
const
OpKernelKey
&
key
)
const
{
return
hash_
(
platform
::
is_gpu_place
(
key
.
place_
));
}
};
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
DeviceContext
*
dev_ctx
)
const
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
OpRunContext
op_ctx
(
this
,
scope
,
dev_ctx
);
auto
&
opKernel
=
AllOpKernels
().
at
(
Type
()).
at
(
OpKernelKey
(
dev_ctx
)
);
Run
(
&
op_ctx
);
opKernel
->
Compute
(
OpKernel
::
KernelContext
(
this
,
scope
,
dev_ctx
)
);
}
}
/// when implement an Op, your should implement this function.
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
/// this function should be moved to OpKernel later
AllOpKernels
()
{
virtual
void
Run
(
const
OpRunContext
*
context
)
const
=
0
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
return
g_all_op_kernels
;
};
};
};
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
#define REGISTER_OP_KERNEL(type, PlaceType, KernelType) \
struct __op_kernel_register__##type##__ { \
__op_kernel_register__##type##__() { \
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
key.place_ = PlaceType(); \
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
.reset(new KernelType()); \
} \
}; \
static __op_kernel_register__##type##__ __reg_kernel_##type##__
paddle/framework/operator_test.cc
浏览文件 @
ca23d861
...
@@ -19,17 +19,15 @@ limitations under the License. */
...
@@ -19,17 +19,15 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
class
OperatorTest
:
public
Operator
WithKernel
{
class
OperatorTest
:
public
Operator
Base
{
public:
public:
void
Run
(
const
OpRunContext
*
ctx
)
const
override
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
float
scale
=
ctx
->
op_
->
GetAttr
<
float
>
(
"scale"
);
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
PADDLE_ENFORCE
(
ctx
->
Input
(
0
)
==
nullptr
,
"Input(0) should not initialized"
);
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
Output
(
0
)
==
nullptr
,
float
scale
=
GetAttr
<
float
>
(
"scale"
);
"Output(1) should not initialized"
);
ASSERT_NEAR
(
scale
,
3.14
,
1e-5
);
auto
output1
=
ctx
->
scope_
->
CreateVariable
(
"output1"
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
PADDLE_ENFORCE
(
output1
!=
nullptr
,
"should create output1 from scope"
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
printf
(
"get attr %s = %f
\n
"
,
"scale"
,
scale
);
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
}
}
};
};
...
@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -49,31 +47,26 @@ class OperatorTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
REGISTER_OP
(
OperatorTest
,
OperatorTestProtoAndCheckerMaker
,
test_operator
)
REGISTER_OP
(
OperatorTest
,
OperatorTestProtoAndCheckerMaker
,
test_operator
)
TEST
(
OperatorBase
,
DebugString
)
{
TEST
(
OperatorBase
,
all
)
{
OpDesc
op_desc
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"test_operator"
);
op_desc
.
set_type
(
"test_operator"
);
std
::
vector
<
std
::
string
>
inputs
=
{
"IN1"
,
"IN2"
};
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"IN1"
;
for
(
auto
&
input
:
inputs
)
{
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"OUT1"
;
op_desc
.
add_inputs
(
input
);
}
std
::
vector
<
std
::
string
>
outputs
=
{
"OUT1"
,
"OUT2"
};
for
(
auto
&
output
:
outputs
)
{
op_desc
.
add_outputs
(
output
);
}
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
float
scale
=
3.14
;
float
scale
=
3.14
;
attr
->
set_f
(
scale
);
attr
->
set_f
(
scale
);
DeviceContext
device_context
;
platform
::
CPU
DeviceContext
device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
OperatorBase
*
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
);
ASSERT_EQ
(
op
->
inputs_
,
inputs
);
ASSERT_EQ
(
op
->
outputs_
,
outputs
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
scale
);
ASSERT_EQ
(
op
->
GetAttr
<
float
>
(
"scale"
),
scale
);
op
->
Run
(
scope
,
&
device_context
);
scope
->
CreateVariable
(
"OUT1"
);
op
->
Run
(
scope
,
device_context
);
std
::
cout
<<
op
->
DebugString
()
<<
std
::
endl
;
delete
op
;
}
}
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/tensor.h
浏览文件 @
ca23d861
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <cstdint>
#include <memory>
#include <memory>
#include <type_traits>
#include <type_traits>
#include "paddle/framework/ddim.h"
#include "paddle/framework/ddim.h"
...
@@ -26,31 +27,65 @@ namespace framework {
...
@@ -26,31 +27,65 @@ namespace framework {
class
Tensor
{
class
Tensor
{
public:
public:
Tensor
()
:
offset_
(
0
)
{}
explicit
Tensor
(
const
DDim
&
dims
)
:
dims_
(
dims
),
offset_
(
0
)
{}
template
<
typename
T
>
template
<
typename
T
>
const
T
*
data
()
const
{
const
T
*
data
()
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
PADDLE_ENFORCE
(
"Tensor::data must be called after Tensor::mutable_data."
);
holder_
!=
nullptr
,
return
static_cast
<
const
T
*>
(
holder_
->
Ptr
());
"Tenosr has not been initialized. Call Tensor::mutable_data first."
);
return
reinterpret_cast
<
const
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
Ptr
())
+
offset_
);
}
}
template
<
typename
T
,
// must be POD types
template
<
typename
T
,
// must be POD types
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
T
*
mutable_data
(
DDim
dims
,
paddle
::
platform
::
Place
place
)
{
T
*
mutable_data
(
DDim
dims
,
paddle
::
platform
::
Place
place
)
{
dims_
=
dims
;
if
(
holder_
==
nullptr
||
if
(
holder_
==
nullptr
||
!
(
holder_
->
Place
()
==
!
(
holder_
->
Place
()
==
place
)
/* some versions of boost::variant don't have operator!= */
place
)
/* some versions of boost::variant don't have operator!= */
||
holder_
->
Size
()
<
product
(
dims
)
*
sizeof
(
T
))
{
||
holder_
->
Size
()
<
product
(
dims
)
*
sizeof
(
T
)
+
offset_
)
{
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
place
,
product
(
dims
)
*
sizeof
(
T
)));
holder_
.
reset
(
new
PlaceholderImpl
<
T
>
(
place
,
product
(
dims
)
*
sizeof
(
T
)));
offset_
=
0
;
}
}
return
static_cast
<
T
*>
(
holder_
->
Ptr
());
return
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
Ptr
())
+
offset_
);
}
}
template
<
typename
T
,
// must be POD types
void
ShareDataFrom
(
const
Tensor
&
src
)
{
typename
std
::
enable_if
<
std
::
is_pod
<
T
>
::
value
>::
type
*
=
nullptr
>
PADDLE_ENFORCE
(
src
.
holder_
!=
nullptr
,
T
*
mutable_data
(
DDim
dims
)
{
"Can not share data from an uninitialized tensor."
);
return
mutable_data
<
T
>
(
dims
,
paddle
::
platform
::
get_place
());
holder_
=
src
.
holder_
;
dims_
=
src
.
dims_
;
offset_
=
src
.
offset_
;
}
}
Tensor
Slice
(
const
int
&
begin_idx
,
const
int
&
end_idx
)
const
{
PADDLE_ENFORCE
(
holder_
!=
nullptr
,
"The sliced tenosr has not been initialized."
);
PADDLE_ENFORCE
(
begin_idx
>=
0
&&
end_idx
<=
dims_
[
0
],
"Slice index is less than zero or out of bound."
);
PADDLE_ENFORCE
(
begin_idx
<
end_idx
,
"Begin index must be less than end index."
);
PADDLE_ENFORCE
(
dims_
[
0
]
!=
1
,
"Can not slice a tensor with dims_[0] = 1."
);
std
::
vector
<
int
>
d
=
vectorize
(
dims_
);
int
base
=
1
;
for
(
size_t
i
=
1
;
i
<
d
.
size
();
++
i
)
{
base
*=
d
[
i
];
}
Tensor
dst
;
dst
.
holder_
=
holder_
;
dst
.
dims_
=
dims_
;
dst
.
dims_
[
0
]
=
end_idx
-
begin_idx
;
dst
.
offset_
=
offset_
+
begin_idx
*
base
*
holder_
->
TypeSize
();
return
dst
;
}
DDim
dims
()
const
{
return
dims_
;
}
private:
private:
// Placeholder hides type T, so it doesn't appear as a template
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
// parameter of Variable.
...
@@ -59,6 +94,7 @@ class Tensor {
...
@@ -59,6 +94,7 @@ class Tensor {
virtual
void
*
Ptr
()
const
=
0
;
virtual
void
*
Ptr
()
const
=
0
;
virtual
paddle
::
platform
::
Place
Place
()
const
=
0
;
virtual
paddle
::
platform
::
Place
Place
()
const
=
0
;
virtual
size_t
Size
()
const
=
0
;
virtual
size_t
Size
()
const
=
0
;
virtual
size_t
TypeSize
()
const
=
0
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -85,6 +121,7 @@ class Tensor {
...
@@ -85,6 +121,7 @@ class Tensor {
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
void
*
Ptr
()
const
{
return
static_cast
<
void
*>
(
ptr_
.
get
());
}
virtual
size_t
Size
()
const
{
return
size_
;
}
virtual
size_t
Size
()
const
{
return
size_
;
}
virtual
paddle
::
platform
::
Place
Place
()
const
{
return
place_
;
}
virtual
paddle
::
platform
::
Place
Place
()
const
{
return
place_
;
}
virtual
size_t
TypeSize
()
const
{
return
sizeof
(
T
);
}
std
::
unique_ptr
<
T
,
Deleter
>
ptr_
;
std
::
unique_ptr
<
T
,
Deleter
>
ptr_
;
paddle
::
platform
::
Place
place_
;
// record the place of ptr_.
paddle
::
platform
::
Place
place_
;
// record the place of ptr_.
...
@@ -92,6 +129,8 @@ class Tensor {
...
@@ -92,6 +129,8 @@ class Tensor {
};
};
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holds the memory block if allocated.
std
::
shared_ptr
<
Placeholder
>
holder_
;
// holds the memory block if allocated.
DDim
dims_
;
size_t
offset_
;
// marks the begin of tensor data area.
};
};
}
// namespace framework
}
// namespace framework
...
...
paddle/framework/tensor_test.cc
浏览文件 @
ca23d861
...
@@ -15,15 +15,27 @@
...
@@ -15,15 +15,27 @@
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include <string>
#include <string>
TEST
(
Tensor
,
ASSERT
)
{
TEST
(
Tensor
,
Dims
)
{
paddle
::
framework
::
Tensor
cpu_tensor
;
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
Tensor
tt
(
make_ddim
({
2
,
3
,
4
}));
DDim
dims
=
tt
.
dims
();
ASSERT_EQ
(
arity
(
dims
),
3
);
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
EXPECT_EQ
(
i
+
2
,
dims
[
i
]);
}
}
TEST
(
Tensor
,
DataAssert
)
{
paddle
::
framework
::
Tensor
src_tensor
;
bool
caught
=
false
;
bool
caught
=
false
;
try
{
try
{
const
double
*
p
__attribute__
((
unused
))
=
cpu
_tensor
.
data
<
double
>
();
src
_tensor
.
data
<
double
>
();
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
}
catch
(
paddle
::
framework
::
EnforceNotMet
err
)
{
caught
=
true
;
caught
=
true
;
std
::
string
msg
=
"Tensor::data must be called after Tensor::mutable_data."
;
std
::
string
msg
=
"Tenosr has not been initialized. Call Tensor::mutable_data first."
;
const
char
*
what
=
err
.
what
();
const
char
*
what
=
err
.
what
();
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
msg
.
length
();
++
i
)
{
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
ASSERT_EQ
(
what
[
i
],
msg
[
i
]);
...
@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) {
...
@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) {
ASSERT_TRUE
(
caught
);
ASSERT_TRUE
(
caught
);
}
}
/*
mutable_data() is not tested
at present
/*
following tests are not available
at present
because Memory::Alloc() and Memory::Free() have not been ready.
because Memory::Alloc() and Memory::Free() have not been ready.
TEST(Tensor, MutableData) {
TEST(Tensor, MutableData) {
using namespace paddle::framework;
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::platform;
{
{
Tensor
cpu
_tensor;
Tensor
src
_tensor;
float* p1 = nullptr;
float* p1 = nullptr;
float* p2 = nullptr;
float* p2 = nullptr;
// initialization
// initialization
p1 =
cpu
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
p1 =
src
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
EXPECT_NE(p1, nullptr);
EXPECT_NE(p1, nullptr);
// set
cpu
_tensor a new dim with large size
// set
src
_tensor a new dim with large size
// momery is supposed to be re-allocated
// momery is supposed to be re-allocated
p2 =
cpu_tensor.mutable_data<float>(make_ddim({3, 4}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace(
));
EXPECT_NE(p2, nullptr);
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
EXPECT_NE(p1, p2);
// set
cpu
_tensor a new dim with same size
// set
src
_tensor a new dim with same size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p1 =
cpu_tensor.mutable_data<float>(make_ddim({2, 2, 3}
));
p1 =
src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
// set
cpu
_tensor a new dim with smaller size
// set
src
_tensor a new dim with smaller size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p2 =
cpu_tensor.mutable_data<float>(make_ddim({2, 2}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
}
}
{
{
Tensor
gpu
_tensor;
Tensor
src
_tensor;
float* p1 = nullptr;
float* p1 = nullptr;
float* p2 = nullptr;
float* p2 = nullptr;
// initialization
// initialization
p1 =
gpu
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
p1 =
src
_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
EXPECT_NE(p1, nullptr);
EXPECT_NE(p1, nullptr);
// set
gpu
_tensor a new dim with large size
// set
src
_tensor a new dim with large size
// momery is supposed to be re-allocated
// momery is supposed to be re-allocated
p2 =
gpu_tensor.mutable_data<float>(make_ddim({3, 4}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({3, 4}), GPUPlace(
));
EXPECT_NE(p2, nullptr);
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
EXPECT_NE(p1, p2);
// set
gpu
_tensor a new dim with same size
// set
src
_tensor a new dim with same size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p1 =
gpu_tensor.mutable_data<float>(make_ddim({2, 2, 3}
));
p1 =
src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), GPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
// set
gpu
_tensor a new dim with smaller size
// set
src
_tensor a new dim with smaller size
// momery block is supposed to be unchanged
// momery block is supposed to be unchanged
p2 =
gpu_tensor.mutable_data<float>(make_ddim({2, 2}
));
p2 =
src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace(
));
EXPECT_EQ(p1, p2);
EXPECT_EQ(p1, p2);
}
}
}
}
*/
TEST(Tensor, ShareDataFrom) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
dst_tensor.ShareDataFrom(src_tensor);
} catch (EnforceNotMet err) {
caught = true;
std::string msg = "Can not share data from an uninitialized tensor.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
{
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
}
TEST(Tensor, Slice) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
src_tensor.mutable_data<int>(make_ddim({5, 3, 4}), CPUPlace());
Tensor slice_tensor = src_tensor.Slice(1, 3);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
EXPECT_EQ(slice_dims[2], 4);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<int>(slice_tensor.dims(), CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), GPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), GPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
}
}
*/
\ No newline at end of file
paddle/operators/.clang-format
已删除
100644 → 0
浏览文件 @
4d336d90
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
paddle/operators/CMakeLists.txt
已删除
100644 → 0
浏览文件 @
4d336d90
paddle/operators/demo_op.h
已删除
100644 → 0
浏览文件 @
4d336d90
#pragma once
#include "paddle/framework/op_registry.h"
using
namespace
paddle
::
framework
;
namespace
paddle
{
namespace
operators
{
class
CosineOp
:
public
OperatorWithKernel
{
public:
void
Run
(
const
OpRunContext
*
context
)
const
override
{
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
}
};
class
CosineOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
CosineOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddType
(
"cos"
);
AddComment
(
"This is cos op"
);
}
};
REGISTER_OP
(
CosineOp
,
CosineOpProtoAndCheckerMaker
,
cos_sim
)
class
MyTestOp
:
public
OperatorWithKernel
{
public:
void
Run
(
const
OpRunContext
*
context
)
const
override
{
printf
(
"%s
\n
"
,
DebugString
().
c_str
());
}
};
class
MyTestOpProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
MyTestOpProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of cosine op"
);
AddOutput
(
"output"
,
"output of cosine op"
);
auto
my_checker
=
[](
int
i
)
{
PADDLE_ENFORCE
(
i
%
2
==
0
,
"'test_attr' must be even!"
);
};
AddAttr
<
int
>
(
"test_attr"
,
"a simple test attribute"
)
.
AddCustomChecker
(
my_checker
);
AddType
(
"my_test_op"
);
AddComment
(
"This is my_test op"
);
}
};
REGISTER_OP
(
MyTestOp
,
MyTestOpProtoAndCheckerMaker
,
my_test_op
)
}
// namespace operators
}
// namespace operators
paddle/platform/cuda_device_context.h
浏览文件 @
ca23d861
...
@@ -23,15 +23,13 @@ limitations under the License. */
...
@@ -23,15 +23,13 @@ limitations under the License. */
#include "paddle/platform/place.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "unsupported/Eigen/CXX11/Tensor"
using
DEVICE_GPU
=
Eigen
::
GpuDevice
;
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
class
CUDADeviceContext
;
class
CUDADeviceContext
;
template
<
>
template
<
>
DEVICE_GPU
DeviceContext
::
get_eigen_device
<
DEVICE_GPU
>
()
{
Eigen
::
GpuDevice
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
{
return
static_cast
<
CUDADeviceContext
*>
(
this
)
->
eigen_handle
();
return
static_cast
<
CUDADeviceContext
*>
(
this
)
->
eigen_handle
();
}
}
...
@@ -59,6 +57,11 @@ class CUDADeviceContext : public DeviceContext {
...
@@ -59,6 +57,11 @@ class CUDADeviceContext : public DeviceContext {
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
eigen_device_
=
new
Eigen
::
GpuDevice
(
eigen_stream_
);
}
}
Place
GetPlace
()
const
override
{
Place
retv
=
GPUPlace
();
return
retv
;
}
void
Wait
()
{
void
Wait
()
{
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
paddle
::
platform
::
throw_on_error
(
cudaStreamSynchronize
(
stream_
),
"cudaStreamSynchronize failed"
);
"cudaStreamSynchronize failed"
);
...
...
paddle/platform/device_context.h
浏览文件 @
ca23d861
...
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/framework/enforce.h"
#include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor"
#include "unsupported/Eigen/CXX11/Tensor"
using
DEVICE_CPU
=
Eigen
::
DefaultDevice
;
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -28,10 +28,12 @@ class DeviceContext {
...
@@ -28,10 +28,12 @@ class DeviceContext {
template
<
typename
DeviceType
>
template
<
typename
DeviceType
>
DeviceType
get_eigen_device
();
DeviceType
get_eigen_device
();
virtual
Place
GetPlace
()
const
=
0
;
};
};
template
<
>
template
<
>
DEVICE_CPU
DeviceContext
::
get_eigen_device
<
DEVICE_CPU
>
()
{
Eigen
::
DefaultDevice
DeviceContext
::
get_eigen_device
<
Eigen
::
DefaultDevice
>
()
{
return
static_cast
<
CPUDeviceContext
*>
(
this
)
->
eigen_handle
();
return
static_cast
<
CPUDeviceContext
*>
(
this
)
->
eigen_handle
();
}
}
...
@@ -44,9 +46,13 @@ class CPUDeviceContext : public DeviceContext {
...
@@ -44,9 +46,13 @@ class CPUDeviceContext : public DeviceContext {
return
*
eigen_handle_
;
return
*
eigen_handle_
;
}
}
Place
GetPlace
()
const
override
{
Place
retv
=
CPUPlace
();
return
retv
;
}
private:
private:
Eigen
::
DefaultDevice
*
eigen_handle_
{
nullptr
};
Eigen
::
DefaultDevice
*
eigen_handle_
{
nullptr
};
};
};
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录