Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
438bca9c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
438bca9c
编写于
3月 15, 2019
作者:
M
minqiyang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement Runtime Var Type Inference
test=develop
上级
ca392c7e
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
265 addition
and
77 deletion
+265
-77
paddle/fluid/framework/details/op_registry.h
paddle/fluid/framework/details/op_registry.h
+2
-0
paddle/fluid/framework/var_type_inference.h
paddle/fluid/framework/var_type_inference.h
+15
-7
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+64
-20
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+133
-9
paddle/fluid/imperative/tracer.cc
paddle/fluid/imperative/tracer.cc
+16
-9
paddle/fluid/imperative/type_defs.h
paddle/fluid/imperative/type_defs.h
+1
-0
paddle/fluid/operators/controlflow/while_op.cc
paddle/fluid/operators/controlflow/while_op.cc
+7
-10
paddle/fluid/operators/distributed_ops/split_ids_op.cc
paddle/fluid/operators/distributed_ops/split_ids_op.cc
+2
-0
paddle/fluid/operators/nccl/nccl_op.cc
paddle/fluid/operators/nccl/nccl_op.cc
+3
-6
paddle/fluid/operators/py_func_op.cc
paddle/fluid/operators/py_func_op.cc
+3
-0
paddle/fluid/operators/reader/create_custom_reader_op.cc
paddle/fluid/operators/reader/create_custom_reader_op.cc
+1
-1
paddle/fluid/operators/reader/read_op.cc
paddle/fluid/operators/reader/read_op.cc
+1
-1
paddle/fluid/operators/reader/reader_op_registry.cc
paddle/fluid/operators/reader/reader_op_registry.cc
+9
-12
paddle/fluid/operators/reader/reader_op_registry.h
paddle/fluid/operators/reader/reader_op_registry.h
+2
-0
paddle/fluid/operators/scale_op.cc
paddle/fluid/operators/scale_op.cc
+1
-0
paddle/fluid/operators/split_selected_rows_op.cc
paddle/fluid/operators/split_selected_rows_op.cc
+2
-0
paddle/fluid/operators/sum_op.cc
paddle/fluid/operators/sum_op.cc
+1
-0
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+2
-2
未找到文件。
paddle/fluid/framework/details/op_registry.h
浏览文件 @
438bca9c
...
@@ -16,6 +16,8 @@ limitations under the License. */
...
@@ -16,6 +16,8 @@ limitations under the License. */
#include <string>
#include <string>
#include <tuple>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/grad_op_desc_maker.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
#include "paddle/fluid/framework/inplace_op_inference.h"
...
...
paddle/fluid/framework/var_type_inference.h
浏览文件 @
438bca9c
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
...
@@ -80,6 +81,19 @@ class InferVarTypeContext {
...
@@ -80,6 +81,19 @@ class InferVarTypeContext {
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataType
(
type
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataType
(
type
);
}
}
inline
std
::
vector
<
proto
::
VarType
::
Type
>
GetDataTypes
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetDataTypes
();
}
inline
void
SetDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetDataTypes
(
multiple_data_type
);
}
inline
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
{
inline
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
block_
);
PADDLE_ENFORCE_NOT_NULL
(
block_
);
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetShape
();
return
block_
->
FindRecursiveOrCreateVar
(
name
).
GetShape
();
...
@@ -101,17 +115,11 @@ class InferVarTypeContext {
...
@@ -101,17 +115,11 @@ class InferVarTypeContext {
block_
->
FindRecursiveOrCreateVar
(
name
).
SetLoDLevel
(
lod_level
);
block_
->
FindRecursiveOrCreateVar
(
name
).
SetLoDLevel
(
lod_level
);
}
}
pr
ivate
:
pr
otected
:
const
OpDesc
*
op_
;
const
OpDesc
*
op_
;
BlockDesc
*
block_
;
BlockDesc
*
block_
;
};
};
// infer var type context for imperative mode
class
RuntimeInferVarTypeContext
:
public
InferVarTypeContext
{
public:
RuntimeInferVarTypeContext
()
:
InferVarTypeContext
(
nullptr
,
nullptr
)
{}
};
class
VarTypeInference
{
class
VarTypeInference
{
public:
public:
virtual
~
VarTypeInference
()
{}
virtual
~
VarTypeInference
()
{}
...
...
paddle/fluid/imperative/layer.cc
浏览文件 @
438bca9c
...
@@ -220,7 +220,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -220,7 +220,7 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
}
}
VLOG
(
3
)
<<
"apply op grad: "
<<
Type
();
VLOG
(
3
)
<<
"apply op grad: "
<<
Type
();
std
::
vector
<
framework
::
VariableValue
Map
>
tmp_grad_outputs
;
std
::
vector
<
VarBasePtr
Map
>
tmp_grad_outputs
;
if
(
backward_id_
>
0
)
{
if
(
backward_id_
>
0
)
{
VLOG
(
3
)
<<
"py_layer_grad"
;
VLOG
(
3
)
<<
"py_layer_grad"
;
tmp_grad_outputs
.
resize
(
1
);
tmp_grad_outputs
.
resize
(
1
);
...
@@ -246,23 +246,59 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -246,23 +246,59 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
// Allocate a new variable
// Allocate a new variable
Variable
*
tmp_var
=
new
framework
::
Variable
();
Variable
*
tmp_var
=
new
framework
::
Variable
();
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tmp_var
->
GetMutable
<
framework
::
LoDTensor
>
();
outputs
.
emplace_back
(
tmp_var
);
VarBase
*
tmp_var_base
=
new
VarBase
(
it
.
second
[
i
]
->
Name
(),
tmp_var
,
nullptr
,
true
);
outputs
.
emplace_back
(
tmp_var_base
);
}
}
}
}
// Run grad op
framework
::
RuntimeContext
ctx
(
grad_input_vars_
[
k
],
tmp_grad_outputs
[
k
]);
// No need to do compile time infer shape here.
// No need to do compile time infer shape here.
// grad_op_desc_->InferShape(*block_);
// grad_op_desc_->InferShape(*block_);
// grad_op_desc->InferVarType(block_);
// grad_op_desc->InferVarType(block_);
std
::
unique_ptr
<
framework
::
OperatorBase
>
opbase
=
std
::
unique_ptr
<
framework
::
OperatorBase
>
opbase
=
framework
::
OpRegistry
::
CreateOp
(
*
grad_op_desc
);
framework
::
OpRegistry
::
CreateOp
(
*
grad_op_desc
);
// auto& info =
// framework::OpInfoMap::Instance().Get(grad_op_desc->Type());
// if (info.infer_var_type_) {
// framework::RuntimeInferVarTypeContext infer_var_type_ctx(
// this, &grad_inputs, &outputs, &attrs_map);
// info.infer_var_type_(infer_var_type_ctx);
// }
framework
::
OperatorWithKernel
*
op_kernel
=
framework
::
OperatorWithKernel
*
op_kernel
=
dynamic_cast
<
framework
::
OperatorWithKernel
*>
(
opbase
.
get
());
dynamic_cast
<
framework
::
OperatorWithKernel
*>
(
opbase
.
get
());
PADDLE_ENFORCE_NOT_NULL
(
op_kernel
,
"only support op with kernel"
);
PADDLE_ENFORCE_NOT_NULL
(
op_kernel
,
"only support op with kernel"
);
// Run grad op
framework
::
VariableValueMap
grad_invars_map
;
framework
::
VariableValueMap
grad_outvars_map
;
for
(
const
auto
&
it
:
grad_input_vars_
[
k
])
{
auto
&
grad_invars
=
grad_invars_map
[
it
.
first
];
grad_invars
.
reserve
(
it
.
second
.
size
());
for
(
const
VarBase
*
grad_inp
:
it
.
second
)
{
PADDLE_ENFORCE_NOT_NULL
(
grad_inp
->
var_
,
"op %s input %s nullptr"
,
grad_op_desc
->
Type
(),
grad_inp
->
Name
());
grad_invars
.
emplace_back
(
grad_inp
->
var_
);
}
}
for
(
const
auto
&
it
:
tmp_grad_outputs
[
k
])
{
auto
&
grad_outvars
=
grad_outvars_map
[
it
.
first
];
grad_outvars
.
reserve
(
it
.
second
.
size
());
for
(
VarBase
*
grad_out
:
it
.
second
)
{
PADDLE_ENFORCE_NOT_NULL
(
grad_out
->
var_
,
"op %s output %s nullptr"
,
grad_op_desc
->
Type
(),
grad_out
->
Name
());
grad_outvars
.
emplace_back
(
grad_out
->
var_
);
}
}
framework
::
RuntimeContext
ctx
(
grad_invars_map
,
grad_outvars_map
);
framework
::
Scope
scope
;
framework
::
Scope
scope
;
PreparedOp
p
=
PreparedOp
::
Prepare
(
ctx
,
*
op_kernel
,
place_
);
PreparedOp
p
=
PreparedOp
::
Prepare
(
ctx
,
*
op_kernel
,
place_
);
p
.
op
.
RuntimeInferShape
(
scope
,
place_
,
ctx
);
p
.
op
.
RuntimeInferShape
(
scope
,
place_
,
ctx
);
...
@@ -279,8 +315,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
...
@@ -279,8 +315,8 @@ std::map<std::string, std::vector<VarBase*>> OpBase::ApplyGrad() {
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
origin_outputs
.
size
());
PADDLE_ENFORCE_EQ
(
outputs
.
size
(),
origin_outputs
.
size
());
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
++
i
)
{
framework
::
Variable
*
grad
=
outputs
[
i
];
framework
::
Variable
*
grad
=
outputs
[
i
]
->
var_
;
framework
::
Variable
*
orig_grad
=
origin_outputs
[
i
];
framework
::
Variable
*
orig_grad
=
origin_outputs
[
i
]
->
var_
;
AddTo
(
grad
,
orig_grad
,
place_
);
AddTo
(
grad
,
orig_grad
,
place_
);
delete
grad
;
delete
grad
;
}
}
...
@@ -328,28 +364,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
...
@@ -328,28 +364,35 @@ void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
int
PyLayer
::
NumFuncs
()
{
return
py_funcs_
.
size
();
}
int
PyLayer
::
NumFuncs
()
{
return
py_funcs_
.
size
();
}
std
::
vector
<
Variable
*>
PyLayer
::
Apply
(
int
func_id
,
std
::
vector
<
framework
::
Variable
*>
PyLayer
::
Apply
(
const
std
::
vector
<
VarBase
*>&
inputs
)
{
int
func_id
,
const
std
::
vector
<
VarBase
*>&
inputs
)
{
std
::
vector
<
framework
::
Variable
*>
invars
;
for
(
const
VarBase
*
in
:
inputs
)
{
invars
.
push_back
(
in
->
var_
);
}
PADDLE_ENFORCE
(
py_funcs_
.
find
(
func_id
)
!=
py_funcs_
.
end
());
PADDLE_ENFORCE
(
py_funcs_
.
find
(
func_id
)
!=
py_funcs_
.
end
());
return
CallPythonFunc
(
py_funcs_
[
func_id
],
in
var
s
);
return
CallPythonFunc
(
py_funcs_
[
func_id
],
in
put
s
);
}
}
std
::
vector
<
Var
iable
*>
PyLayer
::
ApplyGrad
(
std
::
vector
<
Var
Base
*>
PyLayer
::
ApplyGrad
(
int
func_id
,
int
func_id
,
const
std
::
vector
<
framework
::
Variabl
e
*>&
inputs
)
{
const
std
::
vector
<
VarBas
e
*>&
inputs
)
{
PADDLE_ENFORCE
(
py_funcs_
.
find
(
func_id
)
!=
py_funcs_
.
end
());
PADDLE_ENFORCE
(
py_funcs_
.
find
(
func_id
)
!=
py_funcs_
.
end
());
return
CallPythonFunc
(
py_funcs_
[
func_id
],
inputs
);
auto
rets
=
CallPythonFunc
(
py_funcs_
[
func_id
],
inputs
);
std
::
vector
<
VarBase
*>
outs
;
outs
.
reserve
(
rets
.
size
());
for
(
size_t
i
=
0U
;
i
!=
rets
.
size
();
++
i
)
{
outs
.
emplace_back
(
new
VarBase
(
string
::
Sprintf
(
"%s_out_%d"
,
framework
::
GradVarName
(
PyLayer
::
kFwdOut
),
i
),
rets
[
i
],
nullptr
,
true
));
}
return
outs
;
}
}
std
::
vector
<
framework
::
Variable
*>
PyLayer
::
CallPythonFunc
(
std
::
vector
<
framework
::
Variable
*>
PyLayer
::
CallPythonFunc
(
const
py
::
object
&
callable
,
const
std
::
vector
<
framework
::
Variabl
e
*>&
ins
)
{
const
py
::
object
&
callable
,
const
std
::
vector
<
VarBas
e
*>&
ins
)
{
py
::
gil_scoped_acquire
guard
;
py
::
gil_scoped_acquire
guard
;
py
::
tuple
in_args
(
ins
.
size
());
py
::
tuple
in_args
(
ins
.
size
());
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
const
framework
::
LoDTensor
&
t
=
ins
[
i
]
->
Get
<
framework
::
LoDTensor
>
();
const
framework
::
LoDTensor
&
t
=
ins
[
i
]
->
var_
->
Get
<
framework
::
LoDTensor
>
();
in_args
[
i
]
=
t
.
IsInitialized
()
?
py
::
cast
(
t
)
:
py
::
cast
(
nullptr
);
in_args
[
i
]
=
t
.
IsInitialized
()
?
py
::
cast
(
t
)
:
py
::
cast
(
nullptr
);
}
}
VLOG
(
3
)
<<
"pyfunc in "
<<
py
::
len
(
in_args
);
VLOG
(
3
)
<<
"pyfunc in "
<<
py
::
len
(
in_args
);
...
@@ -359,6 +402,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
...
@@ -359,6 +402,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto
ret_tuple
=
py
::
cast
<
py
::
tuple
>
(
ret
);
auto
ret_tuple
=
py
::
cast
<
py
::
tuple
>
(
ret
);
size_t
ret_num
=
py
::
len
(
ret_tuple
);
size_t
ret_num
=
py
::
len
(
ret_tuple
);
std
::
vector
<
framework
::
Variable
*>
outs
;
std
::
vector
<
framework
::
Variable
*>
outs
;
outs
.
reserve
(
ret_num
);
VLOG
(
3
)
<<
"pyfunc out "
<<
ret_num
;
VLOG
(
3
)
<<
"pyfunc out "
<<
ret_num
;
for
(
size_t
i
=
0
;
i
<
ret_num
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ret_num
;
++
i
)
{
try
{
try
{
...
@@ -369,7 +413,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
...
@@ -369,7 +413,7 @@ std::vector<framework::Variable*> PyLayer::CallPythonFunc(
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
ShareDataWith
(
*
py_out_tensor
);
tensor
->
ShareDataWith
(
*
py_out_tensor
);
tensor
->
set_lod
(
py_out_tensor
->
lod
());
tensor
->
set_lod
(
py_out_tensor
->
lod
());
outs
.
push
_back
(
var
);
outs
.
emplace
_back
(
var
);
}
catch
(
py
::
cast_error
&
)
{
}
catch
(
py
::
cast_error
&
)
{
PADDLE_THROW
(
"The %d-th output must be LoDTensor"
,
i
);
PADDLE_THROW
(
"The %d-th output must be LoDTensor"
,
i
);
}
}
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
438bca9c
...
@@ -18,14 +18,16 @@
...
@@ -18,14 +18,16 @@
#include "paddle/fluid/framework/python_headers.h"
#include "paddle/fluid/framework/python_headers.h"
// clang-format on
// clang-format on
#include <map> // NOLINT
#include <map> // NOLINT
#include <string> // NOLINT
#include <string> // NOLINT
#include <vector> // NOLINT
#include <vector> // NOLINT
#include <memory> // NOLINT
#include <memory> // NOLINT
#include <unordered_map> // NOLINT
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
...
@@ -184,6 +186,10 @@ class VarBase {
...
@@ -184,6 +186,10 @@ class VarBase {
}
}
}
}
inline
void
SetDType
(
framework
::
proto
::
VarType
::
Type
type
)
{
auto
tensor
=
var_
->
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
mutable_data
(
place_
,
dtype_
);
}
inline
framework
::
proto
::
VarType
::
Type
DType
()
const
{
return
dtype_
;
}
inline
framework
::
proto
::
VarType
::
Type
DType
()
const
{
return
dtype_
;
}
inline
void
SetStopGradient
(
bool
stop_gradient
)
{
inline
void
SetStopGradient
(
bool
stop_gradient
)
{
...
@@ -328,9 +334,9 @@ class PYBIND11_HIDDEN OpBase {
...
@@ -328,9 +334,9 @@ class PYBIND11_HIDDEN OpBase {
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
pre_ops_out_idx_
;
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
pre_ops_out_idx_
;
// Inputs to a vector of bwd ops.
// Inputs to a vector of bwd ops.
std
::
vector
<
framework
::
VariableValue
Map
>
grad_input_vars_
;
std
::
vector
<
VarBasePtr
Map
>
grad_input_vars_
;
// Outputs to a vector of bwd ops.
// Outputs to a vector of bwd ops.
std
::
vector
<
framework
::
VariableValue
Map
>
grad_output_vars_
;
std
::
vector
<
VarBasePtr
Map
>
grad_output_vars_
;
std
::
vector
<
py
::
object
>
backward_hooks_
;
std
::
vector
<
py
::
object
>
backward_hooks_
;
};
};
...
@@ -359,12 +365,130 @@ class PyLayer {
...
@@ -359,12 +365,130 @@ class PyLayer {
static
std
::
vector
<
framework
::
Variable
*>
Apply
(
static
std
::
vector
<
framework
::
Variable
*>
Apply
(
int
func_id
,
const
std
::
vector
<
VarBase
*>&
inputs
);
int
func_id
,
const
std
::
vector
<
VarBase
*>&
inputs
);
static
std
::
vector
<
framework
::
Variable
*>
ApplyGrad
(
static
std
::
vector
<
VarBase
*>
ApplyGrad
(
int
func_id
,
int
func_id
,
const
std
::
vector
<
framework
::
Variabl
e
*>&
inputs
);
const
std
::
vector
<
VarBas
e
*>&
inputs
);
private:
private:
static
std
::
vector
<
framework
::
Variable
*>
CallPythonFunc
(
static
std
::
vector
<
framework
::
Variable
*>
CallPythonFunc
(
const
py
::
object
&
callable
,
const
std
::
vector
<
framework
::
Variable
*>&
ins
);
const
py
::
object
&
callable
,
const
std
::
vector
<
VarBase
*>&
ins
);
};
// infer var type context for imperative mode
class
PYBIND11_HIDDEN
RuntimeInferVarTypeContext
:
public
framework
::
InferVarTypeContext
{
public:
RuntimeInferVarTypeContext
(
imperative
::
OpBase
*
op
,
const
imperative
::
VarBasePtrMap
*
inputs
,
imperative
::
VarBasePtrMap
*
outputs
,
const
framework
::
AttributeMap
*
attrs_map
)
:
InferVarTypeContext
(
nullptr
,
nullptr
),
op_
(
op
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs_map
),
input_names_
(),
output_names_
(),
var_set_
()
{
input_names_
.
reserve
(
inputs_
->
size
());
for
(
auto
&
it
:
*
inputs_
)
{
for
(
imperative
::
VarBase
*
var
:
it
.
second
)
{
input_names_
[
it
.
first
].
emplace_back
(
var
->
Name
());
var_set_
[
var
->
Name
()]
=
var
;
}
}
output_names_
.
reserve
(
outputs_
->
size
());
for
(
auto
&
it
:
*
outputs_
)
{
for
(
imperative
::
VarBase
*
var
:
it
.
second
)
{
output_names_
[
it
.
first
].
emplace_back
(
var
->
Name
());
var_set_
[
var
->
Name
()]
=
var
;
}
}
}
framework
::
Attribute
GetAttr
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
attrs_
);
return
attrs_
->
at
(
name
);
}
inline
bool
HasVar
(
const
std
::
string
&
name
)
const
{
return
var_set_
.
count
(
name
)
>
0
;
}
inline
bool
HasInput
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
inputs_
);
return
inputs_
->
count
(
name
)
>
0
;
}
inline
bool
HasOutput
(
const
std
::
string
&
name
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
outputs_
);
return
outputs_
->
count
(
name
)
>
0
;
}
inline
const
std
::
vector
<
std
::
string
>&
Input
(
const
std
::
string
&
name
)
const
{
return
input_names_
.
at
(
name
);
}
inline
const
std
::
vector
<
std
::
string
>&
Output
(
const
std
::
string
&
name
)
const
{
return
output_names_
.
at
(
name
);
}
inline
framework
::
proto
::
VarType
::
Type
GetType
(
const
std
::
string
&
name
)
const
{
return
var_set_
.
at
(
name
)
->
DType
();
}
inline
void
SetType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
{
var_set_
[
name
]
->
SetDType
(
type
);
}
inline
framework
::
proto
::
VarType
::
Type
GetDataType
(
const
std
::
string
&
name
)
const
{
return
var_set_
.
at
(
name
)
->
DType
();
}
inline
void
SetDataType
(
const
std
::
string
&
name
,
framework
::
proto
::
VarType
::
Type
type
)
{
var_set_
[
name
]
->
SetDType
(
type
);
}
inline
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>
GetDataTypes
(
const
std
::
string
&
name
)
const
{
PADDLE_THROW
(
"GetDataTypes is not supported in runtime InferVarType"
);
}
inline
void
SetDataTypes
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
proto
::
VarType
::
Type
>&
multiple_data_type
)
{
PADDLE_THROW
(
"SetDataTypes is not supported in runtime InferVarType"
);
}
inline
std
::
vector
<
int64_t
>
GetShape
(
const
std
::
string
&
name
)
const
{
PADDLE_THROW
(
"Do not handle Shape in runtime InferVarType"
);
}
inline
void
SetShape
(
const
std
::
string
&
name
,
const
std
::
vector
<
int64_t
>&
dims
)
{
PADDLE_THROW
(
"Do not handle Shape in runtime InferVarType"
);
}
inline
int32_t
GetLoDLevel
(
const
std
::
string
&
name
)
const
{
PADDLE_THROW
(
"Do not handle LoDLevel in runtime InferVarType"
);
}
inline
void
SetLoDLevel
(
const
std
::
string
&
name
,
int32_t
lod_level
)
{
PADDLE_THROW
(
"Do not handle LoDLevel in runtime InferVarType"
);
}
private:
imperative
::
OpBase
*
op_
;
const
imperative
::
VarBasePtrMap
*
inputs_
;
imperative
::
VarBasePtrMap
*
outputs_
;
const
framework
::
AttributeMap
*
attrs_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
input_names_
;
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
output_names_
;
std
::
unordered_map
<
std
::
string
,
imperative
::
VarBase
*>
var_set_
;
};
};
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/imperative/tracer.cc
浏览文件 @
438bca9c
...
@@ -19,6 +19,7 @@
...
@@ -19,6 +19,7 @@
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
#include "paddle/fluid/framework/var_type_inference.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
...
@@ -160,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
...
@@ -160,7 +161,7 @@ Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {
}
}
std
::
set
<
std
::
string
>
Tracer
::
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
std
::
set
<
std
::
string
>
Tracer
::
Trace
(
OpBase
*
op
,
const
VarBasePtrMap
&
inputs
,
const
VarBasePtrMap
&
outputs
,
VarBasePtrMap
&
outputs
,
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
Place
expected_place
,
const
platform
::
Place
expected_place
,
const
bool
stop_gradient
)
{
const
bool
stop_gradient
)
{
...
@@ -228,6 +229,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -228,6 +229,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
framework
::
OpRegistry
::
CreateOp
(
op
->
Type
(),
invars_name_map
,
framework
::
OpRegistry
::
CreateOp
(
op
->
Type
(),
invars_name_map
,
outvars_name_map
,
attrs_map
);
outvars_name_map
,
attrs_map
);
if
(
info
.
infer_var_type_
)
{
RuntimeInferVarTypeContext
infer_var_type_ctx
(
op
,
&
inputs
,
&
outputs
,
&
attrs_map
);
info
.
infer_var_type_
(
infer_var_type_ctx
);
}
// TODO(minqiyang): Support infer var type in imperative mode
// TODO(minqiyang): Support infer var type in imperative mode
// Run forward op
// Run forward op
VLOG
(
3
)
<<
"tracer running "
<<
op
->
Type
();
VLOG
(
3
)
<<
"tracer running "
<<
op
->
Type
();
...
@@ -278,12 +285,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -278,12 +285,12 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto
fwd_var_it
=
current_vars_map
.
find
(
grad_invar
);
auto
fwd_var_it
=
current_vars_map
.
find
(
grad_invar
);
PADDLE_ENFORCE
(
fwd_var_it
!=
current_vars_map
.
end
());
PADDLE_ENFORCE
(
fwd_var_it
!=
current_vars_map
.
end
());
// Forward inputs or outputs.
// Forward inputs or outputs.
grad_in_vars
.
emplace_back
(
fwd_var_it
->
second
->
var_
);
grad_in_vars
.
emplace_back
(
fwd_var_it
->
second
);
}
else
{
}
else
{
VarBase
*
var
=
current_vars_map
[
var_it
->
second
];
VarBase
*
var
=
current_vars_map
[
var_it
->
second
];
InitGrad
(
var
,
prepared_op
.
GetDeviceContext
());
InitGrad
(
var
,
prepared_op
.
GetDeviceContext
());
// Douts.
// Douts.
grad_in_vars
.
emplace_back
(
var
->
grads_
->
var_
);
grad_in_vars
.
emplace_back
(
var
->
grads_
);
}
}
vars_saved_for_backward
.
insert
(
it
.
first
);
vars_saved_for_backward
.
insert
(
it
.
first
);
...
@@ -300,7 +307,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
...
@@ -300,7 +307,7 @@ std::set<std::string> Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
op
->
Type
());
op
->
Type
());
VarBase
*
var
=
current_vars_map
[
var_it
->
second
];
VarBase
*
var
=
current_vars_map
[
var_it
->
second
];
InitGrad
(
var
,
prepared_op
.
GetDeviceContext
());
InitGrad
(
var
,
prepared_op
.
GetDeviceContext
());
grad_out_vars
.
push_back
(
var
->
grads_
->
var_
);
grad_out_vars
.
push_back
(
var
->
grads_
);
}
}
}
}
}
}
...
@@ -342,23 +349,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
...
@@ -342,23 +349,23 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
auto
&
grad_output_vars
=
auto
&
grad_output_vars
=
op
->
grad_output_vars_
[
0
][
framework
::
GradVarName
(
PyLayer
::
kFwdOut
)];
op
->
grad_output_vars_
[
0
][
framework
::
GradVarName
(
PyLayer
::
kFwdOut
)];
for
(
const
VarBase
*
inp
:
inputs
)
{
for
(
VarBase
*
inp
:
inputs
)
{
grad_input_vars
.
push_back
(
inp
->
var_
);
grad_input_vars
.
push_back
(
inp
);
}
}
for
(
VarBase
*
out
:
outputs
)
{
for
(
VarBase
*
out
:
outputs
)
{
grad_input_vars
.
push_back
(
out
->
var_
);
grad_input_vars
.
push_back
(
out
);
}
}
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
// TODO(minqiyang): Add GPU support for PyLayer, only support CPU now
platform
::
CPUPlace
place
;
platform
::
CPUPlace
place
;
for
(
VarBase
*
out
:
outputs
)
{
for
(
VarBase
*
out
:
outputs
)
{
InitGrad
(
out
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
InitGrad
(
out
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
grad_input_vars
.
push_back
(
out
->
grads_
->
var_
);
grad_input_vars
.
push_back
(
out
->
grads_
);
}
}
for
(
VarBase
*
inp
:
inputs
)
{
for
(
VarBase
*
inp
:
inputs
)
{
InitGrad
(
inp
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
InitGrad
(
inp
,
platform
::
DeviceContextPool
::
Instance
().
Get
(
place
));
grad_output_vars
.
push_back
(
inp
->
grads_
->
var_
);
grad_output_vars
.
push_back
(
inp
->
grads_
);
}
}
}
}
return
outputs
;
return
outputs
;
...
...
paddle/fluid/imperative/type_defs.h
浏览文件 @
438bca9c
...
@@ -25,6 +25,7 @@ class VarBase;
...
@@ -25,6 +25,7 @@ class VarBase;
class
OpBase
;
class
OpBase
;
typedef
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>
VarBasePtrMap
;
typedef
std
::
map
<
std
::
string
,
std
::
vector
<
VarBase
*>>
VarBasePtrMap
;
typedef
std
::
map
<
std
::
string
,
std
::
vector
<
const
VarBase
*>>
ConstVarBasePtrMap
;
typedef
std
::
map
<
std
::
string
,
std
::
vector
<
OpBase
*>>
OpBasePtrMap
;
typedef
std
::
map
<
std
::
string
,
std
::
vector
<
OpBase
*>>
OpBasePtrMap
;
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/operators/controlflow/while_op.cc
浏览文件 @
438bca9c
...
@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
...
@@ -365,19 +365,16 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
WhileGradOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
p_names
=
ctx
.
Input
(
kX
);
auto
p_names
=
op_desc
.
Input
(
kX
);
auto
pg_ig_names
=
ctx
.
Output
(
framework
::
GradVarName
(
kX
));
auto
pg_ig_names
=
op_desc
.
Output
(
framework
::
GradVarName
(
kX
));
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
p_names
.
size
();
++
i
)
{
auto
&
p_var
=
detail
::
Ref
(
block
->
FindVarRecursive
(
p_names
[
i
]));
if
(
ctx
.
HasVar
(
pg_ig_names
[
i
]))
{
auto
*
g_var
=
block
->
FindVarRecursive
(
pg_ig_names
[
i
]);
if
(
g_var
!=
nullptr
)
{
// Gradient could be @EMPTY@
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
VLOG
(
5
)
<<
"Setting "
<<
pg_ig_names
[
i
]
<<
" following "
<<
p_names
[
i
]
<<
" type: "
<<
p_var
.
GetType
(
);
<<
" type: "
<<
ctx
.
GetType
(
p_names
[
i
]
);
g_var
->
SetType
(
p_var
.
GetType
(
));
ctx
.
SetType
(
pg_ig_names
[
i
],
ctx
.
GetType
(
p_names
[
i
]
));
g_var
->
SetDataType
(
p_var
.
GetDataType
(
));
ctx
.
SetDataType
(
pg_ig_names
[
i
],
ctx
.
GetDataType
(
p_names
[
i
]
));
}
}
}
}
}
}
...
...
paddle/fluid/operators/distributed_ops/split_ids_op.cc
浏览文件 @
438bca9c
...
@@ -14,6 +14,8 @@ limitations under the License. */
...
@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include "paddle/fluid/operators/distributed_ops/split_ids_op.h"
#include <memory>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/nccl/nccl_op.cc
浏览文件 @
438bca9c
...
@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -60,12 +60,9 @@ class NCCLInitOp : public framework::OperatorBase {
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
ctx
.
Output
(
"Communicator"
).
front
();
auto
out_var_name
=
op_desc
.
Output
(
"Communicator"
).
front
();
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
RAW
);
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
RAW
;
out_var
.
SetType
(
var_type
);
}
}
};
};
...
...
paddle/fluid/operators/py_func_op.cc
浏览文件 @
438bca9c
...
@@ -14,8 +14,11 @@
...
@@ -14,8 +14,11 @@
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include <memory>
#include <set>
#include <set>
#include <string>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
...
paddle/fluid/operators/reader/create_custom_reader_op.cc
浏览文件 @
438bca9c
...
@@ -123,7 +123,7 @@ class CustomReaderInferShape : public framework::InferShapeBase {
...
@@ -123,7 +123,7 @@ class CustomReaderInferShape : public framework::InferShapeBase {
class
CustomReaderInferVarType
:
public
framework
::
VarTypeInference
{
class
CustomReaderInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
auto
&
out_var_name
=
ctx
.
Output
(
"Out"
)[
0
];
auto
&
out_var_name
=
ctx
.
Output
(
"Out"
)[
0
];
PADDLE_ENFORCE
(
ctx
.
HasVar
(
out_var_name
));
PADDLE_ENFORCE
(
ctx
.
HasVar
(
out_var_name
));
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
READER
);
ctx
.
SetType
(
out_var_name
,
framework
::
proto
::
VarType
::
READER
);
...
...
paddle/fluid/operators/reader/read_op.cc
浏览文件 @
438bca9c
...
@@ -51,7 +51,7 @@ class ReadInferShape : public framework::InferShapeBase {
...
@@ -51,7 +51,7 @@ class ReadInferShape : public framework::InferShapeBase {
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
class
ReadInferVarType
:
public
framework
::
VarTypeInference
{
public:
public:
void
operator
()(
const
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
void
operator
()(
framework
::
InferVarTypeContext
&
ctx
)
const
override
{
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
.
GetAttr
(
"infer_out"
));
bool
infer_out
=
boost
::
get
<
bool
>
(
ctx
.
GetAttr
(
"infer_out"
));
if
(
infer_out
)
{
if
(
infer_out
)
{
std
::
string
reader_name
=
ctx
.
Input
(
"Reader"
)[
0
];
std
::
string
reader_name
=
ctx
.
Input
(
"Reader"
)[
0
];
...
...
paddle/fluid/operators/reader/reader_op_registry.cc
浏览文件 @
438bca9c
...
@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
...
@@ -98,11 +98,10 @@ void FileReaderInferShape::operator()(framework::InferShapeContext* ctx) const {
}
}
}
}
void
FileReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
FileReaderInferVarType
::
operator
()(
framework
::
BlockDesc
*
block
)
const
{
framework
::
InferVarTypeContext
&
ctx
)
const
{
std
::
string
reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
std
::
string
reader_name
=
ctx
.
Output
(
"Out"
)[
0
];
framework
::
VarDesc
*
reader
=
block
->
FindVarRecursive
(
reader_name
);
ctx
.
SetType
(
reader_name
,
framework
::
proto
::
VarType
::
READER
);
reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
}
}
void
DecoratedReaderInferShape
::
operator
()(
void
DecoratedReaderInferShape
::
operator
()(
...
@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()(
...
@@ -125,13 +124,11 @@ void DecoratedReaderInferShape::operator()(
}
}
void
DecoratedReaderInferVarType
::
operator
()(
void
DecoratedReaderInferVarType
::
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
{
framework
::
InferVarTypeContext
&
ctx
)
const
{
std
::
string
in_reader_name
=
op_desc
.
Input
(
"UnderlyingReader"
)[
0
];
const
std
::
string
&
in_reader_name
=
ctx
.
Input
(
"UnderlyingReader"
)[
0
];
framework
::
VarDesc
*
in_reader
=
block
->
FindVarRecursive
(
in_reader_name
);
const
std
::
string
&
out_reader_name
=
ctx
.
Output
(
"Out"
)[
0
];
std
::
string
out_reader_name
=
op_desc
.
Output
(
"Out"
)[
0
];
ctx
.
SetType
(
out_reader_name
,
framework
::
proto
::
VarType
::
READER
);
framework
::
VarDesc
*
out_reader
=
block
->
FindVarRecursive
(
out_reader_name
);
ctx
.
SetDataTypes
(
out_reader_name
,
ctx
.
GetDataTypes
(
in_reader_name
));
out_reader
->
SetType
(
framework
::
proto
::
VarType
::
READER
);
out_reader
->
SetDataTypes
(
in_reader
->
GetDataTypes
());
}
}
void
DecoratedReaderMakerBase
::
Make
()
{
void
DecoratedReaderMakerBase
::
Make
()
{
...
...
paddle/fluid/operators/reader/reader_op_registry.h
浏览文件 @
438bca9c
...
@@ -14,7 +14,9 @@
...
@@ -14,7 +14,9 @@
#pragma once
#pragma once
#include <memory>
#include <string>
#include <string>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/reader.h"
...
...
paddle/fluid/operators/scale_op.cc
浏览文件 @
438bca9c
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/scale_op.h"
#include "paddle/fluid/operators/scale_op.h"
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
...
...
paddle/fluid/operators/split_selected_rows_op.cc
浏览文件 @
438bca9c
...
@@ -14,6 +14,8 @@ limitations under the License. */
...
@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/split_selected_rows_op.h"
#include "paddle/fluid/operators/split_selected_rows_op.h"
#include <memory>
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
...
paddle/fluid/operators/sum_op.cc
浏览文件 @
438bca9c
...
@@ -12,6 +12,7 @@ limitations under the License. */
...
@@ -12,6 +12,7 @@ limitations under the License. */
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/operators/sum_op.h"
#include <algorithm>
#include <algorithm>
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
...
paddle/fluid/pybind/imperative.cc
浏览文件 @
438bca9c
...
@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
...
@@ -38,7 +38,7 @@ void BindTracer(pybind11::module* m) {
.
def
(
"trace"
,
.
def
(
"trace"
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
const
imperative
::
VarBasePtrMap
&
inputs
,
const
imperative
::
VarBasePtrMap
&
inputs
,
const
imperative
::
VarBasePtrMap
&
outputs
,
imperative
::
VarBasePtrMap
&
outputs
,
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
CPUPlace
expected_place
,
const
platform
::
CPUPlace
expected_place
,
const
bool
stop_gradient
=
false
)
{
const
bool
stop_gradient
=
false
)
{
...
@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) {
...
@@ -48,7 +48,7 @@ void BindTracer(pybind11::module* m) {
.
def
(
"trace"
,
.
def
(
"trace"
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
[](
imperative
::
Tracer
&
self
,
imperative
::
OpBase
*
op
,
const
imperative
::
VarBasePtrMap
&
inputs
,
const
imperative
::
VarBasePtrMap
&
inputs
,
const
imperative
::
VarBasePtrMap
&
outputs
,
imperative
::
VarBasePtrMap
&
outputs
,
framework
::
AttributeMap
attrs_map
,
framework
::
AttributeMap
attrs_map
,
const
platform
::
CUDAPlace
expected_place
,
const
platform
::
CUDAPlace
expected_place
,
const
bool
stop_gradient
=
false
)
{
const
bool
stop_gradient
=
false
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录