Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
b44f4ccb
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b44f4ccb
编写于
10月 26, 2017
作者:
Y
Yu Yang
提交者:
Yang Yang(Tony)
10月 26, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make InferShape as a field in OpInfo (#5139)
* Op developer can add `InferShape` to any operator
上级
7f8574c0
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
64 addition
and
36 deletion
+64
-36
paddle/framework/details/op_registry.h
paddle/framework/details/op_registry.h
+16
-2
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+24
-24
paddle/framework/op_info.h
paddle/framework/op_info.h
+11
-4
paddle/framework/operator.h
paddle/framework/operator.h
+3
-1
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+4
-0
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+6
-5
未找到文件。
paddle/framework/details/op_registry.h
浏览文件 @
b44f4ccb
...
...
@@ -28,7 +28,8 @@ enum OpInfoFillType {
kOperator
=
0
,
kOpProtoAndCheckerMaker
=
1
,
kGradOpDescMaker
=
2
,
kVarTypeInference
=
3
kVarTypeInference
=
3
,
kShapeInference
=
4
};
template
<
typename
T
>
...
...
@@ -42,7 +43,10 @@ struct OpInfoFillTypeID {
?
kGradOpDescMaker
:
(
std
::
is_base_of
<
VarTypeInference
,
T
>::
value
?
kVarTypeInference
:
static_cast
<
OpInfoFillType
>
(
-
1
))));
:
(
std
::
is_base_of
<
InferShapeBase
,
T
>::
value
?
kShapeInference
:
static_cast
<
OpInfoFillType
>
(
-
1
)))));
}
};
...
...
@@ -121,6 +125,16 @@ struct OpInfoFiller<T, kVarTypeInference> {
}
};
template
<
typename
T
>
struct
OpInfoFiller
<
T
,
kShapeInference
>
{
void
operator
()(
const
char
*
op_type
,
OpInfo
*
info
)
const
{
info
->
infer_shape_
=
[](
InferShapeContext
*
ctx
)
{
T
inference
;
inference
(
ctx
);
};
}
};
}
// namespace details
}
// namespace framework
...
...
paddle/framework/op_desc.cc
浏览文件 @
b44f4ccb
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/framework/op_desc.h"
#include <functional>
#include <mutex>
#include <unordered_map>
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
...
...
@@ -229,26 +230,26 @@ void OpDescBind::Flush() {
}
}
using
InferShapeFuncMap
=
std
::
unordered_map
<
std
::
string
/*op_type*/
,
std
::
function
<
void
(
InferShapeContext
*
)
>>
;
static
InferShapeFuncMap
&
InferShapeFuncs
()
{
static
InferShapeFuncMap
*
g_map
=
nullptr
;
if
(
g_map
==
nullptr
)
{
g_map
=
new
InferShapeFuncMap
();
auto
&
info_map
=
OpInfoMap
::
Instance
();
// all registered kernels
for
(
auto
&
pair
:
OperatorWithKernel
::
AllOpKernels
())
{
auto
&
info
=
info_map
.
Get
(
pair
.
first
);
// use empty type here to avoid runtime checks.
static
std
::
once_flag
init_infer_shape_funcs
;
static
void
InitInferShapeFuncs
()
{
std
::
call_once
(
init_infer_shape_funcs
,
[]
{
auto
&
map
=
OpInfoMap
::
Instance
();
auto
&
info_map
=
*
map
.
mutable_map
();
for
(
auto
&
kern_pair
:
OperatorWithKernel
::
AllOpKernels
())
{
auto
op_type
=
kern_pair
.
first
;
auto
&
op_info
=
info_map
.
at
(
op_type
);
auto
op
=
static_cast
<
OperatorWithKernel
*>
(
info
.
Creator
()(
""
,
{},
{},
{}));
g_map
->
insert
(
{
pair
.
first
,
[
op
](
InferShapeContext
*
ctx
)
{
op
->
InferShape
(
ctx
);
}})
;
static_cast
<
OperatorWithKernel
*>
(
op_
info
.
Creator
()(
""
,
{},
{},
{}));
if
(
op_info
.
infer_shape_
)
{
// infer_shape has been registered.
continue
;
}
op_info
.
infer_shape_
=
[
op
](
InferShapeContext
*
ctx
)
{
op
->
InferShape
(
ctx
);
};
}
return
*
g_map
;
})
;
}
void
OpDescBind
::
CheckAttrs
()
{
...
...
@@ -265,13 +266,12 @@ void OpDescBind::CheckAttrs() {
void
OpDescBind
::
InferShape
(
const
BlockDescBind
&
block
)
const
{
VLOG
(
3
)
<<
"CompileTime infer shape on "
<<
Type
();
auto
&
funcs
=
InferShapeFuncs
();
auto
it
=
funcs
.
find
(
this
->
Type
());
if
(
it
==
funcs
.
end
())
{
PADDLE_THROW
(
"Operator %s has not been registered"
,
this
->
Type
());
}
InitInferShapeFuncs
();
auto
&
infer_shape
=
OpInfoMap
::
Instance
().
Get
(
this
->
Type
()).
infer_shape_
;
PADDLE_ENFORCE
(
static_cast
<
bool
>
(
infer_shape
),
"%s's infer_shape has not been registered"
,
this
->
Type
());
CompileTimeInferShapeContext
ctx
(
*
this
,
block
);
i
t
->
second
(
&
ctx
);
i
nfer_shape
(
&
ctx
);
}
void
OpDescBind
::
InferVarType
(
BlockDescBind
*
block
)
const
{
...
...
paddle/framework/op_info.h
浏览文件 @
b44f4ccb
...
...
@@ -25,12 +25,19 @@
namespace
paddle
{
namespace
framework
{
class
InferShapeBase
{
public:
virtual
~
InferShapeBase
()
=
default
;
virtual
void
operator
()(
InferShapeContext
*
)
const
=
0
;
};
struct
OpInfo
{
OpCreator
creator_
;
GradOpMakerFN
grad_op_maker_
;
OpProto
*
proto_
{
nullptr
};
OpAttrChecker
*
checker_
{
nullptr
};
InferVarTypeFN
infer_var_type_
;
InferShapeFN
infer_shape_
;
bool
HasOpProtoAndChecker
()
const
{
return
proto_
!=
nullptr
&&
checker_
!=
nullptr
;
...
...
@@ -87,13 +94,13 @@ class OpInfoMap {
}
}
const
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>&
map
()
const
{
return
map_
;
}
const
std
::
unordered_map
<
std
::
string
,
OpInfo
>&
map
()
const
{
return
map_
;
}
std
::
unordered_map
<
std
::
string
,
OpInfo
>*
mutable_map
()
{
return
&
map_
;
}
private:
OpInfoMap
()
=
default
;
std
::
unordered_map
<
std
::
string
,
const
OpInfo
>
map_
;
std
::
unordered_map
<
std
::
string
,
OpInfo
>
map_
;
DISABLE_COPY_AND_ASSIGN
(
OpInfoMap
);
};
...
...
paddle/framework/operator.h
浏览文件 @
b44f4ccb
...
...
@@ -636,7 +636,9 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual
void
InferShape
(
InferShapeContext
*
ctx
)
const
=
0
;
virtual
void
InferShape
(
InferShapeContext
*
ctx
)
const
{
OpInfoMap
::
Instance
().
Get
(
Type
()).
infer_shape_
(
ctx
);
}
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
...
...
paddle/framework/type_defs.h
浏览文件 @
b44f4ccb
...
...
@@ -28,6 +28,8 @@ class OperatorBase;
class
OpDescBind
;
class
BlockDescBind
;
class
BlockDesc
;
class
InferShapeContext
;
using
VariableNameMap
=
std
::
map
<
std
::
string
,
std
::
vector
<
std
::
string
>>
;
// The order should be as same as framework.proto
...
...
@@ -49,5 +51,7 @@ using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>(
using
InferVarTypeFN
=
std
::
function
<
void
(
const
OpDescBind
&
/*op_desc*/
,
BlockDescBind
*
/*block*/
)
>
;
using
InferShapeFN
=
std
::
function
<
void
(
InferShapeContext
*
)
>
;
}
// namespace framework
}
// namespace paddle
paddle/operators/mul_op.cc
浏览文件 @
b44f4ccb
...
...
@@ -19,11 +19,9 @@ namespace operators {
using
framework
::
Tensor
;
class
MulOp
:
public
framework
::
OperatorWithKernel
{
class
MulOp
ShapeInference
:
public
framework
::
InferShapeBase
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of MulOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
...
...
@@ -137,7 +135,10 @@ class MulOpGrad : public framework::OperatorWithKernel {
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP
(
mul
,
ops
::
MulOp
,
ops
::
MulOpMaker
,
mul_grad
,
ops
::
MulOpGrad
);
REGISTER_OPERATOR
(
mul
,
paddle
::
framework
::
OperatorWithKernel
,
ops
::
MulOpMaker
,
ops
::
MulOpShapeInference
,
paddle
::
framework
::
DefaultGradOpDescMaker
<
true
>
);
REGISTER_OPERATOR
(
mul_grad
,
ops
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
mul
,
ops
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
mul_grad
,
ops
::
MulGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录