Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
af6ef888
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
af6ef888
编写于
3月 15, 2022
作者:
石
石晓伟
提交者:
GitHub
3月 15, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
adjusts the mlir attrs order, test=develop (#40514)
上级
e7057932
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
131 addition
and
70 deletion
+131
-70
paddle/infrt/dialect/phi/data_type.cc
paddle/infrt/dialect/phi/data_type.cc
+14
-14
paddle/infrt/dialect/phi/data_type.h
paddle/infrt/dialect/phi/data_type.h
+8
-8
paddle/infrt/dialect/phi/pass/kernel_op_desc.cc
paddle/infrt/dialect/phi/pass/kernel_op_desc.cc
+3
-3
paddle/infrt/host_context/kernel_registry.cc
paddle/infrt/host_context/kernel_registry.cc
+20
-10
paddle/infrt/host_context/kernel_registry.h
paddle/infrt/host_context/kernel_registry.h
+6
-2
paddle/infrt/host_context/mlir_function_executable.cc
paddle/infrt/host_context/mlir_function_executable.cc
+3
-1
paddle/infrt/host_context/mlir_function_executable.h
paddle/infrt/host_context/mlir_function_executable.h
+1
-0
paddle/infrt/host_context/mlir_to_runtime_translate.cc
paddle/infrt/host_context/mlir_to_runtime_translate.cc
+55
-16
paddle/infrt/host_context/mlir_to_runtime_translate.h
paddle/infrt/host_context/mlir_to_runtime_translate.h
+2
-1
paddle/infrt/kernel/phi/dense_tensor_kernels.cc
paddle/infrt/kernel/phi/dense_tensor_kernels.cc
+6
-6
paddle/infrt/kernel/phi/dense_tensor_kernels.h
paddle/infrt/kernel/phi/dense_tensor_kernels.h
+1
-1
paddle/infrt/kernel/phi/registry.cc
paddle/infrt/kernel/phi/registry.cc
+8
-4
paddle/infrt/kernel/tensor_kernels.cc
paddle/infrt/kernel/tensor_kernels.cc
+3
-3
paddle/infrt/tests/dialect/phi/dense_tensor.mlir
paddle/infrt/tests/dialect/phi/dense_tensor.mlir
+1
-1
未找到文件。
paddle/infrt/dialect/phi/data_type.cc
浏览文件 @
af6ef888
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
namespace
infrt
{
namespace
infrt
{
phi
::
Backend
cvtTarget2
Phi
(
TargetType
target
)
{
phi
::
Backend
ConvertTargetTo
Phi
(
TargetType
target
)
{
switch
(
target
)
{
switch
(
target
)
{
case
TargetType
::
CPU
:
case
TargetType
::
CPU
:
return
phi
::
Backend
::
CPU
;
return
phi
::
Backend
::
CPU
;
...
@@ -27,7 +27,7 @@ phi::Backend cvtTarget2Phi(TargetType target) {
...
@@ -27,7 +27,7 @@ phi::Backend cvtTarget2Phi(TargetType target) {
}
}
}
}
TargetType
cv
tTargetFromPhi
(
phi
::
Backend
backend
)
{
TargetType
Conver
tTargetFromPhi
(
phi
::
Backend
backend
)
{
switch
(
backend
)
{
switch
(
backend
)
{
case
phi
::
Backend
::
CPU
:
case
phi
::
Backend
::
CPU
:
return
TargetType
::
CPU
;
return
TargetType
::
CPU
;
...
@@ -38,7 +38,7 @@ TargetType cvtTargetFromPhi(phi::Backend backend) {
...
@@ -38,7 +38,7 @@ TargetType cvtTargetFromPhi(phi::Backend backend) {
}
}
}
}
phi
::
DataType
cvtPrecision2
Phi
(
PrecisionType
precision
)
{
phi
::
DataType
ConvertPrecisionTo
Phi
(
PrecisionType
precision
)
{
#define CONVERT_PRECISION_TO_PHI(Precision) \
#define CONVERT_PRECISION_TO_PHI(Precision) \
case PrecisionType::Precision: \
case PrecisionType::Precision: \
return phi::DataType::Precision;
return phi::DataType::Precision;
...
@@ -61,7 +61,7 @@ phi::DataType cvtPrecision2Phi(PrecisionType precision) {
...
@@ -61,7 +61,7 @@ phi::DataType cvtPrecision2Phi(PrecisionType precision) {
#undef CONVERT_PRECISION_TO_PHI
#undef CONVERT_PRECISION_TO_PHI
}
}
PrecisionType
cv
tPrecisionFromPhi
(
phi
::
DataType
datatype
)
{
PrecisionType
Conver
tPrecisionFromPhi
(
phi
::
DataType
datatype
)
{
#define CONVERT_PRECISION_FROM_PHI(Precision) \
#define CONVERT_PRECISION_FROM_PHI(Precision) \
case phi::DataType::Precision: \
case phi::DataType::Precision: \
return PrecisionType::Precision;
return PrecisionType::Precision;
...
@@ -84,7 +84,7 @@ PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) {
...
@@ -84,7 +84,7 @@ PrecisionType cvtPrecisionFromPhi(phi::DataType datatype) {
#undef CONVERT_PRECISION_FROM_PHI
#undef CONVERT_PRECISION_FROM_PHI
}
}
phi
::
DataLayout
cvtLayout2
Phi
(
LayoutType
layout
)
{
phi
::
DataLayout
ConvertLayoutTo
Phi
(
LayoutType
layout
)
{
switch
(
layout
)
{
switch
(
layout
)
{
case
LayoutType
::
NCHW
:
case
LayoutType
::
NCHW
:
return
phi
::
DataLayout
::
NCHW
;
return
phi
::
DataLayout
::
NCHW
;
...
@@ -97,7 +97,7 @@ phi::DataLayout cvtLayout2Phi(LayoutType layout) {
...
@@ -97,7 +97,7 @@ phi::DataLayout cvtLayout2Phi(LayoutType layout) {
}
}
}
}
LayoutType
cv
tLayoutFromPhi
(
phi
::
DataLayout
layout
)
{
LayoutType
Conver
tLayoutFromPhi
(
phi
::
DataLayout
layout
)
{
switch
(
layout
)
{
switch
(
layout
)
{
case
phi
::
DataLayout
::
NCHW
:
case
phi
::
DataLayout
::
NCHW
:
return
LayoutType
::
NCHW
;
return
LayoutType
::
NCHW
;
...
@@ -110,16 +110,16 @@ LayoutType cvtLayoutFromPhi(phi::DataLayout layout) {
...
@@ -110,16 +110,16 @@ LayoutType cvtLayoutFromPhi(phi::DataLayout layout) {
}
}
}
}
phi
::
KernelKey
cvtPlace2
Phi
(
const
Place
&
place
)
{
phi
::
KernelKey
ConvertPlaceTo
Phi
(
const
Place
&
place
)
{
return
phi
::
KernelKey
(
cvtTarget2
Phi
(
place
.
target
),
return
phi
::
KernelKey
(
ConvertTargetTo
Phi
(
place
.
target
),
cvtLayout2
Phi
(
place
.
layout
),
ConvertLayoutTo
Phi
(
place
.
layout
),
cvtPrecision2
Phi
(
place
.
precision
));
ConvertPrecisionTo
Phi
(
place
.
precision
));
}
}
Place
cv
tPlaceFromPhi
(
phi
::
TensorArgDef
tensor_arg
)
{
Place
Conver
tPlaceFromPhi
(
phi
::
TensorArgDef
tensor_arg
)
{
return
Place
(
cv
tTargetFromPhi
(
tensor_arg
.
backend
),
return
Place
(
Conver
tTargetFromPhi
(
tensor_arg
.
backend
),
cv
tPrecisionFromPhi
(
tensor_arg
.
dtype
),
Conver
tPrecisionFromPhi
(
tensor_arg
.
dtype
),
cv
tLayoutFromPhi
(
tensor_arg
.
layout
));
Conver
tLayoutFromPhi
(
tensor_arg
.
layout
));
}
}
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/phi/data_type.h
浏览文件 @
af6ef888
...
@@ -23,16 +23,16 @@
...
@@ -23,16 +23,16 @@
namespace
infrt
{
namespace
infrt
{
phi
::
Backend
cvtTarget2
Phi
(
TargetType
target
);
phi
::
Backend
ConvertTargetTo
Phi
(
TargetType
target
);
TargetType
cv
tTargetFromPhi
(
phi
::
Backend
backend
);
TargetType
Conver
tTargetFromPhi
(
phi
::
Backend
backend
);
phi
::
DataType
cvtPrecision2
Phi
(
PrecisionType
precision
);
phi
::
DataType
ConvertPrecisionTo
Phi
(
PrecisionType
precision
);
PrecisionType
cv
tPrecisionFromPhi
(
phi
::
DataType
datatype
);
PrecisionType
Conver
tPrecisionFromPhi
(
phi
::
DataType
datatype
);
phi
::
DataLayout
cvtLayout2
Phi
(
LayoutType
layout
);
phi
::
DataLayout
ConvertLayoutTo
Phi
(
LayoutType
layout
);
LayoutType
cv
tLayoutFromPhi
(
phi
::
DataLayout
layout
);
LayoutType
Conver
tLayoutFromPhi
(
phi
::
DataLayout
layout
);
phi
::
KernelKey
cvtPlace2
Phi
(
const
Place
&
place
);
phi
::
KernelKey
ConvertPlaceTo
Phi
(
const
Place
&
place
);
Place
cv
tPlaceFromPhi
(
phi
::
TensorArgDef
tensor_arg
);
Place
Conver
tPlaceFromPhi
(
phi
::
TensorArgDef
tensor_arg
);
}
// namespace infrt
}
// namespace infrt
paddle/infrt/dialect/phi/pass/kernel_op_desc.cc
浏览文件 @
af6ef888
...
@@ -80,7 +80,7 @@ std::vector<PhiKernelDesc> getCandidateKernels(
...
@@ -80,7 +80,7 @@ std::vector<PhiKernelDesc> getCandidateKernels(
phi
::
KernelKeyMap
kernel_key_map
=
phi
::
KernelKeyMap
kernel_key_map
=
phi
::
KernelFactory
::
Instance
().
SelectKernelMap
(
name
);
phi
::
KernelFactory
::
Instance
().
SelectKernelMap
(
name
);
for
(
Place
place
:
valid_palces
)
{
for
(
Place
place
:
valid_palces
)
{
phi
::
KernelKey
kernel_key
=
cvtPlace2
Phi
(
place
);
phi
::
KernelKey
kernel_key
=
ConvertPlaceTo
Phi
(
place
);
if
(
kernel_key_map
.
find
(
kernel_key
)
==
kernel_key_map
.
end
())
{
if
(
kernel_key_map
.
find
(
kernel_key
)
==
kernel_key_map
.
end
())
{
kernel_key
=
phi
::
KernelKey
(
kernel_key
.
backend
(),
kernel_key
=
phi
::
KernelKey
(
kernel_key
.
backend
(),
phi
::
DataLayout
::
ALL_LAYOUT
,
phi
::
DataLayout
::
ALL_LAYOUT
,
...
@@ -97,10 +97,10 @@ std::vector<PhiKernelDesc> getCandidateKernels(
...
@@ -97,10 +97,10 @@ std::vector<PhiKernelDesc> getCandidateKernels(
const
paddle
::
SmallVector
<
phi
::
TensorArgDef
>&
output_arg
=
const
paddle
::
SmallVector
<
phi
::
TensorArgDef
>&
output_arg
=
args_def
.
output_defs
();
args_def
.
output_defs
();
for
(
auto
tensor_arg
:
input_arg
)
{
for
(
auto
tensor_arg
:
input_arg
)
{
phi_kernel_desc
.
inputsType
.
emplace_back
(
cv
tPlaceFromPhi
(
tensor_arg
));
phi_kernel_desc
.
inputsType
.
emplace_back
(
Conver
tPlaceFromPhi
(
tensor_arg
));
}
}
for
(
auto
tensor_arg
:
output_arg
)
{
for
(
auto
tensor_arg
:
output_arg
)
{
phi_kernel_desc
.
outputsType
.
emplace_back
(
cv
tPlaceFromPhi
(
tensor_arg
));
phi_kernel_desc
.
outputsType
.
emplace_back
(
Conver
tPlaceFromPhi
(
tensor_arg
));
}
}
candidate_kernels
.
emplace_back
(
phi_kernel_desc
);
candidate_kernels
.
emplace_back
(
phi_kernel_desc
);
}
}
...
...
paddle/infrt/host_context/kernel_registry.cc
浏览文件 @
af6ef888
...
@@ -23,8 +23,9 @@ namespace infrt {
...
@@ -23,8 +23,9 @@ namespace infrt {
namespace
host_context
{
namespace
host_context
{
struct
KernelRegistry
::
Impl
{
struct
KernelRegistry
::
Impl
{
std
::
unordered_map
<
std
::
string
,
KernelImplementation
>
data
;
std
::
unordered_map
<
std
::
string
,
std
::
unordered_map
<
std
::
string
,
llvm
::
SmallVector
<
std
::
string
,
4
>>
attr_names
;
std
::
pair
<
KernelImplementation
,
std
::
vector
<
const
char
*>>>
data
;
};
};
KernelRegistry
::
KernelRegistry
()
:
impl_
(
std
::
make_unique
<
Impl
>
())
{}
KernelRegistry
::
KernelRegistry
()
:
impl_
(
std
::
make_unique
<
Impl
>
())
{}
...
@@ -33,20 +34,29 @@ void KernelRegistry::AddKernel(const std::string &key,
...
@@ -33,20 +34,29 @@ void KernelRegistry::AddKernel(const std::string &key,
KernelImplementation
fn
)
{
KernelImplementation
fn
)
{
CHECK
(
!
impl_
->
data
.
count
(
key
))
<<
"kernel ["
<<
key
CHECK
(
!
impl_
->
data
.
count
(
key
))
<<
"kernel ["
<<
key
<<
"] is registered twice"
;
<<
"] is registered twice"
;
impl_
->
data
.
emplace
(
key
,
fn
);
impl_
->
data
.
emplace
(
key
,
std
::
make_pair
(
std
::
move
(
fn
),
std
::
vector
<
const
char
*>
{}));
}
}
void
KernelRegistry
::
AddKernelAttrNameList
(
const
std
::
vector
<
const
char
*>
&
KernelRegistry
::
GetAttrNameList
(
const
std
::
string
&
key
,
const
std
::
vector
<
std
::
string
>
&
names
)
{
const
std
::
string
&
key
)
const
{
CHECK
(
!
impl_
->
attr_names
.
count
(
key
))
CHECK
(
impl_
->
data
.
count
(
key
));
<<
"kernel ["
<<
key
<<
"] is registered twice in attribute names"
;
return
impl_
->
data
[
key
].
second
;
impl_
->
attr_names
.
emplace
(
}
key
,
llvm
::
SmallVector
<
std
::
string
,
4
>
(
names
.
begin
(),
names
.
end
()));
void
KernelRegistry
::
AddKernelWithAttrs
(
const
std
::
string
&
key
,
KernelImplementation
fn
,
std
::
vector
<
const
char
*>
&&
attr_order
)
{
CHECK
(
!
impl_
->
data
.
count
(
key
))
<<
"kernel ["
<<
key
<<
"] is registered twice"
;
impl_
->
data
.
emplace
(
key
,
std
::
make_pair
(
std
::
move
(
fn
),
std
::
move
(
attr_order
)));
}
}
KernelImplementation
KernelRegistry
::
GetKernel
(
const
std
::
string
&
key
)
const
{
KernelImplementation
KernelRegistry
::
GetKernel
(
const
std
::
string
&
key
)
const
{
auto
it
=
impl_
->
data
.
find
(
key
);
auto
it
=
impl_
->
data
.
find
(
key
);
return
it
!=
impl_
->
data
.
end
()
?
it
->
second
:
KernelImplementation
{};
return
it
!=
impl_
->
data
.
end
()
?
it
->
second
.
first
:
KernelImplementation
{};
}
}
std
::
vector
<
std
::
string
>
KernelRegistry
::
GetKernelList
()
const
{
std
::
vector
<
std
::
string
>
KernelRegistry
::
GetKernelList
()
const
{
...
...
paddle/infrt/host_context/kernel_registry.h
浏览文件 @
af6ef888
...
@@ -34,10 +34,14 @@ class KernelRegistry {
...
@@ -34,10 +34,14 @@ class KernelRegistry {
KernelRegistry
();
KernelRegistry
();
void
AddKernel
(
const
std
::
string
&
key
,
KernelImplementation
fn
);
void
AddKernel
(
const
std
::
string
&
key
,
KernelImplementation
fn
);
void
AddKernelAttrNameList
(
const
std
::
string
&
key
,
void
AddKernelWithAttrs
(
const
std
::
string
&
key
,
const
std
::
vector
<
std
::
string
>
&
names
);
KernelImplementation
fn
,
std
::
vector
<
const
char
*>
&&
attrs_order
);
KernelImplementation
GetKernel
(
const
std
::
string
&
key
)
const
;
KernelImplementation
GetKernel
(
const
std
::
string
&
key
)
const
;
const
std
::
vector
<
const
char
*>
&
GetAttrNameList
(
const
std
::
string
&
key
)
const
;
std
::
vector
<
std
::
string
>
GetKernelList
()
const
;
std
::
vector
<
std
::
string
>
GetKernelList
()
const
;
size_t
size
()
const
;
size_t
size
()
const
;
...
...
paddle/infrt/host_context/mlir_function_executable.cc
浏览文件 @
af6ef888
...
@@ -43,6 +43,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
...
@@ -43,6 +43,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
func_op
.
getNumResults
()),
func_op
.
getNumResults
()),
MlirToRuntimeTranslator
(
&
core_runtime_builder_
),
MlirToRuntimeTranslator
(
&
core_runtime_builder_
),
region_
(
&
func_op
.
getRegion
()),
region_
(
&
func_op
.
getRegion
()),
kernel_registry_
(
kernel_registry
),
core_runtime_builder_
(
kernel_registry
),
core_runtime_builder_
(
kernel_registry
),
function_table_
(
function_table
)
{}
function_table_
(
function_table
)
{}
...
@@ -54,6 +55,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
...
@@ -54,6 +55,7 @@ MlirFunctionExecutable::MlirFunctionExecutable(
:
Function
(
""
,
func_type
.
getNumInputs
(),
func_type
.
getNumResults
()),
:
Function
(
""
,
func_type
.
getNumInputs
(),
func_type
.
getNumResults
()),
MlirToRuntimeTranslator
(
&
core_runtime_builder_
),
MlirToRuntimeTranslator
(
&
core_runtime_builder_
),
region_
(
region
),
region_
(
region
),
kernel_registry_
(
kernel_registry
),
core_runtime_builder_
(
kernel_registry
),
core_runtime_builder_
(
kernel_registry
),
function_table_
(
function_table
)
{}
function_table_
(
function_table
)
{}
...
@@ -90,7 +92,7 @@ void MlirFunctionExecutable::BuildExecutables(
...
@@ -90,7 +92,7 @@ void MlirFunctionExecutable::BuildExecutables(
if
(
EmitCallOp
(
&
op
,
&
function_table_
))
continue
;
if
(
EmitCallOp
(
&
op
,
&
function_table_
))
continue
;
if
(
EmitGeneralOp
(
&
op
))
continue
;
if
(
EmitGeneralOp
(
&
op
,
*
kernel_registry_
))
continue
;
LOG
(
FATAL
)
<<
"Not supported op: "
<<
DumpToString
(
op
);
LOG
(
FATAL
)
<<
"Not supported op: "
<<
DumpToString
(
op
);
}
}
...
...
paddle/infrt/host_context/mlir_function_executable.h
浏览文件 @
af6ef888
...
@@ -70,6 +70,7 @@ class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator {
...
@@ -70,6 +70,7 @@ class MlirFunctionExecutable : public Function, public MlirToRuntimeTranslator {
private:
private:
mlir
::
Region
*
region_
{};
mlir
::
Region
*
region_
{};
KernelRegistry
*
kernel_registry_
{};
CoreRuntimeBuilder
core_runtime_builder_
;
CoreRuntimeBuilder
core_runtime_builder_
;
MlirToRuntimeTranslator
::
function_defs_t
&
function_table_
;
MlirToRuntimeTranslator
::
function_defs_t
&
function_table_
;
std
::
function
<
void
()
>
copy_res_fn_
;
std
::
function
<
void
()
>
copy_res_fn_
;
...
...
paddle/infrt/host_context/mlir_to_runtime_translate.cc
浏览文件 @
af6ef888
...
@@ -270,7 +270,8 @@ static bool IsReturn(mlir::Operation* op) {
...
@@ -270,7 +270,8 @@ static bool IsReturn(mlir::Operation* op) {
return
op
->
getName
().
getStringRef
()
==
"infrt.return"
;
return
op
->
getName
().
getStringRef
()
==
"infrt.return"
;
}
}
bool
MlirToRuntimeTranslator
::
EmitGeneralOp
(
mlir
::
Operation
*
op
)
{
bool
MlirToRuntimeTranslator
::
EmitGeneralOp
(
mlir
::
Operation
*
op
,
const
KernelRegistry
&
kernel_registry
)
{
CHECK
(
impl_
->
runtime
);
CHECK
(
impl_
->
runtime
);
impl_
->
cur_op
=
impl_
->
cur_op
=
impl_
->
runtime
->
NewOpExecutable
(
op
->
getName
().
getStringRef
().
str
());
impl_
->
runtime
->
NewOpExecutable
(
op
->
getName
().
getStringRef
().
str
());
...
@@ -308,42 +309,80 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
...
@@ -308,42 +309,80 @@ bool MlirToRuntimeTranslator::EmitGeneralOp(mlir::Operation* op) {
// process attributes
// process attributes
auto
attrs
=
op
->
getAttrs
();
auto
attrs
=
op
->
getAttrs
();
// MLIR's underlying attr storage type is `Builtin_Dictionary`, and its
// elements
// are sorted by name. The following code adapts the order of function
// signatures
// of the phi operator library.
llvm
::
SmallVector
<
Value
*
,
4
>
tmp
;
tmp
.
resize
(
attrs
.
size
());
const
std
::
string
&
kernel_name
=
op
->
getName
().
getStringRef
().
str
();
const
auto
&
attr_names
=
kernel_registry
.
GetAttrNameList
(
kernel_name
);
if
(
attrs
.
size
()
&&
attr_names
.
empty
())
{
LOG
(
WARNING
)
<<
"The kernel `"
<<
kernel_name
<<
"` has no specified attr order."
;
}
auto
get_offset
=
[](
const
char
*
attr
,
const
std
::
vector
<
const
char
*>&
names
,
const
std
::
string
&
kernel_name
)
->
int
{
for
(
size_t
i
=
0
;
i
<
names
.
size
();
++
i
)
{
if
(
!
std
::
strcmp
(
attr
,
names
[
i
]))
{
return
i
;
}
}
LOG
(
WARNING
)
<<
"The attribute `"
<<
attr
<<
"` of kernel `"
<<
kernel_name
<<
"` is not properly registered with "
"`KernelRegistry::AddKernelWithAttrs()`."
;
return
-
1
;
};
for
(
size_t
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
attrs
.
size
();
i
++
)
{
auto
&
attr
=
attrs
[
i
];
auto
&
attr
=
attrs
[
i
];
int
offset
{};
if
(
attr_names
.
size
())
{
offset
=
get_offset
(
attr
.
getName
().
data
(),
attr_names
,
kernel_name
);
}
else
{
offset
=
i
;
}
CHECK_NE
(
offset
,
-
1
);
if
(
auto
v
=
EmitAttribute
<
int32_t
>
(
attr
.
getValue
()))
{
if
(
auto
v
=
EmitAttribute
<
int32_t
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<
int64_t
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
int64_t
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<
float
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
float
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<
double
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
double
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
string
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
string
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
std
::
move
(
*
v
)
));
tmp
[
offset
]
=
new
Value
(
std
::
move
(
*
v
));
}
else
if
(
auto
v
=
EmitAttribute
<
bool
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
bool
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<::
infrt
::
TargetType
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<::
infrt
::
TargetType
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
}
else
if
(
auto
v
=
EmitAttribute
<::
infrt
::
PrecisionType
>
(
attr
.
getValue
()))
{
EmitAttribute
<::
infrt
::
PrecisionType
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<::
infrt
::
LayoutType
>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<::
infrt
::
LayoutType
>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
*
v
)
);
tmp
[
offset
]
=
new
Value
(
*
v
);
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
int16_t
>>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
int16_t
>>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
std
::
move
(
*
v
)
));
tmp
[
offset
]
=
new
Value
(
std
::
move
(
*
v
));
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
int32_t
>>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
int32_t
>>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
std
::
move
(
*
v
)
));
tmp
[
offset
]
=
new
Value
(
std
::
move
(
*
v
));
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
int64_t
>>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
int64_t
>>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
std
::
move
(
*
v
)
));
tmp
[
offset
]
=
new
Value
(
std
::
move
(
*
v
));
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
float
>>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
float
>>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
std
::
move
(
*
v
)
));
tmp
[
offset
]
=
new
Value
(
std
::
move
(
*
v
));
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
double
>>
(
attr
.
getValue
()))
{
}
else
if
(
auto
v
=
EmitAttribute
<
std
::
vector
<
double
>>
(
attr
.
getValue
()))
{
impl_
->
cur_op
->
AppendAttribute
(
new
Value
(
std
::
move
(
*
v
)
));
tmp
[
offset
]
=
new
Value
(
std
::
move
(
*
v
));
}
else
{
}
else
{
LOG
(
FATAL
)
<<
"Not supported attribute type"
;
LOG
(
FATAL
)
<<
"Not supported attribute type"
;
}
}
}
}
for
(
size_t
i
=
0
;
i
<
tmp
.
size
();
i
++
)
{
impl_
->
cur_op
->
AppendAttribute
(
tmp
[
i
]);
}
// process results
// process results
llvm
::
SmallVector
<
Value
*
,
4
>
res_values
;
llvm
::
SmallVector
<
Value
*
,
4
>
res_values
;
for
(
int
i
=
0
,
e
=
op
->
getNumResults
();
i
<
e
;
i
++
)
{
for
(
int
i
=
0
,
e
=
op
->
getNumResults
();
i
<
e
;
i
++
)
{
...
@@ -598,7 +637,7 @@ class MlirProgramTestExecutor : public MlirToRuntimeTranslator {
...
@@ -598,7 +637,7 @@ class MlirProgramTestExecutor : public MlirToRuntimeTranslator {
llvm
::
SmallVector
<
mlir
::
Value
,
3
>
results
;
llvm
::
SmallVector
<
mlir
::
Value
,
3
>
results
;
if
(
EmitReturnOp
(
&
op
,
&
results
))
continue
;
if
(
EmitReturnOp
(
&
op
,
&
results
))
continue
;
if
(
EmitCallOp
(
&
op
,
&
impl_
->
func_defs
))
continue
;
if
(
EmitCallOp
(
&
op
,
&
impl_
->
func_defs
))
continue
;
if
(
EmitGeneralOp
(
&
op
))
continue
;
if
(
EmitGeneralOp
(
&
op
,
*
registry
))
continue
;
LOG
(
FATAL
)
<<
"Not supported op: "
<<
DumpToString
(
op
);
LOG
(
FATAL
)
<<
"Not supported op: "
<<
DumpToString
(
op
);
}
}
...
...
paddle/infrt/host_context/mlir_to_runtime_translate.h
浏览文件 @
af6ef888
...
@@ -63,7 +63,8 @@ class MlirToRuntimeTranslator {
...
@@ -63,7 +63,8 @@ class MlirToRuntimeTranslator {
//! Emit a "ts.build_shape" operation.
//! Emit a "ts.build_shape" operation.
bool
EmitBuildShapeOp
(
mlir
::
Operation
*
op
);
bool
EmitBuildShapeOp
(
mlir
::
Operation
*
op
);
//! Emit an operation other than the special cases above.
//! Emit an operation other than the special cases above.
bool
EmitGeneralOp
(
mlir
::
Operation
*
op
);
bool
EmitGeneralOp
(
mlir
::
Operation
*
op
,
const
KernelRegistry
&
kernel_registry
);
//! Emit all the functions.
//! Emit all the functions.
bool
EmitFunctions
();
bool
EmitFunctions
();
...
...
paddle/infrt/kernel/phi/dense_tensor_kernels.cc
浏览文件 @
af6ef888
...
@@ -23,23 +23,23 @@ namespace phi {
...
@@ -23,23 +23,23 @@ namespace phi {
::
phi
::
DenseTensor
CreateDenseTensor
(
::
phi
::
DenseTensor
CreateDenseTensor
(
const
::
phi
::
CPUContext
&
context
,
const
::
phi
::
CPUContext
&
context
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
dims
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
dims
,
host_context
::
Attribute
<::
infrt
::
LayoutType
>
layout
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
lod
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
lod
,
host_context
::
Attribute
<::
infrt
::
LayoutType
>
layout
,
host_context
::
Attribute
<::
infrt
::
PrecisionType
>
precision
)
{
host_context
::
Attribute
<::
infrt
::
PrecisionType
>
precision
)
{
return
::
phi
::
DenseTensor
(
return
::
phi
::
DenseTensor
(
const_cast
<::
phi
::
Allocator
*>
(
&
context
.
GetAllocator
()),
const_cast
<::
phi
::
Allocator
*>
(
&
context
.
GetAllocator
()),
::
phi
::
DenseTensorMeta
(
cvtPrecision2
Phi
(
precision
.
get
()),
::
phi
::
DenseTensorMeta
(
ConvertPrecisionTo
Phi
(
precision
.
get
()),
::
phi
::
make_ddim
(
dims
.
get
()),
::
phi
::
make_ddim
(
dims
.
get
()),
cvtLayout2
Phi
(
layout
.
get
()),
ConvertLayoutTo
Phi
(
layout
.
get
()),
{}));
{}));
}
}
void
FillDenseTensorF32
(
::
phi
::
DenseTensor
*
dense_tensor
,
void
FillDenseTensorF32
(
::
phi
::
DenseTensor
*
dense_tensor
,
host_context
::
Attribute
<
std
::
vector
<
float
>>
value
s
)
{
host_context
::
Attribute
<
std
::
vector
<
float
>>
value
)
{
auto
place
=
::
phi
::
CPUPlace
();
auto
place
=
::
phi
::
CPUPlace
();
float
*
a_data
=
dense_tensor
->
mutable_data
<
float
>
(
place
);
float
*
a_data
=
dense_tensor
->
mutable_data
<
float
>
(
place
);
for
(
int64_t
i
=
0
;
i
<
dense_tensor
->
numel
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
dense_tensor
->
numel
();
++
i
)
{
a_data
[
i
]
=
(
value
s
.
get
())[
i
];
a_data
[
i
]
=
(
value
.
get
())[
i
];
}
}
}
}
...
@@ -57,7 +57,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
...
@@ -57,7 +57,7 @@ void PrintDenseTensor(::phi::DenseTensor* dense_tensor) {
::
phi
::
DDim
dims
=
dense_tensor
->
dims
();
::
phi
::
DDim
dims
=
dense_tensor
->
dims
();
std
::
cout
<<
"dense_tensor: shape=shape"
<<
dims
.
to_str
()
<<
","
std
::
cout
<<
"dense_tensor: shape=shape"
<<
dims
.
to_str
()
<<
","
<<
" value
s
=["
;
<<
" value=["
;
switch
(
dense_tensor
->
dtype
())
{
switch
(
dense_tensor
->
dtype
())
{
PRINT_META_DATA
(
FLOAT32
,
float
);
PRINT_META_DATA
(
FLOAT32
,
float
);
PRINT_META_DATA
(
INT32
,
int32_t
);
PRINT_META_DATA
(
INT32
,
int32_t
);
...
...
paddle/infrt/kernel/phi/dense_tensor_kernels.h
浏览文件 @
af6ef888
...
@@ -26,8 +26,8 @@ namespace phi {
...
@@ -26,8 +26,8 @@ namespace phi {
::
phi
::
DenseTensor
CreateDenseTensor
(
::
phi
::
DenseTensor
CreateDenseTensor
(
const
::
phi
::
CPUContext
&
context
,
const
::
phi
::
CPUContext
&
context
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
dims
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
dims
,
host_context
::
Attribute
<::
infrt
::
LayoutType
>
layout
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
lod
,
host_context
::
Attribute
<
std
::
vector
<
int64_t
>>
lod
,
host_context
::
Attribute
<::
infrt
::
LayoutType
>
layout
,
host_context
::
Attribute
<::
infrt
::
PrecisionType
>
precision
);
host_context
::
Attribute
<::
infrt
::
PrecisionType
>
precision
);
void
FillDenseTensorF32
(
::
phi
::
DenseTensor
*
dense_tensor
,
void
FillDenseTensorF32
(
::
phi
::
DenseTensor
*
dense_tensor
,
...
...
paddle/infrt/kernel/phi/registry.cc
浏览文件 @
af6ef888
...
@@ -34,10 +34,14 @@ namespace kernel {
...
@@ -34,10 +34,14 @@ namespace kernel {
void
RegisterPhiKernels
(
host_context
::
KernelRegistry
*
registry
)
{
void
RegisterPhiKernels
(
host_context
::
KernelRegistry
*
registry
)
{
registry
->
AddKernel
(
"phi_dt.create_context.cpu"
,
registry
->
AddKernel
(
"phi_dt.create_context.cpu"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateCPUContext
));
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateCPUContext
));
registry
->
AddKernel
(
"phi_dt.create_dense_tensor"
,
registry
->
AddKernelWithAttrs
(
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateDenseTensor
));
"phi_dt.create_dense_tensor"
,
registry
->
AddKernel
(
"phi_dt.fill_dense_tensor.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
CreateDenseTensor
),
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
FillDenseTensorF32
));
{
"dims"
,
"lod"
,
"layout"
,
"precision"
});
registry
->
AddKernelWithAttrs
(
"phi_dt.fill_dense_tensor.f32"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
FillDenseTensorF32
),
{
"value"
});
registry
->
AddKernel
(
"phi_dt.print_tensor"
,
registry
->
AddKernel
(
"phi_dt.print_tensor"
,
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
PrintDenseTensor
));
INFRT_KERNEL
(
infrt
::
kernel
::
phi
::
PrintDenseTensor
));
}
}
...
...
paddle/infrt/kernel/tensor_kernels.cc
浏览文件 @
af6ef888
...
@@ -111,9 +111,9 @@ void NaiveMatmul(const DenseHostTensor &x,
...
@@ -111,9 +111,9 @@ void NaiveMatmul(const DenseHostTensor &x,
/// ===== Kernel end ====
/// ===== Kernel end ====
void
RegisterTensorKernels
(
host_context
::
KernelRegistry
*
registry
)
{
void
RegisterTensorKernels
(
host_context
::
KernelRegistry
*
registry
)
{
registry
->
AddKernel
(
"dt.create_uninit_tensor.f32"
,
registry
->
AddKernel
WithAttrs
(
"dt.create_uninit_tensor.f32"
,
INFRT_KERNEL
(
CreateUninitTensor
<
float
>
));
INFRT_KERNEL
(
CreateUninitTensor
<
float
>
),
registry
->
AddKernelAttrNameList
(
"dt.create_uninit_tensor.f32"
,
{
"shape"
});
{
"shape"
});
registry
->
AddKernel
(
"dt.print_tensor"
,
INFRT_KERNEL
(
PrintTensor
));
registry
->
AddKernel
(
"dt.print_tensor"
,
INFRT_KERNEL
(
PrintTensor
));
registry
->
AddKernel
(
"dt.fill_tensor_with_constant.f32"
,
registry
->
AddKernel
(
"dt.fill_tensor_with_constant.f32"
,
INFRT_KERNEL
(
FillTensorWithConstant
<
float
>
));
INFRT_KERNEL
(
FillTensorWithConstant
<
float
>
));
...
...
paddle/infrt/tests/dialect/phi/dense_tensor.mlir
浏览文件 @
af6ef888
...
@@ -9,7 +9,7 @@ func @sign_any_float32_execute() {
...
@@ -9,7 +9,7 @@ func @sign_any_float32_execute() {
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
"phi_dt.fill_dense_tensor.f32"(%t) {value=[3.8:f32]} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
%e = "phi_cpu.sign.float32.any"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
%e = "phi_cpu.sign.float32.any"(%ctx, %t) : (!phi.context<CPU>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> (!infrt.dense_tensor<CPU, FP32, NCHW>)
// CHECK: dense_tensor: shape=shape[1], value
s
=[1]
// CHECK: dense_tensor: shape=shape[1], value=[1]
"phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
"phi_dt.print_tensor" (%e) : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> ()
infrt.return
infrt.return
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录