Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
ad8fa77c
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ad8fa77c
编写于
7月 17, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'develop' into feature/add_some_skeletons_of_ops
上级
5847b96a
df707d04
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
791 addition
and
55 deletion
+791
-55
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+33
-0
paddle/framework/operator.cc
paddle/framework/operator.cc
+71
-15
paddle/framework/operator.h
paddle/framework/operator.h
+79
-27
paddle/framework/operator_test.cc
paddle/framework/operator_test.cc
+108
-8
paddle/operators/add_op.h
paddle/operators/add_op.h
+2
-2
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+17
-0
python/paddle/v2/framework/create_op_creation_methods.py
python/paddle/v2/framework/create_op_creation_methods.py
+235
-0
python/paddle/v2/framework/tests/test_op_creation_methods.py
python/paddle/v2/framework/tests/test_op_creation_methods.py
+241
-2
python/paddle/v2/optimizer.py
python/paddle/v2/optimizer.py
+5
-1
未找到文件。
paddle/framework/op_registry.h
浏览文件 @
ad8fa77c
#pragma once
#pragma once
#include <algorithm>
#include <algorithm>
#include <atomic>
#include <type_traits>
#include <type_traits>
#include <unordered_map>
#include <unordered_map>
#include <unordered_set>
#include <unordered_set>
...
@@ -214,29 +215,61 @@ class OpRegistry {
...
@@ -214,29 +215,61 @@ class OpRegistry {
}
}
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
static
OperatorPtr
CreateOp
(
const
OpDesc
&
op_desc
)
{
//! Create a OpPtr by type.
std
::
string
op_type
=
op_desc
.
type
();
std
::
string
op_type
=
op_desc
.
type
();
OperatorPtr
op
(
creators
().
at
(
op_type
)());
OperatorPtr
op
(
creators
().
at
(
op_type
)());
//! Fill op's data member. Not use constructor because it will be noising
//! for Op developer.
const
OpProto
&
op_proto
=
protos
().
at
(
op_type
);
op
->
type_
=
op_desc
.
type
();
op
->
type_
=
op_desc
.
type
();
// set op's inputs_ from desc.
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
op
->
inputs_
.
reserve
((
size_t
)
op_desc
.
inputs_size
());
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
copy
(
op_desc
.
inputs
().
begin
(),
op_desc
.
inputs
().
end
(),
std
::
back_inserter
(
op
->
inputs_
));
std
::
back_inserter
(
op
->
inputs_
));
// set op's outputs_ from desc.
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
op
->
outputs_
.
reserve
((
size_t
)
op_desc
.
outputs_size
());
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
copy
(
op_desc
.
outputs
().
begin
(),
op_desc
.
outputs
().
end
(),
std
::
back_inserter
(
op
->
outputs_
));
std
::
back_inserter
(
op
->
outputs_
));
//! Fill attrs, and validate attrs.
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
for
(
auto
&
attr
:
op_desc
.
attrs
())
{
op
->
attrs_
[
attr
.
name
()]
=
AttrTypeHelper
::
GetAttrValue
(
attr
);
op
->
attrs_
[
attr
.
name
()]
=
AttrTypeHelper
::
GetAttrValue
(
attr
);
}
}
op_checkers
().
at
(
op_type
).
Check
(
op
->
attrs_
);
op_checkers
().
at
(
op_type
).
Check
(
op
->
attrs_
);
//! Convert Temporary variable name to an unique variable name.
GenerateTempVariableName
(
op
.
get
());
// set argument offsets stored in op.
CreateInOutOffsetMap
(
op
,
op_proto
);
//! Other op's custom Init for a complex Op. For simple Op, the Init
//! method do nothing.
op
->
Init
();
op
->
Init
();
return
op
;
return
op
;
}
}
// init op.in_out_idxs_ to accelerate argument's offset lookup.
static
void
CreateInOutOffsetMap
(
OperatorPtr
op
,
const
OpProto
&
proto
)
{
op
->
CreateInOutOffsetMap
(
proto
);
}
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>&
protos
()
{
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
static
std
::
unordered_map
<
std
::
string
,
OpProto
>
protos_
;
return
protos_
;
return
protos_
;
};
};
private:
private:
static
void
GenerateTempVariableName
(
OperatorBase
*
op
)
{
static
std
::
atomic
<
size_t
>
gUniqId
(
0UL
);
for
(
auto
&
outname
:
op
->
outputs_
)
{
if
(
outname
==
OperatorBase
::
TMP_VAR_NAME
())
{
outname
+=
op
->
type_
;
outname
+=
"@"
;
outname
+=
std
::
to_string
(
gUniqId
.
fetch_add
(
1
));
}
}
}
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>&
creators
()
{
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
static
std
::
unordered_map
<
std
::
string
,
OpCreator
>
creators_
;
return
creators_
;
return
creators_
;
...
...
paddle/framework/operator.cc
浏览文件 @
ad8fa77c
...
@@ -12,30 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,30 +12,86 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include <algorithm>
#include "paddle/framework/operator.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
void
OperatorBase
::
CreateInOutOffsetMap
(
const
OpProto
&
proto
)
{
PADDLE_ENFORCE
(
in_out_idxs_
.
empty
(),
"duplicate call CreateInOutOffsetMap"
);
for
(
int
i
=
0
;
i
<
proto
.
inputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
inputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
for
(
int
i
=
0
;
i
<
proto
.
outputs_size
();
i
++
)
{
const
auto
&
name
=
proto
.
outputs
()[
i
].
name
();
in_out_idxs_
[
name
]
=
i
;
}
}
const
std
::
string
&
OperatorBase
::
Input
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"input_format"
)
==
0
)
{
return
inputs_
[
it
->
second
];
}
else
{
const
auto
&
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
int
idx
=
input_format
[
it
->
second
];
return
inputs_
.
at
(
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Inputs
(
const
std
::
string
&
name
)
const
{
auto
input_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"input_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
inputs_
.
begin
()
+
input_format
.
at
(
offset
),
inputs_
.
begin
()
+
input_format
.
at
(
offset
+
1
)};
}
const
std
::
string
&
OperatorBase
::
Output
(
const
std
::
string
&
name
)
const
{
auto
it
=
in_out_idxs_
.
find
(
name
);
PADDLE_ENFORCE
(
it
!=
in_out_idxs_
.
end
(),
"no key [%s] in in_out_idxs_"
,
name
);
if
(
attrs_
.
count
(
"output_format"
)
==
0
)
{
return
outputs_
[
it
->
second
];
}
else
{
const
auto
&
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
int
idx
=
output_format
[
it
->
second
];
return
outputs_
.
at
(
idx
);
}
}
std
::
vector
<
std
::
string
>
OperatorBase
::
Outputs
(
const
std
::
string
&
name
)
const
{
auto
output_format
=
GetAttr
<
std
::
vector
<
int
>>
(
"output_format"
);
auto
offset
=
in_out_idxs_
.
at
(
name
);
return
std
::
vector
<
std
::
string
>
{
outputs_
.
begin
()
+
output_format
.
at
(
offset
),
outputs_
.
begin
()
+
output_format
.
at
(
offset
+
1
)};
}
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
string
OperatorBase
::
DebugString
()
const
{
std
::
stringstream
ss
;
std
::
stringstream
ss
;
ss
<<
"=================
\n
"
;
ss
<<
"Op("
<<
type_
<<
"), inputs:("
;
ss
<<
"type = "
<<
type_
<<
"
\n
"
;
for
(
size_t
i
=
0
;
i
<
inputs_
.
size
();
++
i
)
{
ss
<<
"inputs = ["
;
ss
<<
inputs_
[
i
];
for
(
auto
&
ipt
:
inputs_
)
{
if
(
i
!=
inputs_
.
size
()
-
1
)
{
ss
<<
ipt
<<
", "
;
ss
<<
", "
;
}
}
ss
<<
"]
\n
"
;
}
ss
<<
"outputs = ["
;
ss
<<
"), outputs:("
;
for
(
auto
&
opt
:
outputs_
)
{
for
(
size_t
i
=
0
;
i
<
outputs_
.
size
();
++
i
)
{
ss
<<
opt
<<
", "
;
ss
<<
outputs_
[
i
];
}
if
(
i
!=
outputs_
.
size
()
-
1
)
{
ss
<<
"]
\n
"
;
ss
<<
", "
;
ss
<<
"attr_keys = ["
;
}
for
(
auto
&
attr
:
attrs_
)
{
}
ss
<<
attr
.
first
<<
", "
;
ss
<<
")."
;
}
ss
<<
"]
\n
"
;
return
ss
.
str
();
return
ss
.
str
();
}
}
...
...
paddle/framework/operator.h
浏览文件 @
ad8fa77c
...
@@ -14,18 +14,20 @@ limitations under the License. */
...
@@ -14,18 +14,20 @@ limitations under the License. */
#pragma once
#pragma once
#include <paddle/framework/attr_checker.h>
#include <paddle/framework/op_desc.pb.h>
#include <paddle/framework/scope.h>
#include <paddle/framework/tensor.h>
#include <paddle/platform/device_context.h>
#include <paddle/platform/place.h>
#include <paddle/utils/Error.h>
#include <boost/variant.hpp>
#include <boost/variant.hpp>
#include <string>
#include <string>
#include <unordered_map>
#include <unordered_map>
#include <vector>
#include <vector>
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -39,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
...
@@ -39,6 +41,13 @@ using OperatorPtr = std::shared_ptr<OperatorBase>;
*/
*/
class
OperatorBase
{
class
OperatorBase
{
public:
public:
/// If a variable is a empty variable, that name will be used.
static
std
::
string
EMPTY_VAR_NAME
()
{
return
"@EMPTY@"
;
}
/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
static
std
::
string
TMP_VAR_NAME
()
{
return
"@TEMP@"
;
}
virtual
~
OperatorBase
()
{}
virtual
~
OperatorBase
()
{}
template
<
typename
T
>
template
<
typename
T
>
...
@@ -62,24 +71,32 @@ class OperatorBase {
...
@@ -62,24 +71,32 @@ class OperatorBase {
virtual
void
Run
(
const
ScopePtr
&
scope
,
virtual
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
const
platform
::
DeviceContext
&
dev_ctx
)
const
=
0
;
// Get a input with argument's name described in `op_proto`
const
std
::
string
&
Input
(
const
std
::
string
&
name
)
const
;
// Get a input which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Inputs
(
const
std
::
string
&
name
)
const
;
// Get a output with argument's name described in `op_proto`
const
std
::
string
&
Output
(
const
std
::
string
&
name
)
const
;
// Get an output which has multiple variables.
// TODO add a vector_view to prevent memory copy.
std
::
vector
<
std
::
string
>
Outputs
(
const
std
::
string
&
name
)
const
;
// init in_out_idxs_ to accelerate argument's offset lookup.
void
CreateInOutOffsetMap
(
const
OpProto
&
proto
);
public:
public:
std
::
string
type_
;
std
::
string
type_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
inputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
std
::
vector
<
std
::
string
>
outputs_
;
AttributeMap
attrs_
;
AttributeMap
attrs_
;
// store the arguments' offset described in op_desc.
std
::
unordered_map
<
std
::
string
,
int
>
in_out_idxs_
;
};
};
class
OpKernel
{
class
KernelContext
{
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
class
KernelContext
{
public:
public:
KernelContext
(
const
OperatorBase
*
op
,
const
ScopePtr
&
scope
,
KernelContext
(
const
OperatorBase
*
op
,
const
std
::
shared_ptr
<
Scope
>
&
scope
,
const
platform
::
DeviceContext
&
device_context
)
const
platform
::
DeviceContext
&
device_context
)
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
:
op_
(
*
op
),
scope_
(
scope
),
device_context_
(
device_context
)
{}
...
@@ -91,11 +108,45 @@ class OpKernel {
...
@@ -91,11 +108,45 @@ class OpKernel {
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
return
scope_
->
GetVariable
(
op_
.
outputs_
[
index
]);
}
}
const
Variable
*
Input
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Input
(
name
));
}
const
Variable
*
Output
(
const
std
::
string
&
name
)
const
{
return
scope_
->
GetVariable
(
op_
.
Output
(
name
));
}
const
std
::
vector
<
const
Variable
*>
Inputs
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Inputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
const
std
::
vector
<
const
Variable
*>
Outputs
(
const
std
::
string
&
name
)
const
{
auto
names
=
op_
.
Outputs
(
name
);
std
::
vector
<
const
Variable
*>
res
;
std
::
transform
(
names
.
begin
(),
names
.
end
(),
res
.
begin
(),
[
this
](
const
std
::
string
&
name
)
{
return
scope_
->
GetVariable
(
name
);
});
return
res
;
}
const
OperatorBase
&
op_
;
const
OperatorBase
&
op_
;
const
ScopePtr
&
scope_
;
const
std
::
shared_ptr
<
Scope
>
&
scope_
;
const
platform
::
DeviceContext
&
device_context_
;
const
platform
::
DeviceContext
&
device_context_
;
};
};
class
OpKernel
{
public:
/**
* KernelContext is the only parameter of Kernel Run function.
* Run will get input/output variables, state such as momentum and
* device resource such as CUDA stream, cublas handle, etc. from
* KernelContext. User should construct it before run the Operator.
*/
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
void
Compute
(
const
KernelContext
&
context
)
const
=
0
;
virtual
~
OpKernel
()
{}
virtual
~
OpKernel
()
{}
...
@@ -140,7 +191,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -140,7 +191,7 @@ class OperatorWithKernel : public OperatorBase {
void
Run
(
const
ScopePtr
&
scope
,
void
Run
(
const
ScopePtr
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
auto
&
opKernel
=
AllOpKernels
().
at
(
type_
).
at
(
OpKernelKey
(
dev_ctx
));
opKernel
->
Compute
(
OpKernel
::
KernelContext
(
this
,
scope
,
dev_ctx
));
opKernel
->
Compute
(
KernelContext
(
this
,
scope
,
dev_ctx
));
}
}
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
...
@@ -148,6 +199,7 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -148,6 +199,7 @@ class OperatorWithKernel : public OperatorBase {
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
static
std
::
unordered_map
<
std
::
string
,
OpKernelMap
>
g_all_op_kernels
;
return
g_all_op_kernels
;
return
g_all_op_kernels
;
}
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
final
{
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
const
Tensor
*>
ins
;
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
VarNamesToTensors
(
scope
,
inputs_
,
&
ins
);
...
...
paddle/framework/operator_test.cc
浏览文件 @
ad8fa77c
...
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
...
@@ -30,7 +30,6 @@ class OpWithoutKernelTest : public OperatorBase {
op_run_num
++
;
op_run_num
++
;
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
inputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
outputs_
.
size
(),
1
);
ASSERT_EQ
((
int
)
outputs_
.
size
(),
1
);
ASSERT_NEAR
(
GetAttr
<
float
>
(
"scale"
),
3.14
,
1e-5
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
...
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
...
@@ -86,9 +85,11 @@ class OpKernelTestProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
public:
public:
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
OpKernelTestProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"input"
,
"input of test op"
);
AddInput
(
"x"
,
"input of test op"
);
AddOutput
(
"output"
,
"output of test op"
);
AddOutput
(
"y"
,
"output of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
AddComment
(
"This is test op"
);
}
}
};
};
...
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
...
@@ -103,11 +104,65 @@ class OpWithKernelTest : public OperatorWithKernel {
class
CPUKernelTest
:
public
OpKernel
{
class
CPUKernelTest
:
public
OpKernel
{
public:
public:
void
Compute
(
const
KernelContext
&
context
)
const
{
void
Compute
(
const
KernelContext
&
ctx
)
const
{
std
::
cout
<<
"this is cpu kernel"
<<
std
::
endl
;
std
::
cout
<<
ctx
.
op_
.
DebugString
()
<<
std
::
endl
;
cpu_kernel_run_num
++
;
cpu_kernel_run_num
++
;
ASSERT_EQ
((
int
)
context
.
op_
.
inputs_
.
size
(),
1
);
ASSERT_EQ
(
ctx
.
op_
.
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
((
int
)
context
.
op_
.
outputs_
.
size
(),
1
);
ASSERT_EQ
(
ctx
.
op_
.
Output
(
"y"
),
"OUT1"
);
ASSERT_NEAR
(
context
.
op_
.
GetAttr
<
float
>
(
"scale"
),
3.14
,
1e-5
);
}
};
// multiple inputs test
class
OperatorMultiInputsTest
:
public
OperatorBase
{
public:
void
Init
()
override
{
x
=
1
;
}
void
InferShape
(
const
std
::
shared_ptr
<
Scope
>&
scope
)
const
override
{}
void
Run
(
const
std
::
shared_ptr
<
Scope
>&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
override
{
ASSERT_EQ
(
scope
->
GetVariable
(
inputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
x
,
1
);
ASSERT_NE
(
scope
->
GetVariable
(
outputs_
[
0
]),
nullptr
);
ASSERT_EQ
(
Input
(
"x"
),
"IN1"
);
ASSERT_EQ
(
Input
(
"y"
),
"OUT1"
);
}
public:
float
x
=
0
;
};
class
OpKernelTestMultiInputsProtoAndCheckerMaker
:
public
OpProtoAndCheckerMaker
{
public:
OpKernelTestMultiInputsProtoAndCheckerMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInputs
(
"xs"
,
"inputs of test op"
);
AddInput
(
"k"
,
"input of test op"
);
AddOutputs
(
"ys"
,
"outputs of test op"
);
AddAttr
<
float
>
(
"scale"
,
"scale of cosine op"
)
.
SetDefault
(
1.0
)
.
LargerThan
(
0.0
);
AddComment
(
"This is test op"
);
}
};
class
CPUKernalMultiInputsTest
:
public
OpKernel
{
public:
void
Compute
(
const
KernelContext
&
ctx
)
const
{
auto
xs
=
ctx
.
op_
.
Inputs
(
"xs"
);
ASSERT_EQ
(
xs
.
size
(),
3UL
);
ASSERT_EQ
(
xs
[
0
],
"x0"
);
ASSERT_EQ
(
xs
[
1
],
"x1"
);
ASSERT_EQ
(
xs
[
2
],
"x2"
);
auto
k
=
ctx
.
op_
.
Input
(
"k"
);
ASSERT_EQ
(
k
,
"k0"
);
auto
ys
=
ctx
.
op_
.
Outputs
(
"ys"
);
ASSERT_EQ
(
ys
.
size
(),
2UL
);
ASSERT_EQ
(
ys
[
0
],
"y0"
);
ASSERT_EQ
(
ys
[
1
],
"y1"
);
}
}
};
};
...
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
...
@@ -118,6 +173,7 @@ REGISTER_OP(op_with_kernel, paddle::framework::OpWithKernelTest,
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
paddle
::
framework
::
OpKernelTestProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
REGISTER_OP_CPU_KERNEL
(
op_with_kernel
,
paddle
::
framework
::
CPUKernelTest
);
// test with single input
TEST
(
OpKernel
,
all
)
{
TEST
(
OpKernel
,
all
)
{
paddle
::
framework
::
OpDesc
op_desc
;
paddle
::
framework
::
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_with_kernel"
);
op_desc
.
set_type
(
"op_with_kernel"
);
...
@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
...
@@ -137,3 +193,47 @@ TEST(OpKernel, all) {
op
->
Run
(
scope
,
cpu_device_context
);
op
->
Run
(
scope
,
cpu_device_context
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
ASSERT_EQ
(
paddle
::
framework
::
cpu_kernel_run_num
,
1
);
}
}
REGISTER_OP
(
op_multi_inputs_with_kernel
,
paddle
::
framework
::
OpWithKernelTest
,
paddle
::
framework
::
OpKernelTestMultiInputsProtoAndCheckerMaker
);
REGISTER_OP_CPU_KERNEL
(
op_multi_inputs_with_kernel
,
paddle
::
framework
::
CPUKernalMultiInputsTest
);
// test with multi inputs
TEST
(
OpKernel
,
multi_inputs
)
{
using
namespace
paddle
::
framework
;
OpDesc
op_desc
;
op_desc
.
set_type
(
"op_multi_inputs_with_kernel"
);
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x0"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x1"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"x2"
;
*
op_desc
.
mutable_inputs
()
->
Add
()
=
"k0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y0"
;
*
op_desc
.
mutable_outputs
()
->
Add
()
=
"y1"
;
auto
attr
=
op_desc
.
mutable_attrs
()
->
Add
();
attr
->
set_name
(
"scale"
);
attr
->
set_type
(
paddle
::
framework
::
AttrType
::
FLOAT
);
attr
->
set_f
(
3.14
);
auto
attr0
=
op_desc
.
mutable_attrs
()
->
Add
();
attr0
->
set_name
(
"input_format"
);
attr0
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
input_format
=
attr0
->
mutable_ints
();
input_format
->
Add
(
0
);
// x0
input_format
->
Add
(
3
);
// k
input_format
->
Add
(
4
);
// end
auto
attr1
=
op_desc
.
mutable_attrs
()
->
Add
();
attr1
->
set_name
(
"output_format"
);
attr1
->
set_type
(
paddle
::
framework
::
AttrType
::
INTS
);
auto
output_format
=
attr1
->
mutable_ints
();
output_format
->
Add
(
0
);
// y0
output_format
->
Add
(
2
);
// y1
paddle
::
platform
::
CPUDeviceContext
cpu_device_context
;
auto
scope
=
std
::
make_shared
<
Scope
>
();
OperatorPtr
op
(
paddle
::
framework
::
OpRegistry
::
CreateOp
(
op_desc
));
op
->
Run
(
scope
,
cpu_device_context
);
}
paddle/operators/add_op.h
浏览文件 @
ad8fa77c
...
@@ -8,10 +8,10 @@ namespace operators {
...
@@ -8,10 +8,10 @@ namespace operators {
template
<
typename
Place
>
template
<
typename
Place
>
class
AddKernel
:
public
framework
::
OpKernel
{
class
AddKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
KernelContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
KernelContext
&
context
)
const
override
{
LOG
(
INFO
)
<<
"Add kernel in "
<<
typeid
(
Place
).
name
();
LOG
(
INFO
)
<<
"Add kernel in "
<<
typeid
(
Place
).
name
();
}
}
};
};
}
// namespace op
}
// namespace op
erators
}
// namespace paddle
}
// namespace paddle
paddle/pybind/pybind.cc
浏览文件 @
ad8fa77c
...
@@ -67,6 +67,23 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -67,6 +67,23 @@ All parameter, weight, gradient are variables in Paddle.
}
}
return
ret_values
;
return
ret_values
;
});
});
m
.
def_submodule
(
"var_names"
,
"The module will return special predefined variable name in Paddle"
)
.
def
(
"empty"
,
pd
::
OperatorBase
::
EMPTY_VAR_NAME
)
.
def
(
"temp"
,
pd
::
OperatorBase
::
TMP_VAR_NAME
);
py
::
class_
<
pd
::
OperatorBase
,
pd
::
OperatorPtr
>
(
m
,
"Operator"
)
.
def
(
"__str__"
,
&
pd
::
OperatorBase
::
DebugString
)
.
def_static
(
"create"
,
[](
const
std
::
string
&
protobin
)
{
pd
::
OpDesc
desc
;
PADDLE_ENFORCE
(
desc
.
ParsePartialFromString
(
protobin
),
"Cannot parse user input to OpDesc"
);
PADDLE_ENFORCE
(
desc
.
IsInitialized
(),
"User OpDesc is not initialized, reason %s"
,
desc
.
InitializationErrorString
());
return
pd
::
OpRegistry
::
CreateOp
(
desc
);
});
return
m
.
ptr
();
return
m
.
ptr
();
}
}
python/paddle/v2/framework/create_op_creation_methods.py
浏览文件 @
ad8fa77c
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attr_type_pb2
as
attr_type_pb2
import
cStringIO
def
get_all_op_protos
():
def
get_all_op_protos
():
"""
Get all registered op proto from Paddle C++
:return: list of OpProto
"""
protostrs
=
core
.
get_all_op_protos
()
protostrs
=
core
.
get_all_op_protos
()
ret_values
=
[]
ret_values
=
[]
for
pbstr
in
protostrs
:
for
pbstr
in
protostrs
:
op_proto
=
op_proto_pb2
.
OpProto
.
FromString
(
str
(
pbstr
))
op_proto
=
op_proto_pb2
.
OpProto
.
FromString
(
str
(
pbstr
))
ret_values
.
append
(
op_proto
)
ret_values
.
append
(
op_proto
)
return
ret_values
return
ret_values
class
OpDescCreationMethod
(
object
):
"""
A Functor object to convert user input(use key word args) to OpDesc based on
OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
"""
def
__init__
(
self
,
op_proto
):
if
not
isinstance
(
op_proto
,
op_proto_pb2
.
OpProto
):
raise
TypeError
(
"Argument should be OpProto"
)
self
.
__op_proto__
=
op_proto
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""
Convert user input to OpDesc. Only key-word args are supported.
:return: OpDesc based on user input
:rtype: op_desc_pb2.OpDesc
"""
if
len
(
args
)
!=
0
:
raise
ValueError
(
"Only keyword arguments is supported by Paddle"
)
op_desc
=
op_desc_pb2
.
OpDesc
()
# Inputs
ipts
,
ipt_format
,
_
=
OpDescCreationMethod
.
extract_input_or_output
(
"input"
,
kwargs
,
self
.
__op_proto__
.
inputs
)
op_desc
.
inputs
.
extend
(
ipts
)
if
ipt_format
is
not
None
:
op_desc
.
attrs
.
extend
([
ipt_format
])
# Outputs
outs
,
out_format
,
tmp_index
=
OpDescCreationMethod
.
extract_input_or_output
(
"output"
,
kwargs
,
self
.
__op_proto__
.
outputs
)
op_desc
.
outputs
.
extend
(
outs
)
if
out_format
is
not
None
:
op_desc
.
attrs
.
extend
([
out_format
])
if
len
(
tmp_index
)
!=
0
:
tmp_index_attr
=
op_desc
.
attrs
.
add
()
tmp_index_attr
.
type
=
attr_type_pb2
.
INTS
tmp_index_attr
.
name
=
"temporary_index"
tmp_index_attr
.
ints
.
extend
(
tmp_index
)
# Types
op_desc
.
type
=
self
.
__op_proto__
.
type
# Attrs
for
attr
in
self
.
__op_proto__
.
attrs
:
if
attr
.
generated
:
continue
user_defined_attr
=
kwargs
.
get
(
attr
.
name
,
None
)
if
user_defined_attr
is
not
None
:
new_attr
=
op_desc
.
attrs
.
add
()
new_attr
.
name
=
attr
.
name
new_attr
.
type
=
attr
.
type
if
attr
.
type
==
attr_type_pb2
.
INT
:
new_attr
.
i
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
FLOAT
:
new_attr
.
f
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
STRING
:
new_attr
.
s
=
user_defined_attr
elif
attr
.
type
==
attr_type_pb2
.
INTS
:
new_attr
.
ints
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attr_type_pb2
.
FLOATS
:
new_attr
.
floats
.
extend
(
user_defined_attr
)
elif
attr
.
type
==
attr_type_pb2
.
STRINGS
:
new_attr
.
strings
.
extend
(
user_defined_attr
)
else
:
raise
NotImplementedError
(
"Not support attribute type "
+
attr
.
type
)
return
op_desc
@
staticmethod
def
extract_input_or_output
(
in_out
,
kwargs
,
meta
):
"""
Extract input variable names or output variable names from key-word
arguments, which base on VarProtos.
:param in_out: "input" or "output"
:param kwargs: key-word arguments that user inputted.
:param meta: a list of VarProto
:return: The three object will be return. The variable names. The
input_format or output_format attribute(None if the input or output is
not multiple). The temporary variable index list.
"""
multiple
=
OpDescCreationMethod
.
any_is_true
((
m
.
multiple
for
m
in
meta
))
tmp_index
=
[]
retv
=
[]
if
multiple
:
var_format
=
op_desc_pb2
.
AttrDesc
()
var_format
.
type
=
attr_type_pb2
.
INTS
var_format
.
name
=
"%s_format"
%
in_out
var_format
.
ints
.
append
(
0
)
for
var
in
meta
:
var_name
=
var
.
name
if
var
.
temporary
:
var_name
=
[
core
.
var_names
.
temp
()]
tmp_index
.
append
(
len
(
retv
))
else
:
var_name
=
kwargs
.
get
(
var_name
,
[])
if
not
isinstance
(
var_name
,
list
):
var_name
=
[
var_name
]
retv
.
extend
(
var_name
)
var_format
.
ints
.
append
(
len
(
var_name
)
+
var_format
.
ints
[
-
1
])
return
retv
,
var_format
,
tmp_index
else
:
for
var
in
meta
:
if
var
.
temporary
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
temp
()))
tmp_index
.
append
(
len
(
retv
))
else
:
retv
.
append
(
kwargs
.
get
(
var
.
name
,
core
.
var_names
.
empty
()))
return
retv
,
None
,
tmp_index
@
staticmethod
def
any_is_true
(
generator
):
"""
Reduce a bool array to one. If any of them is True, then return True.
"""
for
flag
in
generator
:
if
flag
:
return
True
return
False
def
get_docstring_from_op_proto
(
op_proto
):
"""
Generate docstring from a OpProto
:param op_proto: a OpProto instance.
:type op_proto: op_proto_pb2.OpProto
:return: docstring
"""
if
not
isinstance
(
op_proto
,
op_proto_pb2
.
OpProto
):
raise
TypeError
(
"Input must be OpProto"
)
f
=
cStringIO
.
StringIO
()
f
.
write
(
op_proto
.
comment
)
f
.
write
(
"
\n
"
)
def
__append_param__
(
name
,
comment
,
type
):
# Maybe replace the following line with template engine is better.
f
.
write
(
":param "
)
f
.
write
(
name
)
f
.
write
(
": "
)
f
.
write
(
comment
)
f
.
write
(
"
\n
"
)
f
.
write
(
":type "
)
f
.
write
(
name
)
f
.
write
(
": "
)
f
.
write
(
type
)
f
.
write
(
"
\n
"
)
for
ipt
in
op_proto
.
inputs
:
__append_param__
(
ipt
.
name
,
ipt
.
comment
,
"list | basestr"
if
ipt
.
multiple
else
"basestr"
)
temp_var_prefix
=
\
"This is a temporary variable. It does not have to set by user. "
for
opt
in
op_proto
.
outputs
:
__append_param__
(
opt
.
name
,
opt
.
comment
if
not
opt
.
temporary
else
temp_var_prefix
+
opt
.
comment
,
"list | basestr"
if
opt
.
multiple
else
"basestr"
)
for
attr
in
op_proto
.
attrs
:
attr_type
=
None
if
attr
.
type
==
attr_type_pb2
.
INT
:
attr_type
=
"int"
elif
attr
.
type
==
attr_type_pb2
.
FLOAT
:
attr_type
=
"float"
elif
attr
.
type
==
attr_type_pb2
.
STRING
:
attr_type
=
"basestr"
elif
attr
.
type
==
attr_type_pb2
.
INTS
:
attr_type
=
"list of int"
elif
attr
.
type
==
attr_type_pb2
.
FLOATS
:
attr_type
=
"list of float"
elif
attr
.
type
==
attr_type_pb2
.
STRINGS
:
attr_type
=
"list of basestr"
if
attr_type
is
None
:
raise
RuntimeError
(
"Not supported attribute type "
+
attr
.
type
)
__append_param__
(
attr
.
name
,
attr
.
comment
,
attr_type
)
return
f
.
getvalue
()
def
create_op_creation_method
(
op_proto
):
"""
Generate op creation method for an OpProto
"""
method
=
OpDescCreationMethod
(
op_proto
)
def
__impl__
(
*
args
,
**
kwargs
):
opdesc
=
method
(
*
args
,
**
kwargs
)
return
core
.
Operator
.
create
(
opdesc
.
SerializeToString
())
__impl__
.
__doc__
=
get_docstring_from_op_proto
(
op_proto
)
return
__impl__
class
OpCreationsHolder
(
object
):
"""
A object will holds all op creation methods.
Use `op_creations.xxx_op` to access them.
"""
pass
op_creations
=
OpCreationsHolder
()
def
__bootstrap__
():
"""
Bootstrap function for this module. It will dynamic create all op creation
methods in runtime.
"""
for
op_proto
in
get_all_op_protos
():
func
=
create_op_creation_method
(
op_proto
)
func
.
__name__
=
str
(
op_proto
.
type
)
setattr
(
op_creations
,
func
.
__name__
,
func
)
__bootstrap__
()
python/paddle/v2/framework/tests/test_op_creation_methods.py
浏览文件 @
ad8fa77c
import
unittest
import
unittest
import
paddle.v2.framework.create_op_creation_methods
as
creation
import
paddle.v2.framework.create_op_creation_methods
as
creation
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.proto.op_proto_pb2
as
op_proto_pb2
import
paddle.v2.framework.proto.op_desc_pb2
as
op_desc_pb2
import
paddle.v2.framework.proto.attr_type_pb2
as
attr_type_pb2
class
Test
OpCreationsMethod
s
(
unittest
.
TestCase
):
class
Test
GetAllProto
s
(
unittest
.
TestCase
):
def
test_all
_protos
(
self
):
def
test_all
(
self
):
all_protos
=
creation
.
get_all_op_protos
()
all_protos
=
creation
.
get_all_op_protos
()
self
.
assertNotEqual
(
0
,
len
(
all_protos
))
self
.
assertNotEqual
(
0
,
len
(
all_protos
))
...
@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase):
...
@@ -11,5 +15,240 @@ class TestOpCreationsMethods(unittest.TestCase):
self
.
assertTrue
(
each
.
IsInitialized
())
self
.
assertTrue
(
each
.
IsInitialized
())
class
TestOpDescCreationMethod
(
unittest
.
TestCase
):
def
test_plain_input_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
comment
=
"not matter"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"Y"
ipt
.
comment
=
"not matter"
opt
=
op
.
outputs
.
add
()
opt
.
name
=
"Z"
opt
.
comment
=
"not matter"
op
.
comment
=
"not matter"
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
output
=
method
(
X
=
"a"
,
Y
=
"b"
,
Z
=
"c"
)
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
"a"
,
"b"
])
expected
.
outputs
.
append
(
"c"
)
self
.
assertEqual
(
expected
,
output
)
def
test_multiple_input_plain_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"fc"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"X"
ipt
.
comment
=
""
ipt
.
multiple
=
True
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"W"
ipt
.
comment
=
""
ipt
.
multiple
=
True
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
"b"
ipt
.
comment
=
""
out
=
op
.
outputs
.
add
()
out
.
name
=
"Y"
out
.
comment
=
""
op
.
comment
=
""
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
generated1
=
method
(
X
=
"x"
,
W
=
"w"
,
b
=
"b"
,
Y
=
"y"
)
expected1
=
op_desc_pb2
.
OpDesc
()
expected1
.
inputs
.
extend
([
'x'
,
'w'
,
'b'
])
expected1
.
outputs
.
extend
([
'y'
])
expected1
.
type
=
'fc'
attr
=
expected1
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
])
self
.
assertEqual
(
expected1
,
generated1
)
generated2
=
method
(
X
=
[
'x1'
,
'x2'
,
'x3'
],
b
=
'b'
,
W
=
[
'w1'
,
'w2'
,
'w3'
],
Y
=
'y'
)
expected2
=
op_desc_pb2
.
OpDesc
()
expected2
.
inputs
.
extend
([
'x1'
,
'x2'
,
'x3'
,
'w1'
,
'w2'
,
'w3'
,
'b'
])
expected2
.
outputs
.
extend
([
'y'
])
expected2
.
type
=
'fc'
attr
=
expected2
.
attrs
.
add
()
attr
.
name
=
'input_format'
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
3
,
6
,
7
])
self
.
assertEqual
(
expected2
,
generated2
)
def
test_attrs
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
ipt
=
op
.
inputs
.
add
()
ipt
.
name
=
'X'
ipt
.
comment
=
""
def
__add_attr__
(
name
,
type
):
attr
=
op
.
attrs
.
add
()
attr
.
name
=
name
attr
.
comment
=
""
attr
.
type
=
type
__add_attr__
(
"int_attr"
,
attr_type_pb2
.
INT
)
__add_attr__
(
"float_attr"
,
attr_type_pb2
.
FLOAT
)
__add_attr__
(
"string_attr"
,
attr_type_pb2
.
STRING
)
__add_attr__
(
"ints_attr"
,
attr_type_pb2
.
INTS
)
__add_attr__
(
"floats_attr"
,
attr_type_pb2
.
FLOATS
)
__add_attr__
(
"strings_attr"
,
attr_type_pb2
.
STRINGS
)
op
.
comment
=
""
self
.
assertTrue
(
op
.
IsInitialized
())
method
=
creation
.
OpDescCreationMethod
(
op
)
generated
=
method
(
X
=
"a"
,
int_attr
=
10
,
float_attr
=
3.2
,
string_attr
=
"test_str"
,
ints_attr
=
[
0
,
1
,
2
,
3
,
4
],
floats_attr
=
[
0.2
,
3.2
,
4.5
],
strings_attr
=
[
"a"
,
"b"
,
"c"
])
expected
=
op_desc_pb2
.
OpDesc
()
expected
.
type
=
"test"
expected
.
inputs
.
extend
([
'a'
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"int_attr"
attr
.
type
=
attr_type_pb2
.
INT
attr
.
i
=
10
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"float_attr"
attr
.
type
=
attr_type_pb2
.
FLOAT
attr
.
f
=
3.2
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"string_attr"
attr
.
type
=
attr_type_pb2
.
STRING
attr
.
s
=
"test_str"
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
"ints_attr"
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
extend
([
0
,
1
,
2
,
3
,
4
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'floats_attr'
attr
.
type
=
attr_type_pb2
.
FLOATS
attr
.
floats
.
extend
([
0.2
,
3.2
,
4.5
])
attr
=
expected
.
attrs
.
add
()
attr
.
name
=
'strings_attr'
attr
.
type
=
attr_type_pb2
.
STRINGS
attr
.
strings
.
extend
([
'a'
,
'b'
,
'c'
])
self
.
assertEqual
(
expected
,
generated
)
def
test_input_temporary_output
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
out
=
op
.
outputs
.
add
()
out
.
name
=
"OUT"
out
.
comment
=
""
out
=
op
.
outputs
.
add
()
out
.
name
=
"TMP"
out
.
comment
=
""
out
.
temporary
=
True
out
=
op
.
outputs
.
add
()
out
.
name
=
"OUT2"
out
.
comment
=
""
op
.
comment
=
""
method
=
creation
.
OpDescCreationMethod
(
op
)
generated
=
method
(
OUT
=
"a"
,
OUT2
=
"b"
)
desc
=
op_desc_pb2
.
OpDesc
()
desc
.
outputs
.
extend
([
"a"
,
core
.
var_names
.
temp
(),
"b"
])
desc
.
type
=
"test"
attr
=
desc
.
attrs
.
add
()
attr
.
name
=
"temporary_index"
attr
.
type
=
attr_type_pb2
.
INTS
attr
.
ints
.
append
(
2
)
self
.
assertEqual
(
generated
,
desc
)
class
TestOpCreationDocStr
(
unittest
.
TestCase
):
def
test_all
(
self
):
op
=
op_proto_pb2
.
OpProto
()
op
.
type
=
"test"
op
.
comment
=
"""Test Op.
This op is used for unit test, not a real op.
"""
a
=
op
.
inputs
.
add
()
a
.
name
=
"a"
a
.
comment
=
"Input a for test op"
a
.
multiple
=
True
b
=
op
.
inputs
.
add
()
b
.
name
=
"b"
b
.
comment
=
"Input b for test op"
self
.
assertTrue
(
op
.
IsInitialized
())
o1
=
op
.
outputs
.
add
()
o1
.
name
=
"output"
o1
.
comment
=
"The output of test op"
o2
=
op
.
outputs
.
add
()
o2
.
name
=
"temp output"
o2
.
comment
=
"The temporary output of test op"
o2
.
temporary
=
True
test_str
=
op
.
attrs
.
add
()
test_str
.
name
=
"str_attr"
test_str
.
type
=
attr_type_pb2
.
STRING
test_str
.
comment
=
"A string attribute for test op"
actual
=
creation
.
get_docstring_from_op_proto
(
op
)
expected_docstring
=
'''Test Op.
This op is used for unit test, not a real op.
:param a: Input a for test op
:type a: list | basestr
:param b: Input b for test op
:type b: basestr
:param output: The output of test op
:type output: basestr
:param temp output: This is a temporary variable. It does not have to set by user. The temporary output of test op
:type temp output: basestr
:param str_attr: A string attribute for test op
:type str_attr: basestr
'''
self
.
assertEqual
(
expected_docstring
,
actual
)
class
TestOpCreations
(
unittest
.
TestCase
):
def
test_all
(
self
):
add_op
=
creation
.
op_creations
.
add_two
(
X
=
"a"
,
Y
=
"b"
,
Out
=
"z"
)
self
.
assertIsNotNone
(
add_op
)
# Invoke C++ DebugString()
self
.
assertEqual
(
'Op(add_two), inputs:(a, b), outputs:(z).'
,
str
(
add_op
))
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
python/paddle/v2/optimizer.py
浏览文件 @
ad8fa77c
import
py_paddle.swig_paddle
as
swig_api
import
paddle.trainer_config_helpers.config_parser_utils
as
config_parser_utils
import
paddle.trainer_config_helpers.config_parser_utils
as
config_parser_utils
import
paddle.trainer_config_helpers.optimizers
as
v1_optimizers
import
paddle.trainer_config_helpers.optimizers
as
v1_optimizers
"""
"""
...
@@ -17,6 +16,7 @@ __all__ = [
...
@@ -17,6 +16,7 @@ __all__ = [
class
Optimizer
(
object
):
class
Optimizer
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
import
py_paddle.swig_paddle
as
swig_api
if
'batch_size'
in
kwargs
:
if
'batch_size'
in
kwargs
:
del
kwargs
[
'batch_size'
]
# not important for python library.
del
kwargs
[
'batch_size'
]
# not important for python library.
...
@@ -35,18 +35,22 @@ class Optimizer(object):
...
@@ -35,18 +35,22 @@ class Optimizer(object):
For each optimizer(SGD, Adam), GradientMachine should enable different
For each optimizer(SGD, Adam), GradientMachine should enable different
buffers.
buffers.
"""
"""
import
py_paddle.swig_paddle
as
swig_api
tmp
=
swig_api
.
ParameterOptimizer
.
create
(
self
.
__opt_conf__
)
tmp
=
swig_api
.
ParameterOptimizer
.
create
(
self
.
__opt_conf__
)
assert
isinstance
(
tmp
,
swig_api
.
ParameterOptimizer
)
assert
isinstance
(
tmp
,
swig_api
.
ParameterOptimizer
)
return
tmp
.
getParameterTypes
()
return
tmp
.
getParameterTypes
()
def
__create_local_updater__
(
self
):
def
__create_local_updater__
(
self
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createLocalUpdater
(
self
.
__opt_conf__
)
return
swig_api
.
ParameterUpdater
.
createLocalUpdater
(
self
.
__opt_conf__
)
def
__create_remote_updater__
(
self
,
pass_num
,
use_sparse_updater
):
def
__create_remote_updater__
(
self
,
pass_num
,
use_sparse_updater
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createRemoteUpdater
(
return
swig_api
.
ParameterUpdater
.
createRemoteUpdater
(
self
.
__opt_conf__
,
pass_num
,
use_sparse_updater
)
self
.
__opt_conf__
,
pass_num
,
use_sparse_updater
)
def
__create_new_remote_updater__
(
self
,
pserver_spec
,
use_etcd
):
def
__create_new_remote_updater__
(
self
,
pserver_spec
,
use_etcd
):
import
py_paddle.swig_paddle
as
swig_api
return
swig_api
.
ParameterUpdater
.
createNewRemoteUpdater
(
return
swig_api
.
ParameterUpdater
.
createNewRemoteUpdater
(
self
.
__opt_conf__
,
pserver_spec
,
use_etcd
)
self
.
__opt_conf__
,
pserver_spec
,
use_etcd
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录