Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
59b3df31
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
59b3df31
编写于
8月 20, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extract OpInfo into a library
Fix cycle dependencies, Fix
#3583
.
上级
0d9846f3
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
152 addition
and
95 deletion
+152
-95
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+2
-2
paddle/framework/backward_test.cc
paddle/framework/backward_test.cc
+2
-2
paddle/framework/grad_op_builder.cc
paddle/framework/grad_op_builder.cc
+10
-10
paddle/framework/op_info.cc
paddle/framework/op_info.cc
+30
-0
paddle/framework/op_info.h
paddle/framework/op_info.h
+42
-0
paddle/framework/op_registry.cc
paddle/framework/op_registry.cc
+18
-19
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+8
-27
paddle/framework/operator.cc
paddle/framework/operator.cc
+4
-4
paddle/framework/operator.h
paddle/framework/operator.h
+14
-13
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+5
-4
paddle/framework/pybind.cc
paddle/framework/pybind.cc
+1
-1
paddle/operators/net_op.cc
paddle/operators/net_op.cc
+2
-3
paddle/operators/net_op.h
paddle/operators/net_op.h
+4
-2
paddle/operators/recurrent_op.cc
paddle/operators/recurrent_op.cc
+4
-4
paddle/operators/recurrent_op.h
paddle/operators/recurrent_op.h
+6
-4
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
59b3df31
...
...
@@ -18,8 +18,8 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library
(
framework_proto SRCS framework.proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
operator SRCS operator.cc DEPS
framework_proto device_context tensor scope attribut
e
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
cc_library
(
operator SRCS operator.cc DEPS
op_info device_context tensor scop
e
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
grad_op_builder SRCS grad_op_builder.cc DEPS operator
)
...
...
paddle/framework/backward_test.cc
浏览文件 @
59b3df31
...
...
@@ -72,8 +72,8 @@ class NoGradOpMaker : public OpProtoAndCheckerMaker {
class
FcOp
:
public
operators
::
NetOp
{
public:
FcOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
FcOp
(
const
std
::
string
&
type
,
const
Var
iable
NameMap
&
inputs
,
const
Var
iable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
NetOp
(
type
,
inputs
,
outputs
,
attrs
)
{
AddOp
(
OpRegistry
::
CreateOp
(
"mul"
,
{{
"X"
,
{
Input
(
"X"
)}},
{
"Y"
,
{
Input
(
"W"
)}}},
...
...
paddle/framework/grad_op_builder.cc
浏览文件 @
59b3df31
...
...
@@ -20,11 +20,11 @@ namespace framework {
enum
class
OpArgType
{
IN
,
OUT
};
static
void
TransOpArg
(
const
OperatorBase
*
src_op
,
const
OpArgType
&
src_type
,
bool
is_grad
,
OperatorBase
::
Var
NameMap
*
vars
)
{
bool
is_grad
,
Variable
NameMap
*
vars
)
{
const
auto
&
src_inout
=
src_type
==
OpArgType
::
IN
?
src_op
->
Inputs
()
:
src_op
->
Outputs
();
auto
&
dst_inout
=
*
vars
;
const
OpProto
*
proto
=
Op
Registry
::
op_info_m
ap
().
at
(
src_op
->
Type
()).
proto_
;
const
OpProto
*
proto
=
Op
InfoM
ap
().
at
(
src_op
->
Type
()).
proto_
;
const
auto
&
src_arg_list
=
src_type
==
OpArgType
::
IN
?
proto
->
inputs
()
:
proto
->
outputs
();
for
(
const
auto
&
arg
:
src_arg_list
)
{
...
...
@@ -40,25 +40,25 @@ static void TransOpArg(const OperatorBase* src_op, const OpArgType& src_type,
}
OperatorBase
*
BuildGradOp
(
const
OperatorBase
*
op
)
{
auto
it
=
Op
Registry
::
op_info_m
ap
().
find
(
op
->
Type
());
PADDLE_ENFORCE
(
it
!=
Op
Registry
::
op_info_map
().
end
()
,
"'%s' has not been registered."
,
op
->
Type
());
auto
it
=
Op
InfoM
ap
().
find
(
op
->
Type
());
PADDLE_ENFORCE
(
it
!=
Op
InfoMap
().
end
(),
"'%s' has not been registered."
,
op
->
Type
());
PADDLE_ENFORCE
(
it
->
second
.
proto_
!=
nullptr
,
"'%s' has no OpProto."
,
op
->
Type
());
std
::
string
grad_op_type
=
it
->
second
.
grad_op_type_
;
PADDLE_ENFORCE
(
!
grad_op_type
.
empty
(),
"'%s' has no gradient operator."
,
op
->
Type
());
OperatorBase
::
Var
NameMap
inputs
;
OperatorBase
::
Var
NameMap
outputs
;
Variable
NameMap
inputs
;
Variable
NameMap
outputs
;
TransOpArg
(
op
,
OpArgType
::
IN
,
false
,
&
inputs
);
// I
TransOpArg
(
op
,
OpArgType
::
OUT
,
false
,
&
inputs
);
// O
TransOpArg
(
op
,
OpArgType
::
OUT
,
true
,
&
inputs
);
// OG
TransOpArg
(
op
,
OpArgType
::
IN
,
true
,
&
outputs
);
// IG
it
=
Op
Registry
::
op_info_m
ap
().
find
(
grad_op_type
);
PADDLE_ENFORCE
(
it
!=
Op
Registry
::
op_info_map
().
end
()
,
"'%s' has not been registered."
,
grad_op_type
);
it
=
Op
InfoM
ap
().
find
(
grad_op_type
);
PADDLE_ENFORCE
(
it
!=
Op
InfoMap
().
end
(),
"'%s' has not been registered."
,
grad_op_type
);
return
it
->
second
.
creator_
(
grad_op_type
,
inputs
,
outputs
,
op
->
Attrs
());
}
...
...
paddle/framework/op_info.cc
0 → 100644
浏览文件 @
59b3df31
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_info.h"
namespace
paddle
{
namespace
framework
{
static
std
::
unordered_map
<
std
::
string
,
const
paddle
::
framework
::
OpInfo
>*
g_op_info_map
=
nullptr
;
std
::
unordered_map
<
std
::
string
,
const
paddle
::
framework
::
OpInfo
>&
OpInfoMap
()
{
if
(
g_op_info_map
==
nullptr
)
{
g_op_info_map
=
new
std
::
unordered_map
<
std
::
string
,
const
paddle
::
framework
::
OpInfo
>
();
}
return
*
g_op_info_map
;
}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_info.h
0 → 100644
浏览文件 @
59b3df31
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include <string>
#include <unordered_map>
#include "paddle/framework/attribute.h"
namespace
paddle
{
namespace
framework
{
class
OperatorBase
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
(
const
std
::
string
&
/*type*/
,
const
VariableNameMap
&
/*inputs*/
,
const
VariableNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
struct
OpInfo
{
OpCreator
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
extern
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
OpInfoMap
();
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.cc
浏览文件 @
59b3df31
...
...
@@ -19,32 +19,20 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
AttributeMap
attrs
)
{
auto
it
=
op_info_map
().
find
(
type
);
PADDLE_ENFORCE
(
it
!=
op_info_map
().
end
(),
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
AttributeMap
attrs
)
{
auto
it
=
OpInfoMap
().
find
(
type
);
PADDLE_ENFORCE
(
it
!=
OpInfoMap
().
end
(),
"Operator '%s' has not been registered."
,
type
);
it
->
second
.
checker_
->
Check
(
attrs
);
auto
op
=
it
->
second
.
creator_
(
type
,
inputs
,
outputs
,
attrs
);
return
std
::
unique_ptr
<
OperatorBase
>
(
op
);
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
OpDesc
&
op_desc
)
{
VarNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VarNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
}
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
OperatorBase
::
VarNameMap
OpRegistry
::
ConvertOpDescVarsToVarNameMap
(
static
VariableNameMap
ConvertOpDescVarsToVarNameMap
(
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
)
{
VarNameMap
ret_val
;
Var
iable
NameMap
ret_val
;
for
(
auto
&
var
:
op_desc_vars
)
{
auto
&
var_names
=
ret_val
[
var
.
parameter
()];
auto
&
var_names_in_proto
=
var
.
arguments
();
...
...
@@ -55,6 +43,17 @@ OperatorBase::VarNameMap OpRegistry::ConvertOpDescVarsToVarNameMap(
return
ret_val
;
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateOp
(
const
OpDesc
&
op_desc
)
{
VariableNameMap
inputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
inputs
());
VariableNameMap
outputs
=
ConvertOpDescVarsToVarNameMap
(
op_desc
.
outputs
());
AttributeMap
attrs
;
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
attrs
[
attr
.
name
()]
=
GetAttrValue
(
attr
);
}
return
CreateOp
(
op_desc
.
type
(),
inputs
,
outputs
,
attrs
);
}
std
::
unique_ptr
<
OperatorBase
>
OpRegistry
::
CreateGradOp
(
const
OperatorBase
&
op
)
{
PADDLE_ENFORCE
(
!
op
.
IsNetOp
(),
"Use framework::Backward to get backward ops"
);
return
std
::
unique_ptr
<
OperatorBase
>
(
BuildGradOp
(
&
op
));
...
...
paddle/framework/op_registry.h
浏览文件 @
59b3df31
...
...
@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_info.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
...
...
@@ -30,28 +31,16 @@ namespace paddle {
namespace
framework
{
class
OpRegistry
{
using
VarNameMap
=
OperatorBase
::
VarNameMap
;
using
OpCreator
=
std
::
function
<
OperatorBase
*
(
const
std
::
string
&
/*type*/
,
const
VarNameMap
&
/*inputs*/
,
const
VarNameMap
&
/*outputs*/
,
const
AttributeMap
&
/*attrs*/
)
>
;
public:
struct
OpInfo
{
OpCreator
creator_
;
std
::
string
grad_op_type_
;
OpProto
*
proto_
;
OpAttrChecker
*
checker_
;
};
template
<
typename
OpType
,
typename
ProtoMakerType
,
typename
GradOpType
>
static
void
RegisterOp
(
const
std
::
string
&
op_type
,
const
std
::
string
&
grad_op_type
)
{
PADDLE_ENFORCE
(
op_info_m
ap
().
count
(
op_type
)
==
0
,
PADDLE_ENFORCE
(
OpInfoM
ap
().
count
(
op_type
)
==
0
,
"'%s' is registered more than once."
,
op_type
);
OpInfo
op_info
;
op_info
.
creator_
=
[](
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
out
puts
,
const
AttributeMap
&
attrs
)
{
op_info
.
creator_
=
[](
const
std
::
string
&
type
,
const
VariableNameMap
&
in
puts
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
return
new
OpType
(
type
,
inputs
,
outputs
,
attrs
);
};
op_info
.
grad_op_type_
=
grad_op_type
;
...
...
@@ -70,7 +59,7 @@ class OpRegistry {
op_info
.
proto_
=
nullptr
;
op_info
.
checker_
=
nullptr
;
}
op_info_m
ap
().
insert
(
std
::
make_pair
(
op_type
,
op_info
));
OpInfoM
ap
().
insert
(
std
::
make_pair
(
op_type
,
op_info
));
// register gradient op
if
(
!
grad_op_type
.
empty
())
{
RegisterOp
<
GradOpType
,
NOPMaker
,
NOP
>
(
grad_op_type
,
""
);
...
...
@@ -78,21 +67,13 @@ class OpRegistry {
}
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
Var
iable
NameMap
&
inputs
,
const
Var
iable
NameMap
&
outputs
,
AttributeMap
attrs
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateOp
(
const
OpDesc
&
op_desc
);
static
VarNameMap
ConvertOpDescVarsToVarNameMap
(
const
google
::
protobuf
::
RepeatedPtrField
<
OpDesc
::
Var
>&
op_desc_vars
);
static
std
::
unique_ptr
<
OperatorBase
>
CreateGradOp
(
const
OperatorBase
&
op
);
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
op_info_map
()
{
static
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>
op_info_map_
;
return
op_info_map_
;
}
};
class
Registrar
{
...
...
paddle/framework/operator.cc
浏览文件 @
59b3df31
...
...
@@ -115,8 +115,8 @@ void OperatorBase::Rename(const std::string& old_name,
}
OperatorBase
::
OperatorBase
(
const
std
::
string
&
type
,
const
OperatorBase
::
Var
NameMap
&
inputs
,
const
OperatorBase
::
Var
NameMap
&
outputs
,
const
Variable
NameMap
&
inputs
,
const
Variable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
type_
(
type
),
inputs_
(
inputs
),
outputs_
(
outputs
),
attrs_
(
attrs
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
...
...
@@ -141,9 +141,9 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
}
return
ret_val
;
}
auto
it
=
Op
Registry
::
op_info_m
ap
().
find
(
type_
);
auto
it
=
Op
InfoM
ap
().
find
(
type_
);
PADDLE_ENFORCE
(
it
!=
Op
Registry
::
op_info_m
ap
().
end
(),
it
!=
Op
InfoM
ap
().
end
(),
"Operator %s not registered, cannot figure out intermediate outputs"
,
type_
);
PADDLE_ENFORCE
(
...
...
paddle/framework/operator.h
浏览文件 @
59b3df31
...
...
@@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/scope.h"
...
...
@@ -62,10 +63,8 @@ class ExecutionContext;
*/
class
OperatorBase
{
public:
using
VarNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
OperatorBase
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
OperatorBase
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
);
virtual
~
OperatorBase
()
{}
...
...
@@ -93,8 +92,8 @@ class OperatorBase {
/// rename inputs outputs name
void
Rename
(
const
std
::
string
&
old_name
,
const
std
::
string
&
new_name
);
const
VarNameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
VarNameMap
&
Outputs
()
const
{
return
outputs_
;
}
const
Var
iable
NameMap
&
Inputs
()
const
{
return
inputs_
;
}
const
Var
iable
NameMap
&
Outputs
()
const
{
return
outputs_
;
}
//! Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
//! Get a input which has multiple variables.
...
...
@@ -122,11 +121,11 @@ class OperatorBase {
// I (Inputs)opear
// O (Outputs)
// OG (Output Gradients)
VarNameMap
inputs_
;
Var
iable
NameMap
inputs_
;
// NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients)
VarNameMap
outputs_
;
Var
iable
NameMap
outputs_
;
AttributeMap
attrs_
;
};
...
...
@@ -142,9 +141,11 @@ class OperatorBase {
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, const VarNameMap& inputs, \
const VarNameMap& outputs, const paddle::framework::AttributeMap& attrs) \
#define DEFINE_OP_CONSTRUCTOR(CLS, PARENT_CLS) \
CLS(const std::string& type, \
const ::paddle::framework::VariableNameMap& inputs, \
const ::paddle::framework::VariableNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: PARENT_CLS(type, inputs, outputs, attrs) {}
class
NOP
:
public
OperatorBase
{
...
...
@@ -389,8 +390,8 @@ class OperatorWithKernel : public OperatorBase {
using
OpKernelMap
=
std
::
unordered_map
<
OpKernelKey
,
std
::
unique_ptr
<
OpKernel
>
,
OpKernelHash
>
;
OperatorWithKernel
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
OperatorWithKernel
(
const
std
::
string
&
type
,
const
Var
iable
NameMap
&
inputs
,
const
Var
iable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{
...
...
paddle/framework/operator_test.cc
浏览文件 @
59b3df31
...
...
@@ -23,8 +23,8 @@ static int op_run_num = 0;
class
OpWithoutKernelTest
:
public
OperatorBase
{
public:
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
OpWithoutKernelTest
(
const
std
::
string
&
type
,
const
Var
iable
NameMap
&
inputs
,
const
Var
iable
NameMap
&
outputs
,
const
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
),
x
(
1
)
{}
void
InferShape
(
const
Scope
&
scope
)
const
override
{}
void
Run
(
const
Scope
&
scope
,
...
...
@@ -249,8 +249,9 @@ TEST(OpKernel, multi_inputs) {
class
OperatorClone
:
public
paddle
::
framework
::
OperatorBase
{
public:
DEFINE_OP_CLONE_METHOD
(
OperatorClone
);
OperatorClone
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
OperatorClone
(
const
std
::
string
&
type
,
const
paddle
::
framework
::
VariableNameMap
&
inputs
,
const
paddle
::
framework
::
VariableNameMap
&
outputs
,
const
paddle
::
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
InferShape
(
const
paddle
::
framework
::
Scope
&
scope
)
const
override
{}
...
...
paddle/framework/pybind.cc
浏览文件 @
59b3df31
...
...
@@ -138,7 +138,7 @@ All parameter, weight, gradient are variables in Paddle.
//! @note: Be careful! PyBind will return std::string as an unicode, not
//! Python str. If you want a str object, you should cast them in Python.
m
.
def
(
"get_all_op_protos"
,
[]()
->
std
::
vector
<
py
::
bytes
>
{
auto
&
op_info_map
=
Op
Registry
::
op_info_m
ap
();
auto
&
op_info_map
=
Op
InfoM
ap
();
std
::
vector
<
py
::
bytes
>
ret_values
;
for
(
auto
it
=
op_info_map
.
begin
();
it
!=
op_info_map
.
end
();
++
it
)
{
const
OpProto
*
proto
=
it
->
second
.
proto_
;
...
...
paddle/operators/net_op.cc
浏览文件 @
59b3df31
...
...
@@ -81,9 +81,8 @@ std::vector<std::string> NetOp::OutputVars(bool has_intermediate) const {
return
ret_val
;
}
NetOp
::
NetOp
(
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
VarNameMap
&
inputs
,
const
framework
::
OperatorBase
::
VarNameMap
&
outputs
,
NetOp
::
NetOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
framework
::
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
...
...
paddle/operators/net_op.h
浏览文件 @
59b3df31
...
...
@@ -38,8 +38,10 @@ class NetOp : public framework::OperatorBase {
public:
static
const
char
kAll
[];
NetOp
()
:
framework
::
OperatorBase
(
"plain_net"
,
{},
{},
{})
{}
NetOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
NetOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
NetOp
(
const
NetOp
&
o
)
:
framework
::
OperatorBase
(
o
.
type_
,
{},
{},
o
.
attrs_
)
{
this
->
ops_
.
reserve
(
o
.
ops_
.
size
());
...
...
paddle/operators/recurrent_op.cc
浏览文件 @
59b3df31
...
...
@@ -131,8 +131,8 @@ const rnn::ArgumentName RecurrentGradientOp::kArgName{
"memories"
,
"pre_memories"
,
"boot_memories@grad"
};
RecurrentOp
::
RecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
Var
NameMap
&
inputs
,
const
framework
::
OperatorBase
::
Var
NameMap
&
outputs
,
const
framework
::
Variable
NameMap
&
inputs
,
const
framework
::
Variable
NameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
...
...
@@ -223,8 +223,8 @@ void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
}
RecurrentGradientOp
::
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
framework
::
OperatorBase
::
Var
NameMap
&
inputs
,
const
framework
::
OperatorBase
::
Var
NameMap
&
outputs
,
const
std
::
string
&
type
,
const
framework
::
Variable
NameMap
&
inputs
,
const
framework
::
Variable
NameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
)
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{
rnn
::
InitArgument
(
kArgName
,
&
arg_
,
*
this
);
...
...
paddle/operators/recurrent_op.h
浏览文件 @
59b3df31
...
...
@@ -114,8 +114,9 @@ class RecurrentGradientAlgorithm {
class
RecurrentOp
:
public
framework
::
OperatorBase
{
public:
RecurrentOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
RecurrentOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
RecurrentOp
(
const
RecurrentOp
&
o
)
:
framework
::
OperatorBase
(
...
...
@@ -150,8 +151,9 @@ class RecurrentOp : public framework::OperatorBase {
class
RecurrentGradientOp
:
public
framework
::
OperatorBase
{
public:
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
VarNameMap
&
inputs
,
const
VarNameMap
&
outputs
,
RecurrentGradientOp
(
const
std
::
string
&
type
,
const
framework
::
VariableNameMap
&
inputs
,
const
framework
::
VariableNameMap
&
outputs
,
const
framework
::
AttributeMap
&
attrs
);
RecurrentGradientOp
(
const
RecurrentGradientOp
&
o
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录