Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8f6c0a0f
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8f6c0a0f
编写于
10月 28, 2017
作者:
Y
Yu Yang
提交者:
GitHub
10月 28, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Extract InferShape to many cc files (#5174)
* Shrink Operator.h * Fix CI compile
上级
5906baa3
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
334 addition
and
288 deletion
+334
-288
paddle/framework/CMakeLists.txt
paddle/framework/CMakeLists.txt
+3
-2
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+130
-2
paddle/framework/op_registry.h
paddle/framework/op_registry.h
+1
-0
paddle/framework/operator.cc
paddle/framework/operator.cc
+132
-0
paddle/framework/operator.h
paddle/framework/operator.h
+1
-247
paddle/framework/shape_inference.cc
paddle/framework/shape_inference.cc
+54
-0
paddle/framework/shape_inference.h
paddle/framework/shape_inference.h
+13
-37
未找到文件。
paddle/framework/CMakeLists.txt
浏览文件 @
8f6c0a0f
...
...
@@ -24,9 +24,10 @@ cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
cc_library
(
op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute
)
cc_test
(
op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker
)
cc_library
(
op_info SRCS op_info.cc DEPS attribute framework_proto
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog
)
cc_library
(
shape_inference SRCS shape_inference.cc DEPS ddim attribute
)
cc_library
(
operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference
)
cc_test
(
operator_test SRCS operator_test.cc DEPS operator op_registry
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS
attribute ddim
op_info operator glog
)
cc_library
(
proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS
shape_inference
op_info operator glog
)
cc_library
(
op_registry SRCS op_registry.cc DEPS op_proto_maker op_info operator glog proto_desc
)
cc_test
(
op_registry_test SRCS op_registry_test.cc DEPS op_registry
)
...
...
paddle/framework/op_desc.cc
浏览文件 @
8f6c0a0f
...
...
@@ -16,15 +16,51 @@ limitations under the License. */
#include <functional>
#include <mutex>
#include <unordered_map>
#include "glog/logging.h"
#include "paddle/framework/block_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/program_desc.h"
#include "glog/logging.h"
#include "paddle/framework/shape_inference.h"
namespace
paddle
{
namespace
framework
{
class
OpDescBind
;
class
BlockDescBind
;
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
);
bool
HasInput
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
;
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
;
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
;
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
;
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
AttrReader
Attrs
()
const
override
;
const
std
::
vector
<
std
::
string
>
&
Inputs
(
const
std
::
string
&
name
)
const
override
;
const
std
::
vector
<
std
::
string
>
&
Outputs
(
const
std
::
string
&
name
)
const
override
;
private:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
;
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
;
const
OpDescBind
&
op_
;
const
BlockDescBind
&
block_
;
};
OpDescBind
::
OpDescBind
(
const
std
::
string
&
type
,
const
VariableNameMap
&
inputs
,
const
VariableNameMap
&
outputs
,
const
AttributeMap
&
attrs
)
{
...
...
@@ -288,5 +324,97 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
}
}
CompileTimeInferShapeContext
::
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
bool
CompileTimeInferShapeContext
::
HasInput
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
input_names
=
op_
.
Input
(
name
);
auto
length
=
input_names
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input(%s) should have only one value, "
"but it have %d now"
,
name
,
length
);
return
block_
.
HasVarRecursive
(
input_names
[
0
]);
}
bool
CompileTimeInferShapeContext
::
HasOutput
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
output_names
=
op_
.
Output
(
name
);
auto
length
=
output_names
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output(%s) should have only one value, "
"but it have %d now"
,
name
,
length
);
return
block_
.
HasVarRecursive
(
output_names
[
0
]);
}
bool
CompileTimeInferShapeContext
::
HasInputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
input_names
=
op_
.
Input
(
name
);
if
(
input_names
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
input_names
)
{
if
(
!
block_
.
HasVarRecursive
(
input
))
return
false
;
}
return
true
;
}
bool
CompileTimeInferShapeContext
::
HasOutputs
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
output_names
=
op_
.
Output
(
name
);
if
(
output_names
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
output_names
)
{
if
(
!
block_
.
HasVarRecursive
(
output
))
return
false
;
}
return
true
;
}
DDim
CompileTimeInferShapeContext
::
GetInputDim
(
const
std
::
string
&
name
)
const
{
std
::
vector
<
DDim
>
ddims
=
GetInputsDim
(
name
);
auto
length
=
ddims
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input(%s) should have 1 value, "
"but it has %d now"
,
name
,
length
);
return
ddims
[
0
];
}
void
CompileTimeInferShapeContext
::
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
SetOutputsDim
(
name
,
{
dim
});
}
AttrReader
CompileTimeInferShapeContext
::
Attrs
()
const
{
return
AttrReader
(
op_
.
GetAttrMap
());
}
const
std
::
vector
<
std
::
string
>
&
CompileTimeInferShapeContext
::
Inputs
(
const
std
::
string
&
name
)
const
{
return
op_
.
Input
(
name
);
}
const
std
::
vector
<
std
::
string
>
&
CompileTimeInferShapeContext
::
Outputs
(
const
std
::
string
&
name
)
const
{
return
op_
.
Output
(
name
);
}
DDim
CompileTimeInferShapeContext
::
GetDim
(
const
std
::
string
&
name
)
const
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
return
framework
::
make_ddim
(
var
->
Shape
());
}
void
CompileTimeInferShapeContext
::
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
{
block_
.
FindVarRecursive
(
name
)
->
SetShape
(
framework
::
vectorize
(
dim
));
}
}
// namespace framework
}
// namespace paddle
paddle/framework/op_registry.h
浏览文件 @
8f6c0a0f
...
...
@@ -29,6 +29,7 @@ limitations under the License. */
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/shape_inference.h"
namespace
paddle
{
namespace
framework
{
...
...
paddle/framework/operator.cc
浏览文件 @
8f6c0a0f
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/framework/operator.h"
#include <algorithm>
#include <atomic>
#include "paddle/framework/shape_inference.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -273,5 +274,136 @@ bool OpSupportGPU(const std::string& op_type) {
return
false
;
}
class
RuntimeInferShapeContext
:
public
InferShapeContext
{
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
auto
&
ins
=
Inputs
(
name
);
size_t
length
=
ins
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input %s should have more than one inputs"
,
name
);
auto
ipt
=
ins
[
0
];
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
auto
&
outs
=
Outputs
(
name
);
size_t
length
=
outs
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output %s should have more than one inputs"
,
name
);
auto
ipt
=
outs
[
0
];
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
auto
inputs
=
op_
.
Inputs
(
name
);
if
(
inputs
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
inputs
)
{
if
(
scope_
.
FindVar
(
input
)
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
auto
outputs
=
op_
.
Outputs
(
name
);
if
(
outputs
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
outputs
)
{
if
(
scope_
.
FindVar
(
output
)
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
return
GetDim
(
op_
.
Input
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetDim
(
op_
.
Output
(
name
),
dim
);
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Inputs
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Outputs
(
name
);
}
private:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
return
var
->
Get
<
LoDTensor
>
().
dims
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
PADDLE_THROW
(
"Variable type must be LoDTensor/SelectedRows."
);
}
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
}
else
{
PADDLE_THROW
(
"Variable type must be LoDTensor/SelectedRows."
);
}
}
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
};
void
OperatorWithKernel
::
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
{
VLOG
(
3
)
<<
"Running operator "
<<
this
->
Type
();
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
);
// check if op[type] has kernel registered.
auto
&
all_op_kernels
=
AllOpKernels
();
auto
kernels_iter
=
all_op_kernels
.
find
(
type_
);
if
(
kernels_iter
==
all_op_kernels
.
end
())
{
PADDLE_THROW
(
"op[%s] has no kernel"
,
type_
);
}
// check if op[type] have kernel for kernel_key
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
auto
kernel_key
=
OpKernelKey
(
IndicateDataType
(
ctx
),
dev_ctx
);
auto
kernel_iter
=
kernels
.
find
(
kernel_key
);
if
(
kernel_iter
==
kernels
.
end
())
{
PADDLE_THROW
(
"op[%s] has no kernel with kernel_key[%s]"
,
type_
,
kernel_key
);
}
kernel_iter
->
second
->
Compute
(
ctx
);
}
}
// namespace framework
}
// namespace paddle
paddle/framework/operator.h
浏览文件 @
8f6c0a0f
...
...
@@ -29,7 +29,6 @@ limitations under the License. */
#include "paddle/framework/op_info.h"
#include "paddle/framework/scope.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/shape_inference.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
...
...
@@ -317,226 +316,6 @@ template <>
std
::
vector
<
Tensor
*>
ExecutionContext
::
MultiOutput
<
Tensor
>
(
const
std
::
string
&
name
)
const
;
class
CompileTimeInferShapeContext
:
public
InferShapeContext
{
public:
CompileTimeInferShapeContext
(
const
OpDescBind
&
op
,
const
BlockDescBind
&
block
)
:
op_
(
op
),
block_
(
block
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
input_names
=
op_
.
Input
(
name
);
auto
length
=
input_names
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input(%s) should have only one value, "
"but it have %d now"
,
name
,
length
);
return
block_
.
HasVarRecursive
(
input_names
[
0
]);
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
output_names
=
op_
.
Output
(
name
);
auto
length
=
output_names
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output(%s) should have only one value, "
"but it have %d now"
,
name
,
length
);
return
block_
.
HasVarRecursive
(
output_names
[
0
]);
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
input_names
=
op_
.
Input
(
name
);
if
(
input_names
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
input_names
)
{
if
(
!
block_
.
HasVarRecursive
(
input
))
return
false
;
}
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
const
std
::
vector
<
std
::
string
>&
output_names
=
op_
.
Output
(
name
);
if
(
output_names
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
output_names
)
{
if
(
!
block_
.
HasVarRecursive
(
output
))
return
false
;
}
return
true
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
DDim
>
ddims
=
GetInputsDim
(
name
);
auto
length
=
ddims
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input(%s) should have 1 value, "
"but it has %d now"
,
name
,
length
);
return
ddims
[
0
];
}
void
SetInputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetInputsDim
(
name
,
{
dim
});
}
DDim
GetOutputDim
(
const
std
::
string
&
name
)
const
override
{
std
::
vector
<
DDim
>
ddims
=
GetOutputsDim
(
name
);
auto
length
=
ddims
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output(%s) should have 1 value, "
"but it has %d now"
,
name
,
length
);
return
ddims
[
0
];
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetOutputsDim
(
name
,
{
dim
});
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
GetAttrMap
());
}
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Input
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Output
(
name
);
}
private:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
auto
var
=
block_
.
FindVarRecursive
(
name
);
PADDLE_ENFORCE
(
var
!=
nullptr
,
"Cannot find variable %s"
,
name
);
return
framework
::
make_ddim
(
var
->
Shape
());
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
block_
.
FindVarRecursive
(
name
)
->
SetShape
(
framework
::
vectorize
(
dim
));
}
const
OpDescBind
&
op_
;
const
BlockDescBind
&
block_
;
};
class
RuntimeInferShapeContext
:
public
InferShapeContext
{
public:
RuntimeInferShapeContext
(
const
OperatorBase
&
op
,
const
Scope
&
scope
)
:
op_
(
op
),
scope_
(
scope
)
{}
bool
HasInput
(
const
std
::
string
&
name
)
const
override
{
auto
&
ins
=
Inputs
(
name
);
size_t
length
=
ins
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Input %s should have more than one inputs"
,
name
);
auto
ipt
=
ins
[
0
];
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
}
bool
HasOutput
(
const
std
::
string
&
name
)
const
override
{
auto
&
outs
=
Outputs
(
name
);
size_t
length
=
outs
.
size
();
if
(
length
==
0
)
{
return
false
;
}
PADDLE_ENFORCE_EQ
(
length
,
1UL
,
"Output %s should have more than one inputs"
,
name
);
auto
ipt
=
outs
[
0
];
auto
*
var
=
ipt
==
kEmptyVarName
?
nullptr
:
scope_
.
FindVar
(
ipt
);
return
var
!=
nullptr
;
}
bool
HasInputs
(
const
std
::
string
&
name
)
const
override
{
auto
inputs
=
op_
.
Inputs
(
name
);
if
(
inputs
.
empty
())
{
return
false
;
}
for
(
auto
&
input
:
inputs
)
{
if
(
scope_
.
FindVar
(
input
)
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
bool
HasOutputs
(
const
std
::
string
&
name
)
const
override
{
auto
outputs
=
op_
.
Outputs
(
name
);
if
(
outputs
.
empty
())
{
return
false
;
}
for
(
auto
&
output
:
outputs
)
{
if
(
scope_
.
FindVar
(
output
)
==
nullptr
)
{
return
false
;
}
}
return
true
;
}
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
override
{
return
GetDim
(
op_
.
Input
(
name
));
}
void
SetInputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetDim
(
op_
.
Input
(
name
),
dim
);
}
DDim
GetOutputDim
(
const
std
::
string
&
name
)
const
override
{
return
GetDim
(
op_
.
Output
(
name
));
}
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
SetDim
(
op_
.
Output
(
name
),
dim
);
}
AttrReader
Attrs
()
const
override
{
return
AttrReader
(
op_
.
Attrs
());
}
const
std
::
vector
<
std
::
string
>&
Inputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Inputs
(
name
);
}
const
std
::
vector
<
std
::
string
>&
Outputs
(
const
std
::
string
&
name
)
const
override
{
return
op_
.
Outputs
(
name
);
}
private:
DDim
GetDim
(
const
std
::
string
&
name
)
const
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
return
var
->
Get
<
LoDTensor
>
().
dims
();
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
return
var
->
Get
<
SelectedRows
>
().
GetCompleteDims
();
}
else
{
PADDLE_THROW
(
"Variable type must be LoDTensor/SelectedRows."
);
}
}
void
SetDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
override
{
Variable
*
var
=
scope_
.
FindVar
(
name
);
if
(
var
->
IsType
<
LoDTensor
>
())
{
var
->
GetMutable
<
LoDTensor
>
()
->
Resize
(
dim
);
}
else
if
(
var
->
IsType
<
SelectedRows
>
())
{
var
->
GetMutable
<
SelectedRows
>
()
->
set_height
(
dim
[
0
]);
}
else
{
PADDLE_THROW
(
"Variable type must be LoDTensor/SelectedRows."
);
}
}
const
OperatorBase
&
op_
;
const
Scope
&
scope_
;
};
class
OpKernelBase
{
public:
/**
...
...
@@ -595,32 +374,7 @@ class OperatorWithKernel : public OperatorBase {
:
OperatorBase
(
type
,
inputs
,
outputs
,
attrs
)
{}
void
Run
(
const
Scope
&
scope
,
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
{
VLOG
(
3
)
<<
"Running operator "
<<
this
->
Type
();
RuntimeInferShapeContext
infer_shape_ctx
(
*
this
,
scope
);
this
->
InferShape
(
&
infer_shape_ctx
);
ExecutionContext
ctx
(
*
this
,
scope
,
dev_ctx
);
// check if op[type] has kernel registered.
auto
&
all_op_kernels
=
AllOpKernels
();
auto
kernels_iter
=
all_op_kernels
.
find
(
type_
);
if
(
kernels_iter
==
all_op_kernels
.
end
())
{
PADDLE_THROW
(
"op[%s] has no kernel"
,
type_
);
}
// check if op[type] have kernel for kernel_key
OpKernelMap
&
kernels
=
kernels_iter
->
second
;
auto
kernel_key
=
OpKernelKey
(
IndicateDataType
(
ctx
),
dev_ctx
);
auto
kernel_iter
=
kernels
.
find
(
kernel_key
);
if
(
kernel_iter
==
kernels
.
end
())
{
PADDLE_THROW
(
"op[%s] has no kernel with kernel_key[%s]"
,
type_
,
kernel_key
);
}
kernel_iter
->
second
->
Compute
(
ctx
);
}
const
platform
::
DeviceContext
&
dev_ctx
)
const
final
;
static
std
::
unordered_map
<
std
::
string
/* op_type */
,
OpKernelMap
>&
AllOpKernels
()
{
...
...
paddle/framework/shape_inference.cc
0 → 100644
浏览文件 @
8f6c0a0f
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/shape_inference.h"
namespace
paddle
{
namespace
framework
{
std
::
vector
<
framework
::
DDim
>
InferShapeContext
::
GetInputsDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
return
GetDims
(
names
);
}
void
InferShapeContext
::
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
DDim
>
&
dims
)
{
auto
&
names
=
Outputs
(
name
);
SetDims
(
names
,
dims
);
}
void
InferShapeContext
::
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
,
size_t
j
)
const
{}
std
::
vector
<
framework
::
DDim
>
InferShapeContext
::
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
framework
::
DDim
>
ret
;
ret
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
}
void
InferShapeContext
::
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
framework
::
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
}
// namespace framework
}
// namespace paddle
paddle/framework/shape_inference.h
浏览文件 @
8f6c0a0f
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/attribute.h"
#include "paddle/framework/ddim.h"
namespace
paddle
{
...
...
@@ -21,7 +22,7 @@ namespace framework {
class
InferShapeContext
{
public:
virtual
~
InferShapeContext
()
{}
virtual
~
InferShapeContext
()
=
default
;
virtual
bool
HasInput
(
const
std
::
string
&
name
)
const
=
0
;
virtual
bool
HasOutput
(
const
std
::
string
&
name
)
const
=
0
;
...
...
@@ -29,57 +30,32 @@ class InferShapeContext {
virtual
bool
HasOutputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
framework
::
DDim
GetInputDim
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
framework
::
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Inputs
(
name
);
return
GetDims
(
names
);
}
virtual
void
SetInputDim
(
const
std
::
string
&
name
,
const
framework
::
DDim
&
dim
)
=
0
;
void
SetInputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
DDim
>
&
dims
)
{
auto
&
names
=
Inputs
(
name
);
SetDims
(
names
,
dims
);
}
virtual
framework
::
DDim
GetOutputDim
(
const
std
::
string
&
name
)
const
=
0
;
std
::
vector
<
framework
::
DDim
>
GetOutputsDim
(
const
std
::
string
&
name
)
const
{
const
std
::
vector
<
std
::
string
>
&
names
=
Outputs
(
name
);
return
GetDims
(
names
);
}
std
::
vector
<
framework
::
DDim
>
GetInputsDim
(
const
std
::
string
&
name
)
const
;
virtual
void
SetOutputDim
(
const
std
::
string
&
name
,
const
DDim
&
dim
)
=
0
;
void
SetOutputsDim
(
const
std
::
string
&
name
,
const
std
::
vector
<
framework
::
DDim
>
&
dims
)
{
auto
&
names
=
Outputs
(
name
);
SetDims
(
names
,
dims
);
}
const
std
::
vector
<
framework
::
DDim
>
&
dims
);
virtual
AttrReader
Attrs
()
const
=
0
;
virtual
const
std
::
vector
<
std
::
string
>
&
Inputs
(
const
std
::
string
&
name
)
const
=
0
;
virtual
const
std
::
vector
<
std
::
string
>
&
Outputs
(
const
std
::
string
&
name
)
const
=
0
;
// TODO(qiao) implement this function
void
ShareLoD
(
const
std
::
string
&
in
,
const
std
::
string
&
out
,
size_t
i
=
0
,
size_t
j
=
0
)
const
{}
size_t
j
=
0
)
const
;
protected:
virtual
framework
::
DDim
GetDim
(
const
std
::
string
&
name
)
const
=
0
;
virtual
void
SetDim
(
const
std
::
string
&
name
,
const
framework
::
DDim
&
dim
)
=
0
;
std
::
vector
<
framework
::
DDim
>
GetDims
(
const
std
::
vector
<
std
::
string
>
&
names
)
const
{
std
::
vector
<
framework
::
DDim
>
ret
;
ret
.
reserve
(
names
.
size
());
std
::
transform
(
names
.
begin
(),
names
.
end
(),
std
::
back_inserter
(
ret
),
[
this
](
const
std
::
string
&
name
)
{
return
this
->
GetDim
(
name
);
});
return
ret
;
}
const
std
::
vector
<
std
::
string
>
&
names
)
const
;
void
SetDims
(
const
std
::
vector
<
std
::
string
>
&
names
,
const
std
::
vector
<
framework
::
DDim
>
&
dims
)
{
size_t
length
=
names
.
size
();
PADDLE_ENFORCE_EQ
(
length
,
dims
.
size
());
for
(
size_t
i
=
0
;
i
<
length
;
++
i
)
{
SetDim
(
names
[
i
],
dims
[
i
]);
}
}
const
std
::
vector
<
framework
::
DDim
>
&
dims
);
};
}
// namespace framework
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录