Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0c37a061
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0c37a061
编写于
10月 10, 2017
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'feature/change_proto_to_desc' into feature/complete_variable_bind
上级
83dbc150
fb2ad4c9
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
70 addition
and
47 deletion
+70
-47
doc/design/python_api.md
doc/design/python_api.md
+6
-6
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+36
-0
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+2
-0
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+2
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+0
-15
python/paddle/v2/framework/graph.py
python/paddle/v2/framework/graph.py
+20
-21
python/paddle/v2/framework/tests/test_infer_shape.py
python/paddle/v2/framework/tests/test_infer_shape.py
+3
-3
未找到文件。
doc/design/python_api.md
浏览文件 @
0c37a061
...
@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
...
@@ -22,7 +22,7 @@ Whenever we create a block, we need to set its parent block to the current block
```
python
```
python
class
Program
(
objects
):
class
Program
(
objects
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
proto
=
core
.
NewProgram
()
# a C++ ProgramDesc pointer.
self
.
desc
=
core
.
NewProgram
()
# a C++ ProgramDesc pointer.
self
.
blocks
=
vector
<
Block
>
()
self
.
blocks
=
vector
<
Block
>
()
self
.
blocks
.
append
(
Block
(
self
,
-
1
))
# the global block
self
.
blocks
.
append
(
Block
(
self
,
-
1
))
# the global block
self
.
current_block
=
0
# initialized to the global block
self
.
current_block
=
0
# initialized to the global block
...
@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
...
@@ -57,7 +57,7 @@ A [Block](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/block.m
```
python
```
python
class
Block
(
objects
):
class
Block
(
objects
):
def
__init__
(
self
,
program
,
parent_idx
):
def
__init__
(
self
,
program
,
parent_idx
):
self
.
proto
=
core
.
NewBlock
(
program
.
proto
)
self
.
desc
=
core
.
NewBlock
(
program
.
desc
)
self
.
program
=
program
self
.
program
=
program
self
.
vars
=
map
<
string
,
Variable
>
()
self
.
vars
=
map
<
string
,
Variable
>
()
self
.
ops
=
vector
<
Operator
>
()
self
.
ops
=
vector
<
Operator
>
()
...
@@ -98,11 +98,11 @@ class Operator(object):
...
@@ -98,11 +98,11 @@ class Operator(object):
outputs
,
# dict<stirng, Variable>
outputs
,
# dict<stirng, Variable>
attrs
# dict<string, Any>
attrs
# dict<string, Any>
):
):
self
.
proto
=
core
.
NewOpDesc
(
block
.
proto
,
type
,
inputs
,
outputs
,
attrs
)
self
.
desc
=
core
.
NewOpDesc
(
block
.
desc
,
type
,
inputs
,
outputs
,
attrs
)
core
.
infer_shape
(
self
.
proto
,
inputs
,
outputs
)
core
.
infer_shape
(
self
.
desc
,
inputs
,
outputs
)
def
type
(
self
):
def
type
(
self
):
return
self
.
proto
.
type
()
return
self
.
desc
.
type
()
```
```
`Operator`
creates the
`OpDesc`
message in C++ space, so that it can call the
`InferShape`
function, which is in C++.
`Operator`
creates the
`OpDesc`
message in C++ space, so that it can call the
`InferShape`
function, which is in C++.
...
@@ -124,7 +124,7 @@ class Variable(object):
...
@@ -124,7 +124,7 @@ class Variable(object):
name
=
unique_name_generator
()
name
=
unique_name_generator
()
self
.
name
=
name
self
.
name
=
name
self
.
block
=
block
self
.
block
=
block
self
.
proto
=
core
.
NewVarDesc
(
block
.
proto
,
name
,
shape
,
lod_level
)
self
.
desc
=
core
.
NewVarDesc
(
block
.
desc
,
name
,
shape
,
lod_level
)
self
.
writer
=
None
self
.
writer
=
None
```
```
...
...
paddle/framework/CMakeLists.txt
浏览文件 @
0c37a061
...
@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
...
@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute
ddim
)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc
)
...
...
paddle/framework/op_desc.cc
浏览文件 @
0c37a061
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
...
@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
need_update_
=
false
;
need_update_
=
false
;
}
}
}
}
using
InferShapeFuncMap
=
std
::
unordered_map
<
std
::
string
/*op_type*/
,
std
::
function
<
void
(
InferShapeContext
*
)
>>
;
static
InferShapeFuncMap
&
InferShapeFuncs
()
{
static
InferShapeFuncMap
*
g_map
=
nullptr
;
if
(
g_map
==
nullptr
)
{
g_map
=
new
InferShapeFuncMap
();
auto
&
info_map
=
OpInfoMap
::
Instance
();
// all registered kernels
for
(
auto
&
pair
:
OperatorWithKernel
::
AllOpKernels
())
{
auto
&
info
=
info_map
.
Get
(
pair
.
first
);
// use empty type here to avoid runtime checks.
auto
op
=
static_cast
<
OperatorWithKernel
*>
(
info
.
Creator
()(
""
,
{},
{},
{}));
g_map
->
insert
(
{
pair
.
first
,
[
op
](
InferShapeContext
*
ctx
)
{
op
->
InferShape
(
ctx
);
}});
}
}
return
*
g_map
;
}
void
OpDescBind
::
InferShape
(
const
BlockDescBind
&
block
)
const
{
auto
&
funcs
=
InferShapeFuncs
();
auto
it
=
funcs
.
find
(
this
->
Type
());
if
(
it
==
funcs
.
end
())
{
PADDLE_THROW
(
"Operator %s has not been registered"
,
this
->
Type
());
}
CompileTimeInferShapeContext
ctx
(
*
this
,
block
);
it
->
second
(
&
ctx
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/op_desc.h
浏览文件 @
0c37a061
...
@@ -100,6 +100,8 @@ class OpDescBind {
...
@@ -100,6 +100,8 @@ class OpDescBind {
return
&
this
->
attrs_
;
return
&
this
->
attrs_
;
}
}
void
InferShape
(
const
BlockDescBind
&
block
)
const
;
private:
private:
template
<
typename
MapType
>
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
...
paddle/pybind/protobuf.cc
浏览文件 @
0c37a061
...
@@ -198,7 +198,8 @@ void BindOpDesc(py::module &m) {
...
@@ -198,7 +198,8 @@ void BindOpDesc(py::module &m) {
.
def
(
"set_attr"
,
&
OpDescBind
::
SetAttr
)
.
def
(
"set_attr"
,
&
OpDescBind
::
SetAttr
)
.
def
(
"attr"
,
&
OpDescBind
::
GetAttr
)
.
def
(
"attr"
,
&
OpDescBind
::
GetAttr
)
.
def
(
"set_block_attr"
,
&
OpDescBind
::
SetBlockAttr
)
.
def
(
"set_block_attr"
,
&
OpDescBind
::
SetBlockAttr
)
.
def
(
"get_block_attr"
,
&
OpDescBind
::
GetBlockAttr
);
.
def
(
"get_block_attr"
,
&
OpDescBind
::
GetBlockAttr
)
.
def
(
"infer_shape"
,
&
OpDescBind
::
InferShape
);
}
}
}
// namespace pybind
}
// namespace pybind
...
...
paddle/pybind/pybind.cc
浏览文件 @
0c37a061
...
@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc
.
InitializationErrorString
());
desc
.
InitializationErrorString
());
return
OpRegistry
::
CreateOp
(
desc
);
return
OpRegistry
::
CreateOp
(
desc
);
})
})
.
def_static
(
"infer_shape"
,
[](
OpDescBind
&
op_desc
,
BlockDescBind
&
block
)
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
.
Proto
());
auto
*
op_with_kernel
=
dynamic_cast
<
OperatorWithKernel
*>
(
op
.
get
());
if
(
op_with_kernel
!=
nullptr
)
{
auto
ctx
=
CompileTimeInferShapeContext
(
op_desc
,
block
);
op_with_kernel
->
InferShape
(
&
ctx
);
}
else
{
PADDLE_THROW
(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function"
,
op_desc
.
Type
());
}
})
.
def
(
"backward"
,
.
def
(
"backward"
,
[](
const
OperatorBase
&
forwardOp
,
[](
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
...
...
python/paddle/v2/framework/graph.py
浏览文件 @
0c37a061
...
@@ -13,15 +13,15 @@ class Variable(object):
...
@@ -13,15 +13,15 @@ class Variable(object):
if
name
is
None
:
if
name
is
None
:
name
=
Variable
.
_unique_var_name_
()
name
=
Variable
.
_unique_var_name_
()
try
:
try
:
self
.
proto
=
self
.
block
.
proto
.
var
(
name
)
self
.
desc
=
self
.
block
.
desc
.
var
(
name
)
is_new_var
=
False
is_new_var
=
False
except
core
.
EnforceNotMet
:
except
core
.
EnforceNotMet
:
self
.
proto
=
self
.
block
.
proto
.
new_var
(
name
)
self
.
desc
=
self
.
block
.
desc
.
new_var
(
name
)
is_new_var
=
True
is_new_var
=
True
if
shape
is
not
None
:
if
shape
is
not
None
:
if
is_new_var
:
if
is_new_var
:
self
.
proto
.
set_shape
(
shape
)
self
.
desc
.
set_shape
(
shape
)
else
:
else
:
old_shape
=
self
.
shape
old_shape
=
self
.
shape
shape
=
tuple
(
shape
)
shape
=
tuple
(
shape
)
...
@@ -34,7 +34,7 @@ class Variable(object):
...
@@ -34,7 +34,7 @@ class Variable(object):
if
not
isinstance
(
dtype
,
core
.
DataType
):
if
not
isinstance
(
dtype
,
core
.
DataType
):
dtype
=
Variable
.
_convert_np_dtype_to_dtype_
(
dtype
)
dtype
=
Variable
.
_convert_np_dtype_to_dtype_
(
dtype
)
if
is_new_var
:
if
is_new_var
:
self
.
proto
.
set_data_type
(
dtype
)
self
.
desc
.
set_data_type
(
dtype
)
else
:
else
:
old_dtype
=
self
.
data_type
()
old_dtype
=
self
.
data_type
()
if
dtype
!=
old_shape
:
if
dtype
!=
old_shape
:
...
@@ -46,7 +46,7 @@ class Variable(object):
...
@@ -46,7 +46,7 @@ class Variable(object):
if
lod_level
is
not
None
:
if
lod_level
is
not
None
:
if
is_new_var
:
if
is_new_var
:
self
.
proto
.
set_lod_level
(
lod_level
)
self
.
desc
.
set_lod_level
(
lod_level
)
else
:
else
:
if
lod_level
!=
self
.
lod_level
:
if
lod_level
!=
self
.
lod_level
:
raise
ValueError
(
"Variable {0} has been created before. "
raise
ValueError
(
"Variable {0} has been created before. "
...
@@ -54,26 +54,25 @@ class Variable(object):
...
@@ -54,26 +54,25 @@ class Variable(object):
"lod_level is {2}. They are not "
"lod_level is {2}. They are not "
"matched"
.
format
(
self
.
name
,
self
.
lod_level
,
"matched"
.
format
(
self
.
name
,
self
.
lod_level
,
lod_level
))
lod_level
))
self
.
block
.
vars
[
name
]
=
self
self
.
block
.
vars
[
name
]
=
self
self
.
op
=
None
self
.
op
=
None
@
property
@
property
def
name
(
self
):
def
name
(
self
):
return
self
.
proto
.
name
()
return
self
.
desc
.
name
()
@
property
@
property
def
shape
(
self
):
def
shape
(
self
):
# convert to tuple, make it as same as numpy API.
# convert to tuple, make it as same as numpy API.
return
tuple
(
self
.
proto
.
shape
())
return
tuple
(
self
.
desc
.
shape
())
@
property
@
property
def
data_type
(
self
):
def
data_type
(
self
):
return
self
.
proto
.
data_type
()
return
self
.
desc
.
data_type
()
@
property
@
property
def
lod_level
(
self
):
def
lod_level
(
self
):
return
self
.
proto
.
lod_level
()
return
self
.
desc
.
lod_level
()
@
staticmethod
@
staticmethod
def
_unique_var_name_
():
def
_unique_var_name_
():
...
@@ -104,13 +103,13 @@ class Variable(object):
...
@@ -104,13 +103,13 @@ class Variable(object):
class
Operator
(
object
):
class
Operator
(
object
):
def
__init__
(
self
,
def
__init__
(
self
,
block
,
block
,
proto
,
desc
,
type
=
None
,
type
=
None
,
inputs
=
None
,
inputs
=
None
,
outputs
=
None
,
outputs
=
None
,
attrs
=
None
):
attrs
=
None
):
self
.
block
=
block
self
.
block
=
block
self
.
proto
=
proto
self
.
desc
=
desc
if
type
is
not
None
:
if
type
is
not
None
:
# TODO.
# TODO.
pass
pass
...
@@ -129,31 +128,31 @@ class Operator(object):
...
@@ -129,31 +128,31 @@ class Operator(object):
class
Block
(
object
):
class
Block
(
object
):
def
__init__
(
self
,
program
,
idx
):
def
__init__
(
self
,
program
,
idx
):
self
.
proto
=
program
.
proto
.
block
(
idx
)
self
.
desc
=
program
.
desc
.
block
(
idx
)
self
.
vars
=
dict
()
# var_name --> var
self
.
vars
=
dict
()
# var_name --> var
self
.
ops
=
collections
.
deque
()
# operator list
self
.
ops
=
collections
.
deque
()
# operator list
self
.
program
=
program
self
.
program
=
program
@
property
@
property
def
parent_idx
(
self
):
def
parent_idx
(
self
):
return
self
.
proto
.
parent
return
self
.
desc
.
parent
@
property
@
property
def
idx
(
self
):
def
idx
(
self
):
return
self
.
proto
.
id
return
self
.
desc
.
id
def
create_var
(
self
,
*
args
,
**
kwargs
):
def
create_var
(
self
,
*
args
,
**
kwargs
):
return
Variable
(
self
,
*
args
,
**
kwargs
)
return
Variable
(
self
,
*
args
,
**
kwargs
)
def
append_op
(
self
,
*
args
,
**
kwargs
):
def
append_op
(
self
,
*
args
,
**
kwargs
):
op_
proto
=
self
.
proto
.
append_op
()
op_
desc
=
self
.
desc
.
append_op
()
op
=
Operator
(
self
,
op_
proto
,
*
args
,
**
kwargs
)
op
=
Operator
(
self
,
op_
desc
,
*
args
,
**
kwargs
)
self
.
ops
.
append
(
op
)
self
.
ops
.
append
(
op
)
return
op
return
op
def
prepend_op
(
self
,
*
args
,
**
kwargs
):
def
prepend_op
(
self
,
*
args
,
**
kwargs
):
op_
proto
=
self
.
proto
.
prepend_op
()
op_
desc
=
self
.
desc
.
prepend_op
()
op
=
Operator
(
self
,
op_
proto
,
*
args
,
**
kwargs
)
op
=
Operator
(
self
,
op_
desc
,
*
args
,
**
kwargs
)
self
.
ops
.
appendleft
(
op
)
self
.
ops
.
appendleft
(
op
)
return
op
return
op
...
@@ -170,7 +169,7 @@ class Program(object):
...
@@ -170,7 +169,7 @@ class Program(object):
def
__init__
(
self
):
def
__init__
(
self
):
assert
not
hasattr
(
self
.
__class__
,
assert
not
hasattr
(
self
.
__class__
,
'_instance'
),
'Do not call constructor directly!'
'_instance'
),
'Do not call constructor directly!'
self
.
proto
=
core
.
ProgramDesc
.
instance
()
self
.
desc
=
core
.
ProgramDesc
.
instance
()
self
.
blocks
=
[
Block
(
self
,
0
)]
self
.
blocks
=
[
Block
(
self
,
0
)]
self
.
current_block_idx
=
0
self
.
current_block_idx
=
0
...
@@ -182,7 +181,7 @@ class Program(object):
...
@@ -182,7 +181,7 @@ class Program(object):
def
create_block
(
self
):
def
create_block
(
self
):
new_block_idx
=
len
(
self
.
blocks
)
new_block_idx
=
len
(
self
.
blocks
)
self
.
proto
.
append_block
(
self
.
current_block
().
proto
)
self
.
desc
.
append_block
(
self
.
current_block
().
desc
)
self
.
current_block_idx
=
new_block_idx
self
.
current_block_idx
=
new_block_idx
self
.
blocks
.
append
(
Block
(
self
,
self
.
current_block_idx
))
self
.
blocks
.
append
(
Block
(
self
,
self
.
current_block_idx
))
return
self
.
current_block
()
return
self
.
current_block
()
...
...
python/paddle/v2/framework/tests/test_infer_shape.py
浏览文件 @
0c37a061
import
unittest
import
unittest
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.op
import
Operator
class
TestInferShape
(
unittest
.
TestCase
):
class
TestInferShape
(
unittest
.
TestCase
):
...
@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
...
@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc
.
set_input
(
"X"
,
[
"x1"
,
"x2"
])
sum_op_desc
.
set_input
(
"X"
,
[
"x1"
,
"x2"
])
sum_op_desc
.
set_output
(
"Out"
,
[
"out"
])
sum_op_desc
.
set_output
(
"Out"
,
[
"out"
])
core
.
Operator
.
infer_shape
(
sum_op_desc
,
block
)
sum_op_desc
.
infer_shape
(
block
)
self
.
assertEqual
(
out
.
shape
(),
shape
)
self
.
assertEqual
(
out
.
shape
(),
shape
)
def
test_mul_op
(
self
):
def
test_mul_op
(
self
):
...
@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
...
@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc
.
set_attr
(
"x_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"x_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"y_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"y_num_col_dims"
,
1
)
core
.
Operator
.
infer_shape
(
mul_op_desc
,
block
)
mul_op_desc
.
infer_shape
(
block
)
self
.
assertEqual
(
out
.
shape
(),
[
x_shape
[
0
],
y_shape
[
1
]])
self
.
assertEqual
(
out
.
shape
(),
[
x_shape
[
0
],
y_shape
[
1
]])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录