Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1dc53a28
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1dc53a28
编写于
7月 18, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use friend not to expose tensor's `type/place`
上级
a89c7ffa
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
21 addition
and
15 deletion
+21
-15
paddle/framework/tensor.h
paddle/framework/tensor.h
+9
-5
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+1
-3
paddle/pybind/tensor_bind.h
paddle/pybind/tensor_bind.h
+11
-7
未找到文件。
paddle/framework/tensor.h
浏览文件 @
1dc53a28
...
...
@@ -24,6 +24,12 @@ limitations under the License. */
#include "paddle/platform/place.h"
namespace
paddle
{
namespace
pybind
{
namespace
details
{
// forward declare
template
<
bool
less
,
size_t
i
,
typename
...
args
>
struct
CastToPyBufferImpl
;
}
// namespace details
}
// namespace pybind
namespace
framework
{
class
Tensor
{
...
...
@@ -128,10 +134,6 @@ class Tensor {
DDim
dims
()
const
{
return
dims_
;
}
platform
::
Place
place
()
const
{
return
holder_
->
place
();
}
std
::
type_index
type
()
const
{
return
holder_
->
type
();
}
private:
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
...
...
@@ -186,7 +188,9 @@ class Tensor {
DDim
dims_
;
size_t
numel_
;
// cache of `product(dims_)`
size_t
offset_
;
// marks the begin of tensor data area.
};
// namespace framework
template
<
bool
less
,
size_t
i
,
typename
...
args
>
friend
struct
paddle
::
pybind
::
details
::
CastToPyBufferImpl
;
};
// namespace framework
}
// namespace framework
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
1dc53a28
...
...
@@ -15,7 +15,7 @@ limitations under the License. */
#include <Python.h>
#include <paddle/framework/op_registry.h>
#include <paddle/framework/scope.h>
#include <paddle/pybind/tensor.h>
#include <paddle/pybind/tensor
_bind
.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
...
...
@@ -32,8 +32,6 @@ PYBIND11_PLUGIN(core) {
py
::
class_
<
pd
::
Tensor
>
(
m
,
"Tensor"
,
py
::
buffer_protocol
())
.
def_buffer
([](
pd
::
Tensor
&
self
)
->
py
::
buffer_info
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
self
.
place
()),
"Only CPU tensor can cast to numpy array"
);
return
paddle
::
pybind
::
CastToPyBuffer
(
self
);
})
.
def
(
"get_dims"
,
...
...
paddle/pybind/tensor.h
→
paddle/pybind/tensor
_bind
.h
浏览文件 @
1dc53a28
...
...
@@ -40,7 +40,10 @@ template <size_t I, typename... ARGS>
struct
CastToPyBufferImpl
<
true
,
I
,
ARGS
...
>
{
using
CUR_TYPE
=
typename
std
::
tuple_element
<
I
,
std
::
tuple
<
ARGS
...
>>::
type
;
py
::
buffer_info
operator
()(
framework
::
Tensor
&
tensor
)
{
if
(
std
::
type_index
(
typeid
(
CUR_TYPE
))
==
tensor
.
type
())
{
PADDLE_ENFORCE
(
paddle
::
platform
::
is_cpu_place
(
tensor
.
holder_
->
place
()),
"Only CPU tensor can cast to numpy array"
);
if
(
std
::
type_index
(
typeid
(
CUR_TYPE
))
==
tensor
.
holder_
->
type
())
{
auto
dim_vec
=
framework
::
vectorize
(
tensor
.
dims
());
std
::
vector
<
size_t
>
dims_outside
;
std
::
vector
<
size_t
>
strides
;
...
...
@@ -54,12 +57,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod
*=
dims_outside
[
i
-
1
];
}
return
py
::
buffer_info
(
tensor
.
mutable_data
<
CUR_TYPE
>
(
tensor
.
place
()),
sizeof
(
CUR_TYPE
),
py
::
format_descriptor
<
CUR_TYPE
>::
format
(),
(
size_t
)
framework
::
arity
(
tensor
.
dims
()),
dims_outside
,
strides
);
return
py
::
buffer_info
(
tensor
.
mutable_data
<
CUR_TYPE
>
(
tensor
.
holder_
->
place
()),
sizeof
(
CUR_TYPE
),
py
::
format_descriptor
<
CUR_TYPE
>::
format
(),
(
size_t
)
framework
::
arity
(
tensor
.
dims
()),
dims_outside
,
strides
);
}
else
{
constexpr
bool
less
=
I
+
1
<
std
::
tuple_size
<
std
::
tuple
<
ARGS
...
>>::
value
;
return
CastToPyBufferImpl
<
less
,
I
+
1
,
ARGS
...
>
()(
tensor
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录