Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9c9e635c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9c9e635c
编写于
5月 21, 2020
作者:
L
Leo Chen
提交者:
GitHub
5月 21, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support tensor to varbase, test=develop (#24660)
上级
fdbe114b
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
40 addition
and
3 deletion
+40
-3
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+25
-0
python/paddle/fluid/dygraph/base.py
python/paddle/fluid/dygraph/base.py
+4
-2
python/paddle/fluid/tests/unittests/test_imperative_basic.py
python/paddle/fluid/tests/unittests/test_imperative_basic.py
+4
-1
python/paddle/fluid/tests/unittests/test_var_base.py
python/paddle/fluid/tests/unittests/test_var_base.py
+7
-0
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
9c9e635c
...
...
@@ -101,6 +101,7 @@ static void InitTensorForVarBase(imperative::VarBase *self,
static
void
InitVarBaseFromNumpyWithKwargs
(
imperative
::
VarBase
*
self
,
const
py
::
kwargs
&
kwargs
)
{
VLOG
(
4
)
<<
"Init VarBase"
;
PADDLE_ENFORCE_EQ
(
kwargs
.
contains
(
"value"
),
true
,
platform
::
errors
::
NotFound
(
...
...
@@ -126,6 +127,7 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
bool
persistable
=
false
,
bool
zero_copy
=
false
,
std
::
string
name
=
""
)
{
VLOG
(
4
)
<<
"Init VarBase"
;
// 0: self, 1: value, 2: place, 3: persistable, 4: zero_copy, 5: name
if
(
name
==
""
)
{
name
=
imperative
::
GetCurrentTracer
()
->
GenerateUniqueName
(
"generated_var"
);
...
...
@@ -140,10 +142,31 @@ static void InitVarBaseFromNumpyWithArg(imperative::VarBase *self,
static
void
InitVarBaseFromNumpyWithArgDefault
(
imperative
::
VarBase
*
self
,
const
py
::
array
&
array
)
{
VLOG
(
4
)
<<
"Init VarBase"
;
auto
place
=
imperative
::
GetCurrentTracer
()
->
ExpectedPlace
();
InitTensorForVarBase
(
self
,
array
,
place
);
}
static
void
InitVarBaseFromTensorWithArgDefault
(
imperative
::
VarBase
*
self
,
const
framework
::
LoDTensor
&
tensor
)
{
VLOG
(
4
)
<<
"Init VarBase"
;
auto
place
=
imperative
::
GetCurrentTracer
()
->
ExpectedPlace
();
new
(
self
)
imperative
::
VarBase
(
imperative
::
GetCurrentTracer
()
->
GenerateUniqueName
(
"generated_var"
));
self
->
SetPersistable
(
false
);
self
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
self
->
SetDataType
(
tensor
.
type
());
auto
*
new_tensor
=
self
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
// Same place,share data directly
if
(
place
==
tensor
.
place
())
{
new_tensor
->
ShareDataWith
(
tensor
);
VLOG
(
4
)
<<
"Same place, do ShareDataWith"
;
}
else
{
framework
::
TensorCopy
(
tensor
,
place
,
new_tensor
);
VLOG
(
4
)
<<
"Different place, do TensorCopy"
;
}
}
static
std
::
string
GetTypeName
(
const
imperative
::
VarBase
&
var
)
{
if
(
var
.
Type
()
==
framework
::
proto
::
VarType
::
RAW
)
{
return
"RAW"
;
...
...
@@ -520,6 +543,7 @@ void BindImperative(py::module *m_ptr) {
[](
imperative
::
VarBase
&
self
,
framework
::
proto
::
VarType
::
Type
dtype
,
const
std
::
vector
<
int
>
&
dims
,
const
py
::
handle
&
name
,
framework
::
proto
::
VarType
::
Type
type
,
bool
persistable
)
{
VLOG
(
4
)
<<
"Init VarBase"
;
std
::
string
act_name
=
""
;
if
(
!
name
.
ptr
()
||
name
.
ptr
()
==
Py_None
)
{
act_name
=
imperative
::
GetCurrentTracer
()
->
GenerateUniqueName
(
...
...
@@ -547,6 +571,7 @@ void BindImperative(py::module *m_ptr) {
py
::
arg
(
"value"
),
py
::
arg
(
"place"
),
py
::
arg
(
"persistable"
)
=
false
,
py
::
arg
(
"zero_copy"
)
=
false
,
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArgDefault
,
py
::
arg
(
"value"
))
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArgDefault
,
py
::
arg
(
"tensor"
))
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithKwargs
)
.
def
(
"__getitem__"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
_index
)
{
...
...
python/paddle/fluid/dygraph/base.py
浏览文件 @
9c9e635c
...
...
@@ -538,8 +538,8 @@ def to_variable(value, name=None, zero_copy=None):
numpy\.ndarray, Variable or ComplexVariable object.
Parameters:
value(ndarray|Variable|ComplexVariable): The numpy\.ndarray, Variable
or ComplexVariable object that needs to be converted, it can be
value(ndarray|Variable|
Tensor|
ComplexVariable): The numpy\.ndarray, Variable
Tensor
or ComplexVariable object that needs to be converted, it can be
multi-dimension, and the data type is one of numpy\.{float16,
float32, float64, int16, int32, int64, uint8, uint16, complex64,
complex128}.
...
...
@@ -611,6 +611,8 @@ def to_variable(value, name=None, zero_copy=None):
elif
isinstance
(
value
,
(
core
.
VarBase
,
framework
.
Variable
,
framework
.
ComplexVariable
)):
return
value
elif
isinstance
(
value
,
(
core
.
Tensor
,
core
.
LoDTensor
)):
return
core
.
VarBase
(
value
)
else
:
raise
TypeError
(
"The type of input value is invalid, expected type is 'ndarray', "
...
...
python/paddle/fluid/tests/unittests/test_imperative_basic.py
浏览文件 @
9c9e635c
...
...
@@ -240,18 +240,22 @@ class TestImperative(unittest.TestCase):
def
test_create_VarBase
(
self
):
x
=
np
.
ones
([
2
,
2
],
np
.
float32
)
y
=
np
.
zeros
([
3
,
3
],
np
.
float32
)
t
=
fluid
.
Tensor
()
t
.
set
(
x
,
fluid
.
CPUPlace
())
with
fluid
.
dygraph
.
guard
():
tmp
=
fluid
.
core
.
VarBase
(
value
=
x
,
place
=
fluid
.
core
.
CPUPlace
())
tmp2
=
fluid
.
core
.
VarBase
(
y
,
fluid
.
core
.
CPUPlace
())
tmp3
=
fluid
.
dygraph
.
base
.
to_variable
(
x
)
tmp4
=
fluid
.
core
.
VarBase
(
y
)
tmp5
=
fluid
.
core
.
VarBase
(
value
=
x
)
tmp6
=
fluid
.
core
.
VarBase
(
t
)
self
.
assertTrue
(
np
.
array_equal
(
x
,
tmp
.
numpy
()))
self
.
assertTrue
(
np
.
array_equal
(
y
,
tmp2
.
numpy
()))
self
.
assertTrue
(
np
.
array_equal
(
x
,
tmp3
.
numpy
()))
self
.
assertTrue
(
np
.
array_equal
(
y
,
tmp4
.
numpy
()))
self
.
assertTrue
(
np
.
array_equal
(
x
,
tmp5
.
numpy
()))
self
.
assertTrue
(
np
.
array_equal
(
x
,
tmp6
.
numpy
()))
def
test_no_grad_guard
(
self
):
data
=
np
.
array
([[
2
,
3
],
[
4
,
5
]]).
astype
(
'float32'
)
...
...
@@ -384,7 +388,6 @@ class TestImperative(unittest.TestCase):
var_inp
=
fluid
.
dygraph
.
base
.
to_variable
(
np_inp
)
var_inp
.
stop_gradient
=
False
l
=
MyLayer
()
print
(
var_inp
)
x
=
l
(
var_inp
)[
0
]
self
.
assertIsNotNone
(
x
)
dy_out
=
x
.
numpy
()
...
...
python/paddle/fluid/tests/unittests/test_var_base.py
浏览文件 @
9c9e635c
...
...
@@ -47,6 +47,13 @@ class TestVarBase(unittest.TestCase):
linear
=
fluid
.
dygraph
.
Linear
(
32
,
64
)
var
=
linear
.
_helper
.
to_variable
(
"test"
,
name
=
"abc"
)
def
test_tensor_to_variable
(
self
):
with
fluid
.
dygraph
.
guard
():
t
=
fluid
.
Tensor
()
t
.
set
(
np
.
ndarray
([
5
,
30
]),
fluid
.
CPUPlace
())
var
=
fluid
.
dygraph
.
to_variable
(
t
)
self
.
assertTrue
(
np
.
array_equal
(
t
,
var
.
numpy
()))
def
test_write_property
(
self
):
with
fluid
.
dygraph
.
guard
():
var
=
fluid
.
dygraph
.
to_variable
(
self
.
array
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录