Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3f630658
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3f630658
编写于
8月 01, 2023
作者:
W
wanghuancoder
提交者:
GitHub
8月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
test del python varbase (#55788)
* del python varbase
上级
d7aef892
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
0 addition
and
1645 deletion
+0
-1645
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+0
-1495
paddle/fluid/pybind/op_function_common.cc
paddle/fluid/pybind/op_function_common.cc
+0
-136
paddle/fluid/pybind/op_function_common.h
paddle/fluid/pybind/op_function_common.h
+0
-14
未找到文件。
paddle/fluid/pybind/imperative.cc
浏览文件 @
3f630658
...
...
@@ -69,7 +69,6 @@ namespace paddle {
namespace
pybind
{
std
::
atomic
<
int
>
VarBaseUniqueNameID
{
0
};
PyTypeObject
*
g_varbase_pytype
=
nullptr
;
namespace
py
=
::
pybind11
;
...
...
@@ -646,1500 +645,6 @@ void BindImperative(py::module *m_ptr) {
egr
::
Controller
::
Instance
().
SetCurrentTracer
(
tracer
);
imperative
::
SetCurrentTracer
(
tracer
);
});
py
::
class_
<
imperative
::
VarBase
,
std
::
shared_ptr
<
imperative
::
VarBase
>>
varbase
(
m
,
"VarBase"
,
R"DOC()DOC"
);
g_varbase_pytype
=
(
PyTypeObject
*
)
varbase
.
ptr
();
// NOLINT
varbase
.
def_static
(
"_alive_vars"
,
&
imperative
::
VarBase
::
AliveVarNames
)
.
def
(
"__init__"
,
[](
imperative
::
VarBase
&
self
)
{
std
::
string
name
=
imperative
::
GetCurrentTracer
()
->
GenerateUniqueName
(
"generated_tensor"
);
new
(
&
self
)
imperative
::
VarBase
(
name
);
})
.
def
(
"__init__"
,
[](
imperative
::
VarBase
&
self
,
framework
::
proto
::
VarType
::
Type
dtype
,
const
std
::
vector
<
int64_t
>
&
dims
,
const
py
::
handle
&
name
,
framework
::
proto
::
VarType
::
Type
type
,
bool
persistable
)
{
VLOG
(
4
)
<<
"Init VarBase"
;
std
::
string
act_name
=
""
;
if
(
!
name
.
ptr
()
||
name
.
ptr
()
==
Py_None
)
{
act_name
=
imperative
::
GetCurrentTracer
()
->
GenerateUniqueName
(
"generated_tensor"
);
}
else
{
act_name
=
name
.
cast
<
std
::
string
>
();
}
new
(
&
self
)
imperative
::
VarBase
(
act_name
);
self
.
SetPersistable
(
persistable
);
self
.
SetType
(
type
);
self
.
SetDataType
(
dtype
);
if
(
type
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
auto
*
tensor
=
self
.
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor
->
Resize
(
phi
::
make_ddim
(
dims
));
}
})
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArg
<
platform
::
CPUPlace
>
,
py
::
arg
(
"value"
),
py
::
arg
(
"place"
),
py
::
arg
(
"persistable"
)
=
false
,
py
::
arg
(
"zero_copy"
)
=
false
,
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"stop_gradient"
)
=
-
1
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArg
<
platform
::
XPUPlace
>
,
py
::
arg
(
"value"
),
py
::
arg
(
"place"
),
py
::
arg
(
"persistable"
)
=
false
,
py
::
arg
(
"zero_copy"
)
=
false
,
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"stop_gradient"
)
=
-
1
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArg
<
platform
::
CUDAPlace
>
,
py
::
arg
(
"value"
),
py
::
arg
(
"place"
),
py
::
arg
(
"persistable"
)
=
false
,
py
::
arg
(
"zero_copy"
)
=
false
,
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"stop_gradient"
)
=
-
1
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArg
<
platform
::
CUDAPinnedPlace
>
,
py
::
arg
(
"value"
),
py
::
arg
(
"place"
),
py
::
arg
(
"persistable"
)
=
false
,
py
::
arg
(
"zero_copy"
)
=
false
,
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"stop_gradient"
)
=
-
1
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArg
<
platform
::
CustomPlace
>
,
py
::
arg
(
"value"
),
py
::
arg
(
"place"
),
py
::
arg
(
"persistable"
)
=
false
,
py
::
arg
(
"zero_copy"
)
=
false
,
py
::
arg
(
"name"
)
=
""
,
py
::
arg
(
"stop_gradient"
)
=
-
1
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithArgDefault
,
py
::
arg
(
"value"
))
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArgDefault
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArg
<
platform
::
CPUPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArg
<
platform
::
XPUPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArg
<
platform
::
CUDAPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArg
<
platform
::
CUDAPinnedPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromTensorWithArg
<
platform
::
CustomPlace
>
,
py
::
arg
(
"tensor"
),
py
::
arg
(
"place"
),
py
::
arg
(
"name"
)
=
""
)
.
def
(
"__init__"
,
&
InitVarBaseFromNumpyWithKwargs
)
.
def
(
"__setitem_varbase__"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
_index
,
py
::
object
&
value_obj
)
{
VLOG
(
4
)
<<
"Call __setitem_varbase__"
;
auto
self_tensor
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
// NOTE(zhiqiu): PyTuple_Pack increases refcount while PyTuple_New
// https://github.com/python/cpython/blob/24b63c695ae0a95b06379eaadace66735abac1e2/Objects/tupleobject.c#L251
PyObject
*
index_ptr
=
!
PyTuple_Check
(
_index
.
ptr
())
?
PyTuple_Pack
(
1
,
_index
.
ptr
())
:
_index
.
ptr
();
DEFINE_PADDLE_SCOPE_GUARD
([
index_ptr
,
&
_index
]()
{
if
(
!
PyTuple_Check
(
_index
.
ptr
()))
{
Py_DECREF
(
index_ptr
);
VLOG
(
4
)
<<
"Call Py_DECREF"
;
}
});
auto
is_tensor
=
[](
py
::
handle
var
)
{
if
(
!
var
.
ptr
()
||
var
.
ptr
()
==
Py_None
)
{
return
false
;
}
try
{
py
::
cast
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
var
);
return
true
;
}
catch
(
py
::
cast_error
&
)
{
return
false
;
}
};
// NOTE(liym27):
// Increase the version of VarBase self because __setitem__ is an
// inplace operator for the VarBase self.
self
->
BumpInplaceVersion
();
// 1. Check arguments
bool
parse_index
=
true
;
// Check whether _index can be parsed.
const
int
size
=
PyTuple_GET_SIZE
(
index_ptr
);
for
(
int
dim
=
0
;
dim
<
size
;
++
dim
)
{
PyObject
*
slice_item
=
PyTuple_GetItem
(
index_ptr
,
dim
);
if
(
!
(
PyCheckInteger
(
slice_item
)
||
PySlice_Check
(
slice_item
)
||
slice_item
==
Py_Ellipsis
||
slice_item
==
Py_None
))
{
parse_index
=
false
;
break
;
}
}
// 2. Call op set_value to speed up if the condition is met,
// otherwise call TensorToPyArray.
// TODO(liym27): Try not to call TensorToPyArray because it always
// copys data to cpu place, which reduces performance.
if
(
parse_index
)
{
std
::
vector
<
int
>
axes
,
starts
,
ends
,
steps
,
decrease_axes
,
none_axes
,
infer_flags
;
std
::
vector
<
int64_t
>
list_select_idxs
;
// if index is a list, list_select_flag will be true
bool
list_select_flag
=
false
;
ParseIndexingSlice
(
self_tensor
,
index_ptr
,
&
axes
,
&
starts
,
&
ends
,
&
steps
,
&
decrease_axes
,
&
none_axes
,
&
infer_flags
,
&
list_select_idxs
,
&
list_select_flag
);
framework
::
AttributeMap
attrs
=
{{
"axes"
,
axes
},
{
"starts"
,
starts
},
{
"ends"
,
ends
},
{
"steps"
,
steps
},
{
"decrease_axes"
,
decrease_axes
},
{
"none_axes"
,
none_axes
}};
imperative
::
NameVarBaseMap
ins
=
{{
"Input"
,
{
self
}}};
imperative
::
NameVarBaseMap
outs
=
{{
"Out"
,
{
self
}}};
const
auto
&
tracer
=
imperative
::
GetCurrentTracer
();
if
(
tracer
->
HasGrad
())
{
PADDLE_ENFORCE_EQ
(
self
->
IsLeaf
()
&&
!
self
->
OverridedStopGradient
(),
false
,
platform
::
errors
::
InvalidArgument
(
"Leaf Tensor (%s) that doesn't stop gradient can't use "
"inplace strategy."
,
self
->
Name
()));
}
if
(
py
::
isinstance
<
imperative
::
VarBase
>
(
value_obj
.
ptr
()))
{
auto
value_tensor
=
value_obj
.
cast
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
();
ins
.
insert
({
"ValueTensor"
,
{
value_tensor
}});
// pass the stop_gradient from value to tensor
if
(
!
value_tensor
->
OverridedStopGradient
()
&&
self
->
OverridedStopGradient
())
{
self
->
SetOverridedStopGradient
(
false
);
}
}
else
if
(
py
::
isinstance
<
py
::
array
>
(
value_obj
))
{
auto
value_tensor
=
std
::
shared_ptr
<
imperative
::
VarBase
>
(
new
imperative
::
VarBase
(
false
,
tracer
->
GenerateUniqueName
()));
py
::
object
value
=
value_obj
;
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
float
>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
float
>
(
value_obj
);
}
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
double
>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
double
>
(
value_obj
);
}
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
INT32
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
int32_t
>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
int32_t
>
(
value_obj
);
}
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
INT64
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
int64_t
>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
int64_t
>
(
value_obj
);
}
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
BOOL
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
bool
>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
bool
>
(
value_obj
);
}
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
COMPLEX64
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
std
::
complex
<
float
>>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
std
::
complex
<
float
>>
(
value_obj
);
}
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
COMPLEX128
)
{
if
(
!
py
::
isinstance
<
py
::
array_t
<
std
::
complex
<
double
>>>
(
value_obj
))
{
value
=
pybind11
::
detail
::
CastNumpyArray
<
std
::
complex
<
double
>>
(
value_obj
);
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, float64, complex64, complex128, int32 or "
"int64, "
"please check the type of tensor."
));
}
SetTensorFromPyArray
(
value_tensor
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
(),
value
,
self
->
Place
(),
false
);
ins
.
insert
({
"ValueTensor"
,
{
value_tensor
}});
}
else
{
// convert the value to self data type
if
(
py
::
isinstance
<
py
::
float_
>
(
value_obj
)
||
py
::
isinstance
<
py
::
int_
>
(
value_obj
)
||
py
::
isinstance
<
py
::
bool_
>
(
value_obj
)
||
PyComplex_Check
(
value_obj
.
ptr
()))
{
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
FP32
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
float
>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
FP64
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
double
>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
INT32
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
int32_t
>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
INT64
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
int64_t
>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
BOOL
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
bool
>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
FP16
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
float
>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
COMPLEX64
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
std
::
complex
<
float
>>
()};
}
else
if
(
self
->
DataType
()
==
framework
::
proto
::
VarType
::
COMPLEX128
)
{
attrs
[
"values"
]
=
std
::
vector
<
paddle
::
experimental
::
Scalar
>
{
value_obj
.
cast
<
std
::
complex
<
double
>>
()};
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, float64, complex64, complex128, int32, int64 "
"or float16, "
"please check the type of tensor."
));
}
attrs
[
"shape"
]
=
std
::
vector
<
int64_t
>
{
1
};
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Value type error. The assign value allows "
"numpy.ndarray, integer, float or bool, "
"but received %s."
,
Py_TYPE
(
value_obj
.
ptr
())));
}
}
{
// Release gil and do tracing
py
::
gil_scoped_release
release
;
tracer
->
TraceOp
(
"set_value"
,
ins
,
outs
,
std
::
move
(
attrs
),
{{
"Input"
,
"Out"
}});
}
}
else
{
auto
self_numpy
=
TensorToPyArray
(
*
self_tensor
);
VLOG
(
4
)
<<
"parse_index is false"
;
if
(
is_tensor
(
_index
))
{
VLOG
(
4
)
<<
"index is tensor"
;
auto
index_var
=
py
::
cast
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
(
_index
);
auto
index_tensor
=
index_var
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
index_numpy
=
TensorToPyArray
(
*
index_tensor
);
self_numpy
[
index_numpy
]
=
value_obj
;
}
else
{
VLOG
(
4
)
<<
"index is not tensor"
;
self_numpy
[
_index
]
=
value_obj
;
}
SetTensorFromPyArray
(
self_tensor
,
self_numpy
,
self_tensor
->
place
(),
false
);
}
})
.
def
(
"_getitem_index_not_tensor"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
_index
)
{
VLOG
(
4
)
<<
"Call _getitem_index_not_tensor"
;
std
::
vector
<
int
>
slice_axes
,
slice_starts
,
slice_ends
,
slice_strides
,
decrease_axis
,
none_axes
,
infer_flags
;
std
::
vector
<
int64_t
>
list_select_idxs
;
// if index is a list, list_select_flag will be true
bool
list_select_flag
=
false
;
auto
tensor
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
ParseIndexingSlice
(
tensor
,
_index
.
ptr
(),
&
slice_axes
,
&
slice_starts
,
&
slice_ends
,
&
slice_strides
,
&
decrease_axis
,
&
none_axes
,
&
infer_flags
,
&
list_select_idxs
,
&
list_select_flag
);
// release gil and do tracing
py
::
gil_scoped_release
release
;
const
auto
&
tracer
=
imperative
::
GetCurrentTracer
();
auto
out
=
slice_axes
.
empty
()
&&
!
list_select_flag
?
self
:
std
::
shared_ptr
<
imperative
::
VarBase
>
(
new
imperative
::
VarBase
(
tracer
->
GenerateUniqueName
()));
if
(
!
slice_axes
.
empty
())
{
imperative
::
NameVarBaseMap
ins
=
{{
"Input"
,
{
self
}}};
framework
::
AttributeMap
attrs
=
{
{
"axes"
,
slice_axes
},
{
"starts"
,
slice_starts
},
{
"ends"
,
slice_ends
},
{
"infer_flags"
,
infer_flags
},
{
"decrease_axis"
,
decrease_axis
}};
imperative
::
NameVarBaseMap
outs
=
{{
"Out"
,
{
out
}}};
std
::
string
op_type
=
"slice"
;
for
(
auto
stride
:
slice_strides
)
{
if
(
stride
!=
1
)
{
op_type
=
"strided_slice"
;
attrs
.
insert
({
"strides"
,
slice_strides
});
attrs
.
erase
(
"decrease_axis"
);
break
;
}
}
tracer
->
TraceOp
(
op_type
,
ins
,
outs
,
std
::
move
(
attrs
));
}
bool
set_to_1d
=
FLAGS_set_to_1d
;
if
(
set_to_1d
)
{
// NOTE(zoooo0820): When all axes are decreased, the output
// will be 1-D with FLAGS_set_to_1d=True. In this case, one
// `None` should be pop out, otherwise the output shape will be
// not correct.
if
(
static_cast
<
int
>
(
decrease_axis
.
size
())
==
tensor
->
dims
().
size
())
{
VLOG
(
1
)
<<
"Warning: In Tensor '__getitem__', if the number "
"of scalar "
"elements "
"in the index is equal to the rank of the Tensor, "
"the output "
"should "
"be 0-D. In order to be consistent with the "
"behavior of previous "
"versions, it will be processed to 1-D. But it is "
"not correct and "
"will be "
"removed in release 2.6. "
"If 1-D is still wanted, please modify the index "
"element from "
"scalar to slice "
"(e.g. 'x[i]' => 'x[i:i+1]'). "
;
if
(
!
none_axes
.
empty
())
{
none_axes
.
pop_back
();
}
}
}
if
(
!
none_axes
.
empty
())
{
// Deal with cases that decrease_axes is not empty
// For example:
// # x.shape: (2,3,4)
// out = x[0, 0:2, None] # out.shape : (2, 1, 4)
for
(
auto
&
axis
:
none_axes
)
{
int
len
=
0
;
for
(
int
da
:
decrease_axis
)
{
if
(
da
<
axis
)
{
len
++
;
}
}
axis
-=
len
;
}
imperative
::
NameVarBaseMap
ins
=
{{
"X"
,
{
out
}}};
framework
::
AttributeMap
attrs
=
{{
"axes"
,
none_axes
}};
auto
new_out
=
std
::
shared_ptr
<
imperative
::
VarBase
>
(
new
imperative
::
VarBase
(
tracer
->
GenerateUniqueName
()));
auto
out_xshape
=
std
::
shared_ptr
<
imperative
::
VarBase
>
(
new
imperative
::
VarBase
(
tracer
->
GenerateUniqueName
()));
imperative
::
NameVarBaseMap
outs
=
{{
"Out"
,
{
new_out
}},
{
"XShape"
,
{
out_xshape
}}};
tracer
->
TraceOp
(
"unsqueeze2"
,
ins
,
outs
,
std
::
move
(
attrs
));
return
new_out
;
}
// the index is a list
if
(
list_select_flag
)
{
auto
select_index
=
std
::
shared_ptr
<
imperative
::
VarBase
>
(
new
imperative
::
VarBase
(
tracer
->
GenerateUniqueName
()));
auto
*
idx_tensor
=
select_index
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
dev_ctx
=
platform
::
DeviceContextPool
::
Instance
().
Get
(
tracer
->
ExpectedPlace
());
paddle
::
framework
::
TensorFromVector
(
list_select_idxs
,
*
dev_ctx
,
idx_tensor
);
imperative
::
NameVarBaseMap
ins
=
{{
"X"
,
{
self
}},
{
"Index"
,
{
select_index
}}};
imperative
::
NameVarBaseMap
outs
=
{{
"Out"
,
{
out
}}};
tracer
->
TraceOp
(
"index_select"
,
ins
,
outs
,
{{
"dim"
,
0
}});
}
return
out
;
})
.
def
(
"_getitem_from_offset"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
py
::
args
&
args
)
{
const
auto
&
tensor
=
self
->
Var
().
Get
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
tensor
.
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor of %s is Empty, please check if it has no data."
,
self
->
Name
()));
const
auto
&
tensor_dims
=
tensor
.
dims
();
std
::
vector
<
size_t
>
dims
(
tensor_dims
.
size
());
std
::
vector
<
size_t
>
strides
(
tensor_dims
.
size
());
size_t
numel
=
1
;
for
(
int
i
=
tensor_dims
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
strides
[
i
]
=
numel
;
dims
[
i
]
=
static_cast
<
size_t
>
(
tensor_dims
[
i
]);
numel
*=
dims
[
i
];
}
size_t
offset
=
0
;
if
(
args
.
empty
())
{
PADDLE_ENFORCE_EQ
(
numel
,
1
,
platform
::
errors
::
InvalidArgument
(
"only one element tensors can be converted to Python "
"scalars when no input coordinates"
));
}
else
if
(
args
.
size
()
==
1
)
{
offset
=
args
[
0
].
cast
<
size_t
>
();
PADDLE_ENFORCE_LT
(
offset
,
numel
,
platform
::
errors
::
InvalidArgument
(
"index %d is out of bounds for size %d"
,
offset
,
numel
));
}
else
{
PADDLE_ENFORCE_EQ
(
args
.
size
(),
dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"incorrect number of indices for Tensor"
));
for
(
size_t
i
=
0
;
i
<
args
.
size
();
++
i
)
{
size_t
index
=
args
[
i
].
cast
<
size_t
>
();
PADDLE_ENFORCE_LT
(
index
,
dims
[
i
],
platform
::
errors
::
InvalidArgument
(
"index %d is out fo bounds for axis %d with size %d"
,
index
,
i
,
dims
[
i
]));
offset
+=
index
*
strides
[
i
];
}
}
#define TENSOR_TO_PY_SCALAR(T, proto_type) \
if (framework::TransToProtoVarType(tensor.dtype()) == proto_type) { \
std::string py_dtype_str = details::TensorDTypeToPyDTypeStr(proto_type); \
T b = TensorGetElement<T>(tensor, offset); \
return py::array( \
py::dtype(py_dtype_str.c_str()), {}, {}, static_cast<void *>(&b)); \
}
_ForEachDataType_
(
TENSOR_TO_PY_SCALAR
);
#undef TENSOR_TO_PY_SCALAR
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported tensor data type: %s"
,
tensor
.
dtype
()));
},
py
::
return_value_policy
::
copy
)
.
def
(
"_inplace_version"
,
[](
imperative
::
VarBase
&
self
)
->
uint32_t
{
const
auto
&
var
=
self
.
MutableVar
();
PADDLE_ENFORCE_EQ
(
var
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor of %s is Empty, please check if it has no data."
,
self
.
Name
()));
return
var
->
CurrentInplaceVersion
();
})
.
def
(
"_bump_inplace_version"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
// NOTE(liym27): _bump_inplace_version is only used for inplace
// operation
self
->
BumpInplaceVersion
();
},
R"DOC(
**Notes**:
**This API is ONLY available in Dygraph mode.**
**This is a very low level API. Users should not use it directly. **
Bump the version whenever the Tensor is modified through an inplace operation.
)DOC"
)
.
def
(
"numpy"
,
[](
imperative
::
VarBase
&
self
)
->
py
::
array
{
const
auto
&
tensor
=
self
.
MutableVar
()
->
Get
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
tensor
.
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor of %s is Empty, please check if it has no data."
,
self
.
Name
()));
return
TensorToPyArray
(
tensor
,
true
);
},
R"DOC(
Returns a numpy array shows the value of current Tensor.
Returns:
ndarray: The numpy value of current Tensor.
Returns type:
ndarray: dtype is same as current Tensor
Examples:
.. code-block:: python
import paddle
import numpy as np
data = np.random.uniform(-1, 1, [30, 10, 32]).astype('float32')
linear = paddle.nn.Linear(32, 64)
data = paddle.to_tensor(data)
x = linear(data)
print(x.numpy())
)DOC"
)
.
def
(
"detach"
,
[](
const
imperative
::
VarBase
&
self
)
->
std
::
shared_ptr
<
imperative
::
VarBase
>
{
PADDLE_ENFORCE_EQ
(
self
.
Var
().
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
.
Name
()));
PADDLE_ENFORCE_EQ
(
self
.
Var
().
IsType
<
phi
::
DenseTensor
>
()
||
self
.
Var
().
IsType
<
phi
::
SelectedRows
>
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Type of Tensor[%s] must be LoDTensor or SelectedRows!"
,
self
.
Name
()));
auto
detach_var
=
std
::
make_shared
<
imperative
::
VarBase
>
(
true
,
"detach_"
+
self
.
Name
());
detach_var
->
SetPersistable
(
self
.
Persistable
());
detach_var
->
SetType
(
self
.
Type
());
detach_var
->
SetDataType
(
self
.
DataType
());
if
(
self
.
Var
().
IsType
<
phi
::
DenseTensor
>
())
{
const
auto
&
origin_tensor
=
self
.
Var
().
Get
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
origin_tensor
.
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
.
Name
()));
auto
*
detach_tensor
=
detach_var
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
detach_tensor
->
ShareDataWith
(
origin_tensor
);
// NOTE(liym27): Call ShareInplaceVersionCounterWith to share the
// same TensorInplaceVersion, which is used to check whether
// inplace
// operations are correct.
detach_tensor
->
ShareInplaceVersionCounterWith
(
origin_tensor
);
}
else
{
const
auto
&
origin_selected_rows
=
self
.
Var
().
Get
<
phi
::
SelectedRows
>
();
PADDLE_ENFORCE_EQ
(
origin_selected_rows
.
value
().
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
.
Name
()));
auto
*
detach_selected_rows
=
detach_var
->
MutableVar
()
->
GetMutable
<
phi
::
SelectedRows
>
();
detach_selected_rows
->
set_height
(
origin_selected_rows
.
height
());
detach_selected_rows
->
set_rows
(
origin_selected_rows
.
rows
());
detach_selected_rows
->
mutable_value
()
->
ShareDataWith
(
origin_selected_rows
.
value
());
detach_selected_rows
->
mutable_value
()
->
ShareInplaceVersionCounterWith
(
origin_selected_rows
.
value
());
}
VLOG
(
3
)
<<
"The detached Tensor("
<<
detach_var
->
Name
()
<<
") share data with "
<<
self
.
Name
();
return
detach_var
;
},
py
::
return_value_policy
::
take_ownership
,
R"DOC(
Returns a new Tensor, detached from the current graph.
It will share data with origin Tensor and always doesn't have a Tensor copy.
In addition, the detached Tensor doesn't provide gradient propagation.
Returns: The detached Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor([1.0], stop_gradient=False)
detach_x = x.detach()
detach_x[:] = 10.0
print(x) # Tensor(shape=[1], dtype=float32, place=CPUPlace, stop_gradient=False,
# [10.])
y = x**2
y.backward()
print(x.grad) # [20.0]
print(detach_x.grad) # None, 'stop_gradient=True' by default
detach_x.stop_gradient = False # Set stop_gradient to be False, supported auto-grad
z = detach_x**3
z.backward()
print(x.grad) # [20.0], detach_x is detached from x's graph, not affect each other
print(detach_x.grad) # [300.0], detach_x has its own graph
# Due to sharing of data with origin Tensor, There are some unsafe operations:
y = 2 * x
detach_x[:] = 5.0
y.backward()
# It will raise Error:
# one of the variables needed for gradient computation has been modified by an inplace operation.
)DOC"
)
.
def
(
"clear_gradient"
,
&
imperative
::
VarBase
::
ClearGradient
,
py
::
arg
(
"set_to_zero"
)
=
true
,
R"DOC(
Only for Tensor that has gradient, normally we use this for Parameters since other temporary Tensor doesen't has gradient.
The Gradient of current Tensor will be set to ``0`` .
Returns: None
Examples:
.. code-block:: python
import paddle
input = paddle.uniform([10, 2])
linear = paddle.nn.Linear(2, 3)
out = linear(input)
out.backward()
print("Before clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
linear.weight.clear_gradient()
print("After clear_gradient, linear.weight.grad: {}".format(linear.weight.grad))
)DOC"
)
.
def
(
"_gradient_set_empty"
,
&
imperative
::
VarBase
::
_GradientSetEmpty
,
py
::
arg
(
"set_is_empty"
)
=
true
)
.
def
(
"_is_gradient_set_empty"
,
&
imperative
::
VarBase
::
_IsGradientSetEmpty
)
.
def
(
"clone"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
const
auto
&
tensor
=
self
->
Var
().
Get
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
tensor
.
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"%s has not been initialized"
,
self
->
Name
()));
auto
tracer
=
imperative
::
GetCurrentTracer
();
auto
new_var
=
std
::
make_shared
<
imperative
::
VarBase
>
(
true
,
tracer
->
GenerateUniqueName
(
self
->
Name
()
+
"_clone"
));
framework
::
AttributeMap
attrs
;
imperative
::
NameVarBaseMap
ins
=
{{
"X"
,
{
self
}}};
imperative
::
NameVarBaseMap
outs
=
{{
"Out"
,
{
new_var
}}};
tracer
->
TraceOp
(
"assign"
,
ins
,
outs
,
attrs
);
return
new_var
;
},
py
::
return_value_policy
::
copy
,
R"DOC(
Returns a new Tensor, which is clone of origin Tensor, and it remains in the current graph.
It will always have a Tensor copy.
Tn addition, the cloned Tensor provides gradient propagation.
Returns: The cloned Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, stop_gradient=False)
clone_x = x.clone()
y = clone_x**2
y.backward()
print(clone_x.stop_gradient) # False
print(clone_x.grad) # [2.0], support gradient propagation
print(x.stop_gradient) # False
print(x.grad) # [2.0], clone_x support gradient propagation for x
x = paddle.to_tensor(1.0)
clone_x = x.clone()
clone_x.stop_gradient = False
z = clone_x**3
z.backward()
print(clone_x.stop_gradient) # False
print(clone_x.grad) # [3.0], support gradient propagation
print(x.stop_gradient) # True
print(x.grad) # None
)DOC"
)
.
def
(
"_grad_name"
,
&
imperative
::
VarBase
::
GradVarName
)
.
def
(
"_grad_value"
,
[](
imperative
::
VarBase
&
self
)
{
return
self
.
MutableGradVar
()
->
Get
<
phi
::
DenseTensor
>
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"_set_grad_type"
,
[](
imperative
::
VarBase
&
self
,
framework
::
proto
::
VarType
::
Type
type
)
{
self
.
MutableGradVarBase
()
->
SetType
(
type
);
})
.
def
(
"_reset_grad_inplace_version"
,
[](
imperative
::
VarBase
&
self
,
bool
set_to_zero
)
{
/*
*** This interfaceis a complete hack ***
reset_grad_inplace_version removes all inplace related records to
Grad VarBase/VariableWrapper,
the essential purpose of which is to let you use inplace operations
as if using its non-inplaced version,
which of course will cause unexpected consequences if not used with
care.
Make sure you fully understand what you're doing before make use of
this interface, and prepare for the worst.
*/
py
::
gil_scoped_release
release
;
if
(
self
.
HasGradVar
())
{
auto
grad_var
=
self
.
GradVarBase
();
auto
var_wrapper
=
grad_var
->
SharedVar
();
if
(
var_wrapper
)
{
var_wrapper
->
ResetInplaceVersion
(
set_to_zero
);
}
}
})
.
def
(
"_grad_ivar"
,
[](
const
imperative
::
VarBase
&
self
)
{
auto
&
grad_var
=
self
.
GradVarBase
();
if
(
grad_var
&&
grad_var
->
Var
().
IsInitialized
())
{
auto
*
tensor
=
grad_var
->
MutableVar
()
->
IsType
<
phi
::
DenseTensor
>
()
?
grad_var
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
()
:
grad_var
->
MutableVar
()
->
GetMutable
<
phi
::
SelectedRows
>
()
->
mutable_value
();
if
(
tensor
->
IsInitialized
())
{
return
grad_var
;
}
}
return
std
::
shared_ptr
<
imperative
::
VarBase
>
(
nullptr
);
},
py
::
return_value_policy
::
copy
)
.
def
(
"_set_grad_ivar"
,
[](
imperative
::
VarBase
&
self
,
imperative
::
VarBase
&
grad
)
{
self
.
SetGradVarBase
(
grad
);
})
.
def
(
"_is_sparse"
,
[](
imperative
::
VarBase
&
self
)
{
return
self
.
Var
().
IsType
<
phi
::
SelectedRows
>
();
})
.
def
(
"_allreduce"
,
[](
imperative
::
VarBase
&
self
,
const
imperative
::
ParallelStrategy
&
strategy
)
{
if
(
strategy
.
nranks_
>
1
)
{
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if NCCL_VERSION_CODE >= 2212
imperative
::
AllReduce
(
self
.
Var
(),
self
.
MutableVar
(),
strategy
);
#else
if
(
!
self
.
Var
().
IsType
<
phi
::
SelectedRows
>
())
{
imperative
::
AllReduce
(
self
.
Var
(),
self
.
MutableVar
(),
strategy
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Imperative SelectedRows allreduce is not supported when "
"paddle is compiled with NCCL version lower than v2.2.12. "
"You can set is_sparse=False for the Layer containing "
"this argument, such as Embedding(is_sparse=False)."
));
}
#endif // NCCL_VERSION_CODE
#else
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Imperative allreduce is not supported when paddle is "
"not compiled with NCCL."
));
#endif // PADDLE_WITH_NCCL or PADDLE_WITH_RCCL
}
},
py
::
call_guard
<
py
::
gil_scoped_release
>
())
.
def
(
"_register_grad_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register gradient hook on a Tensor that stop "
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
AddVariableWrapperHook
(
std
::
make_shared
<
PyVariableWrapperHook
>
(
hook
.
ptr
()));
})
.
def
(
"_remove_grad_hook"
,
[](
imperative
::
VarBase
&
self
,
int64_t
hook_id
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot remove gradient hook on a Tensor that stop "
"gradient or without gradient."
));
return
self
.
GradVarBase
()
->
RemoveVariableWrapperHook
(
hook_id
);
})
.
def
(
"_register_void_function_post_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register void function post hook on a Tensor that "
"stop "
"gradient or without gradient."
));
auto
py_func
=
PyObjectCast
<
std
::
function
<
void
()
>>
(
hook
.
ptr
());
auto
grad_node
=
self
.
MutableGradVarBase
()
->
GradNode
();
for
(
auto
&
cur_op
:
*
grad_node
)
{
cur_op
.
AddVoidFunctionPostHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
py_func
));
}
})
.
def
(
"_register_backward_hook"
,
[](
imperative
::
VarBase
&
self
,
const
py
::
handle
&
hook
)
{
PADDLE_ENFORCE_EQ
(
self
.
IsLeaf
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Only can register backward hook for leaf Tensor."
));
PADDLE_ENFORCE_EQ
(
!
self
.
OverridedStopGradient
()
&&
self
.
HasGradVar
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Cannot register backward hook on a Tensor that stop "
"gradient or without gradient."
));
auto
py_func
=
PyObjectCast
<
std
::
function
<
void
()
>>
(
hook
.
ptr
());
self
.
GradVarBase
()
->
AddVoidHook
(
std
::
make_shared
<
std
::
function
<
void
()
>>
(
py_func
));
},
R"DOC(
Registers a backward hook for current Tensor.
This hook will be called every time the gradient of current Tensor has been fully calculated.
There are two differences with `_register_grad_hook`:
1. This backward hook will be executed after the gradient accumulation completed across batches,
but the hook registered by `_register_grad_hook` will be executed the gradient accumulation
completed in current batch.
2. This backward hook function should have the following signature:
hook() -> None
It requires no input and no return value.
Args:
hook(function): A backward hook to be registered for Tensor.gradient
Returns:
None
)DOC"
)
.
def
(
"cpu"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
if
(
platform
::
is_cpu_place
(
self
->
Place
()))
{
return
self
;
}
else
{
auto
new_var
=
self
->
NewVarBase
(
platform
::
CPUPlace
(),
true
);
new_var
->
SetOverridedStopGradient
(
self
->
OverridedStopGradient
());
return
new_var
;
}
},
R"DOC(
Returns a copy of this Tensor in CPU memory.
If this Tensor is already in CPU memory, then no copy is performed and the original Tensor is returned.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
print(x.place) # CUDAPlace(0)
y = x.cpu()
print(y.place) # CPUPlace
)DOC"
)
.
def
(
"pin_memory"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Cannot copy this Tensor to pinned memory in CPU version "
"Paddle, "
"Please recompile or reinstall Paddle with CUDA support."
));
#endif
if
(
platform
::
is_cuda_pinned_place
(
self
->
Place
()))
{
return
self
;
}
else
{
auto
new_var
=
self
->
NewVarBase
(
platform
::
CUDAPinnedPlace
(),
true
);
new_var
->
SetOverridedStopGradient
(
self
->
OverridedStopGradient
());
return
new_var
;
}
},
R"DOC(
Returns a copy of this Tensor in pin memory.
If this Tensor is already in pin memory, then no copy is performed and the original Tensor is returned.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.0, place=paddle.CUDAPlace(0))
print(x.place) # CUDAPlace(0)
y = x.pin_memory()
print(y.place) # CUDAPinnedPlace
)DOC"
)
.
def
(
"cuda"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
py
::
handle
&
handle
,
bool
blocking
)
{
#if !defined(PADDLE_WITH_CUDA) && !defined(PADDLE_WITH_HIP)
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Cannot copy this Tensor to GPU in CPU version Paddle, "
"Please recompile or reinstall Paddle with CUDA support."
));
#else
int
device_count
=
platform
::
GetGPUDeviceCount
();
int
device_id
=
0
;
if
(
handle
==
py
::
none
())
{
auto
default_place
=
imperative
::
GetCurrentTracer
()
->
ExpectedPlace
();
device_id
=
default_place
.
GetDeviceId
();
}
else
{
PyObject
*
py_obj
=
handle
.
ptr
();
PADDLE_ENFORCE_EQ
(
PyCheckInteger
(
py_obj
),
true
,
platform
::
errors
::
InvalidArgument
(
" 'device_id' must be a positive integer"
));
device_id
=
py
::
cast
<
int
>
(
handle
);
}
PADDLE_ENFORCE_GE
(
device_id
,
0
,
platform
::
errors
::
InvalidArgument
(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)"
,
device_id
,
device_count
));
PADDLE_ENFORCE_LT
(
device_id
,
device_count
,
platform
::
errors
::
InvalidArgument
(
"Can not copy Tensor to Invalid CUDAPlace(%d), device id "
"must inside [0, %d)"
,
device_id
,
device_count
));
platform
::
CUDAPlace
place
=
platform
::
CUDAPlace
(
device_id
);
if
(
platform
::
is_same_place
(
self
->
Place
(),
place
))
{
return
self
;
}
else
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
new_var
->
SetOverridedStopGradient
(
self
->
OverridedStopGradient
());
return
new_var
;
}
#endif
},
py
::
arg
(
"device_id"
)
=
py
::
none
(),
py
::
arg
(
"blocking"
)
=
true
,
R"DOC(
Returns a copy of this Tensor in GPU memory.
If this Tensor is already in GPU memory and device_id is default,
then no copy is performed and the original Tensor is returned.
Args:
device_id(int, optional): The destination GPU device id. Default: None, means current device.
blocking(bool, optional): If False and the source is in pinned memory, the copy will be
asynchronous with respect to the host. Otherwise, the argument has no effect. Default: False.
Examples:
.. code-block:: python
# required: gpu
import paddle
x = paddle.to_tensor(1.0, place=paddle.CPUPlace())
print(x.place) # Place(cpu)
y = x.cuda()
print(y.place) # Place(gpu:0)
y = x.cuda(None)
print(y.place) # Place(gpu:0)
paddle.device.set_device("gpu:1")
y = x.cuda(None)
print(y.place) # Place(gpu:1)
)DOC"
)
.
def
(
"_share_memory"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
#ifndef _WIN32
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
self
->
Place
()),
true
,
platform
::
errors
::
InvalidArgument
(
"Sharing memory only support CPU Tensor currently"
));
// 1. get LoDTensor
auto
*
t
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
// 2. allocate shared memory
void
*
data_ptr
=
t
->
data
();
size_t
data_size
=
t
->
numel
()
*
framework
::
SizeOfType
(
framework
::
TransToProtoVarType
(
t
->
dtype
()));
auto
shared_writer_holder
=
memory
::
allocation
::
AllocateMemoryMapWriterAllocation
(
data_size
);
// 3. maintain mmap fd set & backup ipc_name
const
std
::
string
&
ipc_name
=
shared_writer_holder
->
ipc_name
();
memory
::
allocation
::
MemoryMapFdSet
::
Instance
().
Insert
(
ipc_name
);
// 4. copy data & reset holder
memory
::
Copy
(
platform
::
CPUPlace
(),
shared_writer_holder
->
ptr
(),
platform
::
CPUPlace
(),
data_ptr
,
data_size
);
t
->
ResetHolder
(
shared_writer_holder
);
return
*
t
;
#else
PADDLE_THROW
(
platform
::
errors
::
PermissionDenied
(
"Sharing memory in Windows OS is not supported currently"
));
#endif
},
py
::
return_value_policy
::
reference
)
#if defined(PADDLE_WITH_CUDA)
.
def
(
"_uva"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
int
device_id
)
{
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
self
->
Place
()),
true
,
platform
::
errors
::
InvalidArgument
(
"Unified virtual addressing only support "
"CPU Tensor currently."
));
auto
*
self_tensor
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
tensor_uva
(
self_tensor
,
device_id
);
},
py
::
arg
(
"device_id"
)
=
0
,
py
::
return_value_policy
::
reference
,
R"DOC(
Returns self tensor with the UVA(unified virtual addressing).
Args:
device_id(int, optional): The destination GPU device id. Default: None, means current device.
Examples:
.. code-block:: python
# required: gpu
import paddle
x = paddle.to_tensor([1, 2, 3], place=paddle.CPUPlace())
x._uva()
print(x)
)DOC"
)
#endif
.
def
(
"copy_"
,
&
imperative
::
VarBase
::
CopyFrom
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
CPUPlace
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
// Note(zhiqiu): Since NewVarBase may use GpuCopyAsync to
// copy data from the tensor of self to the tensor of new varbase,
// we need to ensure that the varbase self is not destructed until
// the GpuCopyAsync is completed. Otherwise, the memory may be
// freed
// when varbase self is destructed.
// To do that, we increase the reference count of self by 1 and
// add a cuda event to wait the GpuCopyAsync's completion.
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
CUDAPinnedPlace
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
XPUPlace
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
CUDAPlace
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
IPUPlace
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
CustomPlace
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"_copy_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
platform
::
Place
&
place
,
bool
blocking
)
{
auto
new_var
=
self
->
NewVarBase
(
place
,
blocking
);
if
(
!
blocking
)
{
IncreaseVarbaseReferenceCountUntilCopyComplete
(
self
,
place
);
}
return
new_var
;
},
py
::
return_value_policy
::
copy
)
.
def
(
"value"
,
[](
imperative
::
VarBase
&
self
)
{
return
self
.
MutableVar
();
},
py
::
return_value_policy
::
reference
)
.
def
(
"_clear"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
auto
*
t
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
t
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
->
Name
()));
t
->
clear
();
})
.
def
(
"_offset"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
auto
*
t
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
t
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
->
Name
()));
return
t
->
offset
();
})
.
def
(
"_share_buffer_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
std
::
shared_ptr
<
imperative
::
VarBase
>
&
dst
)
{
auto
*
src
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
dst_
=
dst
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
src
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
->
Name
()));
dst_
->
ShareBufferWith
(
*
src
);
dst_
->
ShareDataTypeWith
(
*
src
);
})
.
def
(
"_is_shared_buffer_with"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
std
::
shared_ptr
<
imperative
::
VarBase
>
&
dst
)
{
auto
*
src
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
dst_
=
dst
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
!
src
->
IsInitialized
()
||
!
dst_
->
IsInitialized
())
{
return
false
;
}
return
dst_
->
IsSharedBufferWith
(
*
src
);
})
.
def
(
"_share_underline_tensor_to"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
std
::
shared_ptr
<
imperative
::
VarBase
>
&
dst
)
{
auto
*
src
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
dst_
=
dst
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
src
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
->
Name
()));
dst_
->
ShareBufferWith
(
*
src
);
dst_
->
ShareDataTypeWith
(
*
src
);
dst_
->
Resize
(
src
->
dims
());
})
.
def
(
"_is_shared_underline_tensor_with"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
std
::
shared_ptr
<
imperative
::
VarBase
>
&
dst
)
{
auto
*
src
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
auto
*
dst_
=
dst
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
if
(
!
src
->
IsInitialized
()
||
!
dst_
->
IsInitialized
())
{
return
false
;
}
return
dst_
->
IsSharedBufferWith
(
*
src
);
})
.
def
(
"_slice"
,
[](
const
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
int64_t
begin_idx
,
int64_t
end_idx
)
{
auto
*
t
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
PADDLE_ENFORCE_EQ
(
t
->
IsInitialized
(),
true
,
platform
::
errors
::
InvalidArgument
(
"Tensor %s has not been initialized!"
,
self
->
Name
()));
return
t
->
Slice
(
begin_idx
,
end_idx
);
})
.
def
(
"_copy_gradient_from"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
,
const
imperative
::
VarBase
&
src
)
{
self
->
_CopyGradientFrom
(
src
);
})
.
def
(
"_numel"
,
[](
std
::
shared_ptr
<
imperative
::
VarBase
>
&
self
)
{
auto
*
t
=
self
->
MutableVar
()
->
GetMutable
<
phi
::
DenseTensor
>
();
return
t
->
numel
();
})
.
def
(
"element_size"
,
&
imperative
::
VarBase
::
ElementSize
,
R"DOC(
Returns the size in bytes of an element in the Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1, dtype='bool')
x.element_size() # 1
x = paddle.to_tensor(1, dtype='float16')
x.element_size() # 2
x = paddle.to_tensor(1, dtype='float32')
x.element_size() # 4
x = paddle.to_tensor(1, dtype='float64')
x.element_size() # 8
x = paddle.to_tensor(1, dtype='complex128')
x.element_size() # 16
)DOC"
)
.
def_property
(
"name"
,
&
imperative
::
VarBase
::
Name
,
&
imperative
::
VarBase
::
SetName
)
.
def_property
(
"stop_gradient"
,
&
imperative
::
VarBase
::
OverridedStopGradient
,
&
imperative
::
VarBase
::
SetOverridedStopGradient
)
.
def_property
(
"persistable"
,
&
imperative
::
VarBase
::
Persistable
,
&
imperative
::
VarBase
::
SetPersistable
)
.
def_property_readonly
(
"shape"
,
[](
imperative
::
VarBase
&
self
)
{
if
(
self
.
Var
().
IsType
<
phi
::
DenseTensor
>
())
{
auto
value
=
phi
::
vectorize
<
int
>
(
self
.
Var
().
Get
<
phi
::
DenseTensor
>
().
dims
());
auto
tensor
=
self
.
Var
().
Get
<
phi
::
DenseTensor
>
();
auto
tmp_value
=
value
;
auto
desired_layout
=
paddle
::
imperative
::
LayoutAutoTune
::
Instance
()
.
GetDesiredLayout
();
auto
default_layout
=
paddle
::
imperative
::
LayoutAutoTune
::
Instance
()
.
GetDefaultLayout
();
bool
change_dim
=
(
desired_layout
!=
default_layout
&&
tensor
.
layout
()
==
desired_layout
&&
value
.
size
()
==
4
);
VLOG
(
6
)
<<
"'Shape' method, layout autotune,"
<<
" desired_layout: "
<<
desired_layout
<<
" default_layout: "
<<
default_layout
<<
" tensor layout: "
<<
tensor
.
layout
()
<<
" tensor's shape size is : "
<<
value
.
size
();
if
(
change_dim
&&
phi
::
DataLayoutToString
(
desired_layout
)
==
"NCHW"
)
{
VLOG
(
6
)
<<
"layout autotune get Shape from NHWC -> NCHW "
<<
value
[
0
]
<<
" "
<<
value
[
1
]
<<
" "
<<
value
[
2
]
<<
" "
<<
value
[
3
]
<<
" to "
<<
tmp_value
[
3
]
<<
" "
<<
tmp_value
[
1
]
<<
" "
<<
tmp_value
[
2
]
<<
" "
<<
tmp_value
[
1
];
// NCHW -> NHWC
value
[
1
]
=
tmp_value
[
2
];
value
[
2
]
=
tmp_value
[
3
];
value
[
3
]
=
tmp_value
[
1
];
}
else
if
(
change_dim
&&
phi
::
DataLayoutToString
(
desired_layout
)
==
"NHWC"
)
{
VLOG
(
6
)
<<
"layout autotune get Shape from NHWC -> NCHW "
<<
value
[
0
]
<<
" "
<<
value
[
1
]
<<
" "
<<
value
[
2
]
<<
" "
<<
value
[
3
]
<<
" to "
<<
tmp_value
[
0
]
<<
" "
<<
tmp_value
[
3
]
<<
" "
<<
tmp_value
[
1
]
<<
" "
<<
tmp_value
[
2
];
// NHWC -> NCHW
value
[
1
]
=
tmp_value
[
3
];
value
[
2
]
=
tmp_value
[
1
];
value
[
3
]
=
tmp_value
[
2
];
}
return
value
;
}
else
if
(
self
.
Var
().
IsType
<
phi
::
SelectedRows
>
())
{
return
phi
::
vectorize
<
int
>
(
self
.
Var
().
Get
<
phi
::
SelectedRows
>
().
value
().
dims
());
}
else
if
(
self
.
Var
().
IsType
<
framework
::
Strings
>
())
{
return
std
::
vector
<
int
>
{
static_cast
<
int
>
(
self
.
Var
().
Get
<
framework
::
Strings
>
().
size
())};
}
else
if
(
self
.
Var
().
IsType
<
framework
::
Vocab
>
())
{
return
std
::
vector
<
int
>
{
static_cast
<
int
>
(
self
.
Var
().
Get
<
framework
::
Vocab
>
().
size
())};
}
else
{
VLOG
(
2
)
<<
"It is meaningless to get shape of "
"variable type "
<<
GetTypeName
(
self
);
return
std
::
vector
<
int
>
();
}
})
.
def_property_readonly
(
"layout"
,
[](
imperative
::
VarBase
&
self
)
{
if
(
self
.
Var
().
IsType
<
phi
::
DenseTensor
>
())
{
auto
layout
=
self
.
Var
().
Get
<
phi
::
DenseTensor
>
().
layout
();
return
phi
::
DataLayoutToString
(
layout
);
}
return
std
::
string
(
""
);
})
.
def_property_readonly
(
"is_leaf"
,
&
imperative
::
VarBase
::
IsLeaf
,
R"DOC(
Whether a Tensor is leaf Tensor.
For the Tensor whose stop_gradient is ``True`` , it will be leaf Tensor.
For the Tensor whose stop_gradient is ``False`` , it will be leaf Tensor too if it is created by user.
Returns:
bool: Whether a Tensor is leaf Tensor.
Examples:
.. code-block:: python
import paddle
x = paddle.to_tensor(1.)
print(x.is_leaf) # True
x = paddle.to_tensor(1., stop_gradient=True)
y = x + 1
print(x.is_leaf) # True
print(y.is_leaf) # True
x = paddle.to_tensor(1., stop_gradient=False)
y = x + 1
print(x.is_leaf) # True
print(y.is_leaf) # False
)DOC"
)
.
def_property_readonly
(
"place"
,
[](
imperative
::
VarBase
&
self
)
{
return
self
.
Place
();
},
py
::
return_value_policy
::
copy
)
.
def_property_readonly
(
"_place_str"
,
[](
imperative
::
VarBase
&
self
)
{
std
::
stringstream
ostr
;
ostr
<<
self
.
Place
();
return
ostr
.
str
();
})
.
def_property_readonly
(
"type"
,
&
imperative
::
VarBase
::
Type
)
.
def_property_readonly
(
"dtype"
,
&
imperative
::
VarBase
::
DataType
);
py
::
class_
<
imperative
::
jit
::
ProgramDescTracer
>
(
m
,
"ProgramDescTracer"
,
""
)
.
def
(
"create_program_desc"
,
&
imperative
::
jit
::
ProgramDescTracer
::
CreateProgramDesc
)
...
...
paddle/fluid/pybind/op_function_common.cc
浏览文件 @
3f630658
...
...
@@ -61,7 +61,6 @@ class OpAttrTypeMap {
ops_attrtype_map_
;
};
extern
PyTypeObject
*
g_varbase_pytype
;
extern
PyTypeObject
*
g_vartype_pytype
;
extern
PyTypeObject
*
g_blockdesc_pytype
;
extern
PyTypeObject
*
p_tensor_type
;
...
...
@@ -71,7 +70,6 @@ bool PyObject_CheckBool(PyObject** obj) { return PyBool_Check(*obj); }
bool
PyObject_CheckLongOrToLong
(
PyObject
**
obj
)
{
if
((
PyLong_Check
(
*
obj
)
&&
!
PyBool_Check
(
*
obj
))
||
PyObject_TypeCheck
(
*
obj
,
g_vartype_pytype
)
||
// NOLINT
PyObject_TypeCheck
(
*
obj
,
g_varbase_pytype
)
||
// NOLINT
(
PyObject_TypeCheck
(
*
obj
,
p_tensor_type
)
&&
// NOLINT
(((
TensorObject
*
)(
*
obj
))
->
tensor
.
numel
()
==
1
)))
{
// NOLINT
return
true
;
...
...
@@ -92,7 +90,6 @@ bool PyObject_CheckLongOrToLong(PyObject** obj) {
bool
PyObject_CheckFloatOrToFloat
(
PyObject
**
obj
)
{
// sometimes users provide PyLong or numpy.int64 but attr is float
if
(
PyFloat_Check
(
*
obj
)
||
PyLong_Check
(
*
obj
)
||
PyObject_TypeCheck
(
*
obj
,
g_varbase_pytype
)
||
// NOLINT
(
PyObject_TypeCheck
(
*
obj
,
p_tensor_type
)
&&
// NOLINT
(((
TensorObject
*
)(
*
obj
))
->
tensor
.
numel
()
==
1
)))
{
// NOLINT
return
true
;
...
...
@@ -111,7 +108,6 @@ bool PyObject_CheckFloatOrToFloat(PyObject** obj) {
bool
PyObject_CheckComplexOrToComplex
(
PyObject
**
obj
)
{
if
(
PyComplex_Check
(
*
obj
)
||
PyLong_Check
(
*
obj
)
||
PyFloat_Check
(
*
obj
)
||
PyObject_TypeCheck
(
*
obj
,
g_vartype_pytype
)
||
// NOLINT
PyObject_TypeCheck
(
*
obj
,
g_varbase_pytype
)
||
// NOLINT
PyObject_TypeCheck
(
*
obj
,
p_tensor_type
))
{
// NOLINT
return
true
;
}
...
...
@@ -926,138 +922,6 @@ void ConstructAttrMapFromPyArgs(
}
}
std
::
shared_ptr
<
imperative
::
VarBase
>
GetVarBaseFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
)
{
::
pybind11
::
detail
::
instance
*
inst
=
(
::
pybind11
::
detail
::
instance
*
)
PyTuple_GET_ITEM
(
args
,
arg_idx
);
if
(
PyTuple_Check
((
PyObject
*
)
inst
))
{
// NOLINT
inst
=
(
::
pybind11
::
detail
::
instance
*
)
PyTuple_GET_ITEM
(
inst
,
0
);
}
if
(
inst
==
nullptr
||
(
PyObject
*
)
inst
==
Py_None
)
{
// NOLINT
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got None"
,
op_type
,
arg_name
,
arg_idx
));
}
return
nullptr
;
}
if
(
!
PyObject_TypeCheck
((
PyObject
*
)
inst
,
g_varbase_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be Tensor, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)((
PyObject
*
)
inst
)
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
inst
->
simple_layout
?
inst
->
simple_value_holder
:
&
inst
->
nonsimple
.
values_and_holders
[
0
];
return
reinterpret_cast
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&>
(
vh
[
1
]);
}
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
GetVarBaseListFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
)
{
PyObject
*
list
=
PyTuple_GET_ITEM
(
args
,
arg_idx
);
if
(
list
==
nullptr
||
list
==
Py_None
)
{
if
(
!
dispensable
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensor, but got "
"None"
,
op_type
,
arg_name
,
arg_idx
));
// NOLINT
}
return
{};
}
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
result
;
if
(
PyList_Check
(
list
))
{
Py_ssize_t
len
=
PyList_Size
(
list
);
if
(
len
==
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list"
,
op_type
,
arg_name
,
arg_idx
));
}
::
pybind11
::
detail
::
instance
*
item
=
nullptr
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
(
::
pybind11
::
detail
::
instance
*
)
PyList_GetItem
(
list
,
i
);
if
(
!
PyObject_TypeCheck
((
PyObject
*
)
item
,
g_varbase_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)((
PyObject
*
)
item
)
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
item
->
simple_layout
?
item
->
simple_value_holder
:
&
item
->
nonsimple
.
values_and_holders
[
0
];
result
.
emplace_back
(
reinterpret_cast
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&>
(
vh
[
1
]));
}
}
else
if
(
PyTuple_Check
(
list
))
{
Py_ssize_t
len
=
PyTuple_Size
(
list
);
if
(
len
==
0
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"empty list"
,
op_type
,
arg_name
,
arg_idx
));
}
::
pybind11
::
detail
::
instance
*
item
=
nullptr
;
for
(
Py_ssize_t
i
=
0
;
i
<
len
;
i
++
)
{
item
=
(
::
pybind11
::
detail
::
instance
*
)
PyTuple_GetItem
(
list
,
i
);
// NOLINT
if
(
!
PyObject_TypeCheck
((
PyObject
*
)
item
,
g_varbase_pytype
))
{
// NOLINT
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but "
"got list of "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)((
PyObject
*
)
item
)
->
ob_type
)
->
tp_name
));
// NOLINT
}
void
**
vh
=
item
->
simple_layout
?
item
->
simple_value_holder
:
&
item
->
nonsimple
.
values_and_holders
[
0
];
result
.
emplace_back
(
reinterpret_cast
<
std
::
shared_ptr
<
paddle
::
imperative
::
VarBase
>&>
(
vh
[
1
]));
}
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"%s(): argument '%s' (position %d) must be list of Tensors, but got "
"%s"
,
op_type
,
arg_name
,
arg_idx
,
((
PyTypeObject
*
)
list
->
ob_type
)
->
tp_name
));
// NOLINT
}
return
result
;
}
unsigned
long
GetUnsignedLongFromArgs
(
// NOLINT
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
...
...
paddle/fluid/pybind/op_function_common.h
浏览文件 @
3f630658
...
...
@@ -194,20 +194,6 @@ void ConstructAttrMapFromPyArgs(
ssize_t
attr_end
,
paddle
::
framework
::
AttributeMap
&
attrs
);
// NOLINT
std
::
shared_ptr
<
imperative
::
VarBase
>
GetVarBaseFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
=
false
);
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
GetVarBaseListFromArgs
(
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
PyObject
*
args
,
ssize_t
arg_idx
,
bool
dispensable
=
false
);
unsigned
long
GetUnsignedLongFromArgs
(
// NOLINT
const
std
::
string
&
op_type
,
const
std
::
string
&
arg_name
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录