Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
7506e481
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
7506e481
编写于
10月 10, 2017
作者:
Y
Yu Yang
提交者:
GitHub
10月 10, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #4660 from reyoung/feature/polish_infer_shape
Polish CompileTime InferShape
上级
a281b383
805639b1
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
44 addition
and
20 deletion
+44
-20
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+1
-1
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+36
-0
paddle/framework/op_desc.h
paddle/framework/op_desc.h
+2
-0
paddle/pybind/protobuf.cc
paddle/pybind/protobuf.cc
+2
-1
paddle/pybind/pybind.cc
paddle/pybind/pybind.cc
+0
-15
python/paddle/v2/framework/tests/test_infer_shape.py
python/paddle/v2/framework/tests/test_infer_shape.py
+3
-3
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
7506e481
...
@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
...
@@ -19,7 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library
(
framework_proto SRCS framework.proto
)
proto_library
(
framework_proto SRCS framework.proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
attribute SRCS attribute.cc DEPS framework_proto
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute
ddim
)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc
)
...
...
paddle/framework/op_desc.cc
浏览文件 @
7506e481
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/framework/op_desc.h"
#include "paddle/framework/op_desc.h"
#include <functional>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
...
@@ -185,5 +188,38 @@ void OpDescBind::Sync() {
need_update_
=
false
;
need_update_
=
false
;
}
}
}
}
using
InferShapeFuncMap
=
std
::
unordered_map
<
std
::
string
/*op_type*/
,
std
::
function
<
void
(
InferShapeContext
*
)
>>
;
static
InferShapeFuncMap
&
InferShapeFuncs
()
{
static
InferShapeFuncMap
*
g_map
=
nullptr
;
if
(
g_map
==
nullptr
)
{
g_map
=
new
InferShapeFuncMap
();
auto
&
info_map
=
OpInfoMap
::
Instance
();
// all registered kernels
for
(
auto
&
pair
:
OperatorWithKernel
::
AllOpKernels
())
{
auto
&
info
=
info_map
.
Get
(
pair
.
first
);
// use empty type here to avoid runtime checks.
auto
op
=
static_cast
<
OperatorWithKernel
*>
(
info
.
Creator
()(
""
,
{},
{},
{}));
g_map
->
insert
(
{
pair
.
first
,
[
op
](
InferShapeContext
*
ctx
)
{
op
->
InferShape
(
ctx
);
}});
}
}
return
*
g_map
;
}
void
OpDescBind
::
InferShape
(
const
BlockDescBind
&
block
)
const
{
auto
&
funcs
=
InferShapeFuncs
();
auto
it
=
funcs
.
find
(
this
->
Type
());
if
(
it
==
funcs
.
end
())
{
PADDLE_THROW
(
"Operator %s has not been registered"
,
this
->
Type
());
}
CompileTimeInferShapeContext
ctx
(
*
this
,
block
);
it
->
second
(
&
ctx
);
}
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/framework/op_desc.h
浏览文件 @
7506e481
...
@@ -100,6 +100,8 @@ class OpDescBind {
...
@@ -100,6 +100,8 @@ class OpDescBind {
return
&
this
->
attrs_
;
return
&
this
->
attrs_
;
}
}
void
InferShape
(
const
BlockDescBind
&
block
)
const
;
private:
private:
template
<
typename
MapType
>
template
<
typename
MapType
>
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
static
std
::
vector
<
typename
MapType
::
key_type
>
MapKeys
(
const
MapType
&
map
)
{
...
...
paddle/pybind/protobuf.cc
浏览文件 @
7506e481
...
@@ -196,7 +196,8 @@ void BindOpDesc(py::module &m) {
...
@@ -196,7 +196,8 @@ void BindOpDesc(py::module &m) {
.
def
(
"set_attr"
,
&
OpDescBind
::
SetAttr
)
.
def
(
"set_attr"
,
&
OpDescBind
::
SetAttr
)
.
def
(
"attr"
,
&
OpDescBind
::
GetAttr
)
.
def
(
"attr"
,
&
OpDescBind
::
GetAttr
)
.
def
(
"set_block_attr"
,
&
OpDescBind
::
SetBlockAttr
)
.
def
(
"set_block_attr"
,
&
OpDescBind
::
SetBlockAttr
)
.
def
(
"get_block_attr"
,
&
OpDescBind
::
GetBlockAttr
);
.
def
(
"get_block_attr"
,
&
OpDescBind
::
GetBlockAttr
)
.
def
(
"infer_shape"
,
&
OpDescBind
::
InferShape
);
}
}
}
// namespace pybind
}
// namespace pybind
...
...
paddle/pybind/pybind.cc
浏览文件 @
7506e481
...
@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -231,21 +231,6 @@ All parameter, weight, gradient are variables in Paddle.
desc
.
InitializationErrorString
());
desc
.
InitializationErrorString
());
return
OpRegistry
::
CreateOp
(
desc
);
return
OpRegistry
::
CreateOp
(
desc
);
})
})
.
def_static
(
"infer_shape"
,
[](
OpDescBind
&
op_desc
,
BlockDescBind
&
block
)
{
auto
op
=
OpRegistry
::
CreateOp
(
*
op_desc
.
Proto
());
auto
*
op_with_kernel
=
dynamic_cast
<
OperatorWithKernel
*>
(
op
.
get
());
if
(
op_with_kernel
!=
nullptr
)
{
auto
ctx
=
CompileTimeInferShapeContext
(
op_desc
,
block
);
op_with_kernel
->
InferShape
(
&
ctx
);
}
else
{
PADDLE_THROW
(
"OP(%s) is not type of OperatorWithKernel, "
"should not call this function"
,
op_desc
.
Type
());
}
})
.
def
(
"backward"
,
.
def
(
"backward"
,
[](
const
OperatorBase
&
forwardOp
,
[](
const
OperatorBase
&
forwardOp
,
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
const
std
::
unordered_set
<
std
::
string
>
&
no_grad_vars
)
{
...
...
python/paddle/v2/framework/tests/test_infer_shape.py
浏览文件 @
7506e481
import
unittest
import
unittest
import
paddle.v2.framework.core
as
core
import
paddle.v2.framework.core
as
core
from
paddle.v2.framework.op
import
Operator
class
TestInferShape
(
unittest
.
TestCase
):
class
TestInferShape
(
unittest
.
TestCase
):
...
@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
...
@@ -26,7 +26,7 @@ class TestInferShape(unittest.TestCase):
sum_op_desc
.
set_input
(
"X"
,
[
"x1"
,
"x2"
])
sum_op_desc
.
set_input
(
"X"
,
[
"x1"
,
"x2"
])
sum_op_desc
.
set_output
(
"Out"
,
[
"out"
])
sum_op_desc
.
set_output
(
"Out"
,
[
"out"
])
core
.
Operator
.
infer_shape
(
sum_op_desc
,
block
)
sum_op_desc
.
infer_shape
(
block
)
self
.
assertEqual
(
out
.
shape
(),
shape
)
self
.
assertEqual
(
out
.
shape
(),
shape
)
def
test_mul_op
(
self
):
def
test_mul_op
(
self
):
...
@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
...
@@ -55,7 +55,7 @@ class TestInferShape(unittest.TestCase):
mul_op_desc
.
set_attr
(
"x_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"x_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"y_num_col_dims"
,
1
)
mul_op_desc
.
set_attr
(
"y_num_col_dims"
,
1
)
core
.
Operator
.
infer_shape
(
mul_op_desc
,
block
)
mul_op_desc
.
infer_shape
(
block
)
self
.
assertEqual
(
out
.
shape
(),
[
x_shape
[
0
],
y_shape
[
1
]])
self
.
assertEqual
(
out
.
shape
(),
[
x_shape
[
0
],
y_shape
[
1
]])
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录