Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
69252fd8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
69252fd8
编写于
12月 13, 2021
作者:
J
jianghaicheng
提交者:
GitHub
12月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add popart_canonicalization p4 (#37967)
上级
bdf5834e
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
572 addition
and
0 deletion
+572
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/reduce_ops.cc
...platform/device/ipu/popart_canonicalization/reduce_ops.cc
+68
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/search_ops.cc
...platform/device/ipu/popart_canonicalization/search_ops.cc
+109
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc
...platform/device/ipu/popart_canonicalization/tensor_ops.cc
+395
-0
未找到文件。
paddle/fluid/platform/device/ipu/popart_canonicalization/reduce_ops.cc
0 → 100644
浏览文件 @
69252fd8
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
reduce_op_handler
(
Graph
*
graph
,
Node
*
node
,
const
std
::
string
&
op_name
)
{
auto
*
op
=
node
->
Op
();
auto
attrs
=
AttributeMap
{};
auto
reduce_all
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"reduce_all"
));
if
(
!
reduce_all
)
{
auto
axes_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"dim"
));
auto
axes
=
std
::
vector
<
int64_t
>
{
axes_
.
begin
(),
axes_
.
end
()};
attrs
.
emplace
(
"axes"
,
axes
);
}
auto
keepdims_
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"keep_dim"
));
auto
keepdims
=
int64_t
{
keepdims_
};
attrs
.
emplace
(
"keepdims"
,
keepdims
);
return
CreateBaseOp
(
graph
,
node
,
op_name
,
node
->
inputs
,
node
->
outputs
,
attrs
);
}
Node
*
reduce_mean_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
reduce_op_handler
(
graph
,
node
,
"popart_reducemean"
);
}
Node
*
reduce_min_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
reduce_op_handler
(
graph
,
node
,
"popart_reducemin"
);
}
Node
*
reduce_sum_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
reduce_op_handler
(
graph
,
node
,
"popart_reducesum"
);
}
Node
*
reduce_max_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
reduce_op_handler
(
graph
,
node
,
"popart_reducemax"
);
}
Node
*
reduce_prod_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
reduce_op_handler
(
graph
,
node
,
"popart_reduceprod"
);
}
REGISTER_HANDLER
(
reduce_mean
,
reduce_mean_handler
);
REGISTER_HANDLER
(
reduce_min
,
reduce_min_handler
);
REGISTER_HANDLER
(
reduce_sum
,
reduce_sum_handler
);
REGISTER_HANDLER
(
reduce_max
,
reduce_max_handler
);
REGISTER_HANDLER
(
reduce_prod
,
reduce_prod_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/search_ops.cc
0 → 100644
浏览文件 @
69252fd8
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
topK_op_handler
(
Graph
*
graph
,
Node
*
node
)
{
VLOG
(
10
)
<<
"[topK_op_handler] entering to handler ..."
;
auto
*
op
=
node
->
Op
();
auto
attrs
=
AttributeMap
{};
int
axis_32INT
=
-
1
;
if
(
op
->
HasAttr
(
"axis"
))
{
axis_32INT
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
}
if
(
axis_32INT
==
-
1
)
{
auto
shape
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
int
rank
=
shape
.
size
();
if
(
rank
<
1
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"The dimension of the shape of topK input should be large than 1"
));
}
axis_32INT
=
rank
-
1
;
}
int64_t
axis
=
int64_t
{
axis_32INT
};
attrs
.
emplace
(
"axis"
,
axis
);
bool
largest
=
true
;
if
(
op
->
HasAttr
(
"largest"
))
{
largest
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"largest"
));
}
if
(
largest
)
{
// defaults to 1, largest values
attrs
.
emplace
(
"largest"
,
1
);
}
else
{
attrs
.
emplace
(
"largest"
,
0
);
}
bool
sorted
=
true
;
if
(
op
->
HasAttr
(
"sorted"
))
{
sorted
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"sorted"
));
}
if
(
sorted
)
{
// defaults to 1, sorted results
attrs
.
emplace
(
"sorted"
,
1
);
}
else
{
attrs
.
emplace
(
"sorted"
,
0
);
}
std
::
vector
<
paddle
::
framework
::
ir
::
Node
*>
inputs
=
node
->
inputs
;
if
(
node
->
inputs
.
size
()
==
2
)
{
// Input X tensor and K const tensor
VLOG
(
10
)
<<
"[topK_op_handler] get 2 input tensors."
;
inputs
[
0
]
=
node
->
inputs
[
1
];
// K_t
VLOG
(
10
)
<<
"[topK_op_handler] input node("
<<
inputs
[
0
]
->
Var
()
->
Name
()
<<
")"
;
inputs
[
1
]
=
node
->
inputs
[
0
];
// X
VLOG
(
10
)
<<
"[topK_op_handler] input node("
<<
inputs
[
1
]
->
Var
()
->
Name
()
<<
")"
;
}
else
if
(
node
->
inputs
.
size
()
==
1
)
{
// Input X tensor with k integer
VLOG
(
10
)
<<
"[topK_op_handler] get 1 input tensor."
;
int
k_32INT
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"k"
));
int64_t
k
=
int64_t
{
k_32INT
};
attrs
.
emplace
(
"k"
,
k
);
}
// show output node dtype
for
(
auto
*
o_node
:
node
->
outputs
)
{
auto
*
var
=
o_node
->
Var
();
// see framework.pb.h
// VarType_Type_INT64 = 3,
// VarType_Type_FP32 = 5,
auto
dtype
=
var
->
GetDataType
();
if
(
dtype
==
3
)
{
// poplar does not support int64_t
var
->
SetDataType
(
framework
::
proto
::
VarType
::
INT32
);
}
std
::
string
name
=
var
->
Name
();
VLOG
(
10
)
<<
"[topK_op_handler] output node("
<<
name
<<
") dtype : "
<<
dtype
;
}
VLOG
(
10
)
<<
"[topK_op_handler] leave the handler."
;
return
CreateBaseOp
(
graph
,
node
,
"popart_topk"
,
inputs
,
{
node
->
outputs
[
1
],
node
->
outputs
[
0
]},
attrs
);
}
REGISTER_HANDLER
(
top_k
,
topK_op_handler
);
REGISTER_HANDLER
(
top_k_v2
,
topK_op_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc
0 → 100644
浏览文件 @
69252fd8
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
using
framework
::
Attribute
;
using
framework
::
AttributeMap
;
Node
*
fill_constant_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
if
(
op
->
HasInput
(
"ShapeTensor"
)
&&
!
op
->
Input
(
"ShapeTensor"
).
empty
())
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"op fill_constant with ShapeTensor"
));
}
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxDtype
(
dtype_
);
auto
dims
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
value_
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"value"
));
size_t
size
=
1
;
for
(
auto
&
dim
:
dims
)
{
size
*=
dim
;
}
Attribute
value
;
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
FP32
:
value
=
std
::
vector
<
float
>
(
size
,
value_
);
break
;
case
framework
::
proto
::
VarType
::
FP64
:
value
=
std
::
vector
<
double
>
(
size
,
value_
);
break
;
case
framework
::
proto
::
VarType
::
INT32
:
value
=
std
::
vector
<
int
>
(
size
,
value_
);
break
;
case
framework
::
proto
::
VarType
::
INT64
:
value
=
std
::
vector
<
int64_t
>
(
size
,
value_
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"fill_constant dtype: %d"
,
dtype_
));
}
return
CreateConst
(
graph
,
node
,
node
->
inputs
,
node
->
outputs
,
AttributeMap
{
{
"value"
,
value
},
{
"dims"
,
dims
},
{
"dtype"
,
dtype
},
});
}
Node
*
gaussian_random_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxDtype
(
dtype_
);
auto
mean
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"mean"
));
auto
scale
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"std"
));
// seed not work
auto
seed_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"seed"
));
auto
seed
=
static_cast
<
float
>
(
seed_
);
return
CreateBaseOp
(
graph
,
node
,
"popart_randomnormal"
,
node
->
inputs
,
node
->
outputs
,
{
{
"shape"
,
shape
},
{
"dtype"
,
dtype
},
{
"mean"
,
mean
},
{
"scale"
,
scale
},
{
"seed"
,
seed
},
});
}
Node
*
uniform_random_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxDtype
(
dtype_
);
auto
high
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"max"
));
auto
low
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"min"
));
// seed not work
auto
seed_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"seed"
));
auto
seed
=
static_cast
<
float
>
(
seed_
);
return
CreateBaseOp
(
graph
,
node
,
"popart_randomuniform"
,
node
->
inputs
,
node
->
outputs
,
{
{
"shape"
,
shape
},
{
"dtype"
,
dtype
},
{
"high"
,
high
},
{
"low"
,
low
},
{
"seed"
,
seed
},
});
}
Node
*
transpose_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
axis_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"axis"
));
std
::
vector
<
int64_t
>
perm
(
axis_
.
begin
(),
axis_
.
end
());
auto
attrs
=
AttributeMap
{{
"perm"
,
perm
}};
auto
new_node_transpose
=
CreateBaseOp
(
graph
,
node
,
"popart_transpose"
,
node
->
inputs
,
{
GetOutputVarNode
(
"Out"
,
node
)},
attrs
);
return
new_node_transpose
;
}
Node
*
reshape_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
shape_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"shape"
));
std
::
vector
<
int64_t
>
shape
(
shape_
.
begin
(),
shape_
.
end
());
auto
attrs
=
AttributeMap
{
{
"value"
,
shape
},
{
"dims"
,
std
::
vector
<
int64_t
>
{
static_cast
<
int64_t
>
(
shape
.
size
())}},
{
"dtype"
,
ONNXDataType
::
INT64
}};
auto
new_node_const
=
CreateBaseOp
(
graph
,
node
,
"popart_constant"
,
{},
{},
attrs
);
auto
new_node_reshape
=
CreateBaseOp
(
graph
,
node
,
"popart_reshape"
,
{
GetInputVarNode
(
"X"
,
node
),
new_node_const
->
outputs
[
0
]},
{
GetOutputVarNode
(
"Out"
,
node
)},
{});
return
new_node_reshape
;
}
Node
*
gather_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
new_node_gather
=
CreateBaseOp
(
graph
,
node
,
"popart_gather"
,
{
GetInputVarNode
(
"X"
,
node
),
GetInputVarNode
(
"Index"
,
node
)},
{
GetOutputVarNode
(
"Out"
,
node
)},
{});
return
new_node_gather
;
}
Node
*
squeeze_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
axes_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"axes"
));
auto
input_shape_
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
std
::
vector
<
int64_t
>
axes
{
axes_
.
begin
(),
axes_
.
end
()};
if
(
axes_
.
empty
())
{
for
(
int
i
=
0
;
i
<
input_shape_
.
size
();
i
++
)
{
if
(
input_shape_
[
i
]
==
1
)
{
axes
.
push_back
(
i
);
}
}
}
auto
new_node_squeeze
=
CreateBaseOp
(
graph
,
node
,
"popart_squeeze"
,
{
GetInputVarNode
(
"X"
,
node
)},
{
GetOutputVarNode
(
"Out"
,
node
)},
{{
"axes"
,
axes
}});
return
new_node_squeeze
;
}
Node
*
cast_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
otype
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"out_dtype"
));
auto
new_node_cast
=
CreateCast
(
graph
,
node
,
node
->
inputs
,
node
->
outputs
,
otype
);
return
new_node_cast
;
}
Node
*
lookup_table_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
padding_idx_
=
BOOST_GET_CONST
(
int64_t
,
op
->
GetAttr
(
"padding_idx"
));
auto
w_shape_
=
GetInputVarNode
(
"W"
,
node
)
->
Var
()
->
GetShape
();
auto
table_size_
=
w_shape_
[
0
];
auto
emb_size_
=
w_shape_
[
1
];
Node
*
w_node
;
if
(
padding_idx_
>=
0
&&
padding_idx_
<
table_size_
)
{
std
::
vector
<
float
>
const_value_
(
emb_size_
,
0
);
std
::
vector
<
int64_t
>
const_shape_
{
1
,
emb_size_
};
auto
concat_const
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
const_value_
},
{
"dims"
,
const_shape_
},
{
"dtype"
,
ONNXDataType
::
FLOAT
}});
auto
axes
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
0
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
step
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
left_start
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
0
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
left_end
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
padding_idx_
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
right_start
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
padding_idx_
+
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
right_end
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
table_size_
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
left_slice
=
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"W"
,
node
),
left_start
->
outputs
[
0
],
left_end
->
outputs
[
0
],
axes
->
outputs
[
0
],
step
->
outputs
[
0
]},
{},
{});
auto
right_slice
=
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"W"
,
node
),
right_start
->
outputs
[
0
],
right_end
->
outputs
[
0
],
axes
->
outputs
[
0
],
step
->
outputs
[
0
]},
{},
{});
if
(
padding_idx_
==
0
)
{
w_node
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
concat_const
->
outputs
[
0
],
right_slice
->
outputs
[
0
]},
{},
{{
"axis"
,
int64_t
(
0
)}});
ClearNode
(
left_start
);
ClearNode
(
left_end
);
ClearNode
(
left_slice
);
}
else
if
(
padding_idx_
==
table_size_
-
1
)
{
w_node
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
left_slice
->
outputs
[
0
],
concat_const
->
outputs
[
0
]},
{},
{{
"axis"
,
int64_t
{
0
}}});
ClearNode
(
right_start
);
ClearNode
(
right_end
);
ClearNode
(
right_slice
);
}
else
{
w_node
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
left_slice
->
outputs
[
0
],
concat_const
->
outputs
[
0
],
right_slice
->
outputs
[
0
]},
{},
{{
"axis"
,
int64_t
{
0
}}});
}
w_node
=
w_node
->
outputs
[
0
];
}
else
{
w_node
=
GetInputVarNode
(
"W"
,
node
);
}
auto
squeeze
=
CreateBaseOp
(
graph
,
node
,
"popart_squeeze"
,
{
GetInputVarNode
(
"Ids"
,
node
)},
{},
{{
"axes"
,
std
::
vector
<
int64_t
>
{
-
1
}}});
auto
gather
=
CreateBaseOp
(
graph
,
node
,
"popart_gather"
,
{
w_node
,
squeeze
->
outputs
[
0
]},
{
GetOutputVarNode
(
"Out"
,
node
)},
{});
return
gather
;
}
Node
*
unsqueeze_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
axes_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"axes"
));
std
::
vector
<
int64_t
>
axes
{
axes_
.
begin
(),
axes_
.
end
()};
auto
new_node_unsqueeze
=
CreateBaseOp
(
graph
,
node
,
"popart_unsqueeze"
,
{
GetInputVarNode
(
"X"
,
node
)},
{
GetOutputVarNode
(
"Out"
,
node
)},
{{
"axes"
,
axes
}});
return
new_node_unsqueeze
;
}
Node
*
concat_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
int64_t
axis_
{
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
))};
auto
new_node_concat
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
node
->
inputs
,
node
->
outputs
,
{{
"axis"
,
axis_
}});
return
new_node_concat
;
}
Node
*
stack_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
int64_t
axis_
{
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
))};
std
::
vector
<
int64_t
>
axes_
{
axis_
};
std
::
vector
<
Node
*>
unsqueeze_outputs_
{};
for
(
auto
input
:
node
->
inputs
)
{
auto
new_unsqueeze_node
=
CreateBaseOp
(
graph
,
node
,
"popart_unsqueeze"
,
{
input
},
{},
{{
"axes"
,
axes_
}});
unsqueeze_outputs_
.
push_back
(
new_unsqueeze_node
->
outputs
[
0
]);
for
(
size_t
i
=
0
;
i
<
input
->
outputs
.
size
();
++
i
)
{
if
(
input
->
outputs
[
i
]
==
node
)
{
input
->
outputs
[
i
]
=
new_unsqueeze_node
;
break
;
}
}
}
auto
new_node_concat
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
unsqueeze_outputs_
,
{
GetOutputVarNode
(
"Y"
,
node
)},
{{
"axis"
,
axis_
}});
return
new_node_concat
;
}
Node
*
shape_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
"popart_shape"
,
node
->
inputs
,
node
->
outputs
);
return
new_node
;
}
Node
*
slice_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
Node
*
starts
=
nullptr
;
if
(
op
->
HasInput
(
"StartsTensor"
)
&&
!
op
->
Input
(
"StartsTensor"
).
empty
())
{
starts
=
GetInputVarNode
(
"StartsTensor"
,
node
);
}
else
{
auto
starts_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"starts"
));
auto
dim
=
int64_t
(
starts_
.
size
());
auto
attr
=
MakeConstAttrMap
<
int
>
(
starts_
,
{
dim
},
ONNXDataType
::
INT32
);
starts
=
CreateConst
(
graph
,
node
,
{},
{},
attr
);
starts
=
starts
->
outputs
[
0
];
}
Node
*
ends
=
nullptr
;
if
(
op
->
HasInput
(
"EndsTensor"
)
&&
!
op
->
Input
(
"EndsTensor"
).
empty
())
{
ends
=
GetInputVarNode
(
"EndsTensor"
,
node
);
}
else
{
auto
ends_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"ends"
));
auto
dim
=
int64_t
(
ends_
.
size
());
auto
attr
=
MakeConstAttrMap
<
int
>
(
ends_
,
{
dim
},
ONNXDataType
::
INT32
);
ends
=
CreateConst
(
graph
,
node
,
{},
{},
attr
);
ends
=
ends
->
outputs
[
0
];
}
Node
*
axes
=
nullptr
;
{
auto
axes_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"axes"
));
auto
dim
=
int64_t
(
axes_
.
size
());
auto
attr
=
MakeConstAttrMap
<
int
>
(
axes_
,
{
dim
},
ONNXDataType
::
INT32
);
axes
=
CreateConst
(
graph
,
node
,
{},
{},
attr
);
}
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"Input"
,
node
),
starts
,
ends
,
axes
->
outputs
[
0
]},
node
->
outputs
);
return
new_node
;
}
Node
*
expand_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
if
(
op
->
HasInput
(
"expand_times_tensor"
)
&&
!
op
->
Input
(
"expand_times_tensor"
).
empty
())
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Expand op with expand_times_tensor"
));
}
Node
*
expand_times
=
nullptr
;
if
(
op
->
HasInput
(
"ExpandTimes"
)
&&
!
op
->
Input
(
"ExpandTimes"
).
empty
())
{
// cast to int64
expand_times
=
CreateCast
(
graph
,
node
,
{
GetInputVarNode
(
"ExpandTimes"
,
node
)},
{},
framework
::
proto
::
VarType
::
INT64
);
}
else
{
auto
expand_times_i32
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"expand_times"
));
auto
expand_times_
=
std
::
vector
<
int64_t
>
{
expand_times_i32
.
begin
(),
expand_times_i32
.
end
()};
auto
dim
=
int64_t
(
expand_times_
.
size
());
auto
attr
=
MakeConstAttrMap
<
int64_t
>
(
expand_times_
,
{
dim
},
ONNXDataType
::
INT64
);
expand_times
=
CreateConst
(
graph
,
node
,
{},
{},
attr
);
}
auto
new_node
=
CreateBaseOp
(
graph
,
node
,
"popart_tile"
,
{
GetInputVarNode
(
"X"
,
node
),
expand_times
->
outputs
[
0
]},
node
->
outputs
);
return
new_node
;
}
REGISTER_HANDLER
(
fill_constant
,
fill_constant_handler
);
REGISTER_HANDLER
(
gaussian_random
,
gaussian_random_handler
);
REGISTER_HANDLER
(
uniform_random
,
uniform_random_handler
);
REGISTER_HANDLER
(
transpose2
,
transpose_handler
);
REGISTER_HANDLER
(
reshape2
,
reshape_handler
);
REGISTER_HANDLER
(
gather
,
gather_handler
);
REGISTER_HANDLER
(
squeeze2
,
squeeze_handler
);
REGISTER_HANDLER
(
cast
,
cast_handler
);
REGISTER_HANDLER
(
lookup_table
,
lookup_table_handler
);
REGISTER_HANDLER
(
unsqueeze2
,
unsqueeze_handler
);
REGISTER_HANDLER
(
concat
,
concat_handler
);
REGISTER_HANDLER
(
stack
,
stack_handler
);
REGISTER_HANDLER
(
shape
,
shape_handler
);
REGISTER_HANDLER
(
slice
,
slice_handler
);
REGISTER_HANDLER
(
expand
,
expand_handler
);
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录