Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1b5647d7
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1b5647d7
编写于
5月 06, 2022
作者:
A
Allen Guo
提交者:
GitHub
5月 06, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IPU] clean ipu related code (#42511)
* clean code * fix ci * fix ci * fix ci 2
上级
6ff35e17
变更
19
隐藏空白更改
内联
并排
Showing
19 changed file
with
561 addition
and
502 deletion
+561
-502
paddle/fluid/platform/device/ipu/ipu_backend.cc
paddle/fluid/platform/device/ipu/ipu_backend.cc
+7
-9
paddle/fluid/platform/device/ipu/ipu_backend.h
paddle/fluid/platform/device/ipu/ipu_backend.h
+30
-32
paddle/fluid/platform/device/ipu/ipu_compiler.cc
paddle/fluid/platform/device/ipu/ipu_compiler.cc
+147
-37
paddle/fluid/platform/device/ipu/ipu_compiler.h
paddle/fluid/platform/device/ipu/ipu_compiler.h
+3
-33
paddle/fluid/platform/device/ipu/ipu_device.cc
paddle/fluid/platform/device/ipu/ipu_device.cc
+6
-2
paddle/fluid/platform/device/ipu/ipu_device.h
paddle/fluid/platform/device/ipu/ipu_device.h
+1
-1
paddle/fluid/platform/device/ipu/ipu_executor.cc
paddle/fluid/platform/device/ipu/ipu_executor.cc
+51
-17
paddle/fluid/platform/device/ipu/ipu_executor.h
paddle/fluid/platform/device/ipu/ipu_executor.h
+17
-13
paddle/fluid/platform/device/ipu/ipu_strategy.cc
paddle/fluid/platform/device/ipu/ipu_strategy.cc
+2
-0
paddle/fluid/platform/device/ipu/ipu_utils.cc
paddle/fluid/platform/device/ipu/ipu_utils.cc
+109
-75
paddle/fluid/platform/device/ipu/ipu_utils.h
paddle/fluid/platform/device/ipu/ipu_utils.h
+46
-176
paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc
...form/device/ipu/popart_canonicalization/activation_ops.cc
+3
-3
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc
...ice/ipu/popart_canonicalization/canonicalization_utils.cc
+2
-65
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h
...vice/ipu/popart_canonicalization/canonicalization_utils.h
+1
-4
paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc
.../platform/device/ipu/popart_canonicalization/logic_ops.cc
+9
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc
...d/platform/device/ipu/popart_canonicalization/math_ops.cc
+78
-2
paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc
...uid/platform/device/ipu/popart_canonicalization/nn_ops.cc
+1
-1
paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc
...platform/device/ipu/popart_canonicalization/op_builder.cc
+1
-1
paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc
...platform/device/ipu/popart_canonicalization/tensor_ops.cc
+47
-31
未找到文件。
paddle/fluid/platform/device/ipu/ipu_backend.cc
浏览文件 @
1b5647d7
...
@@ -13,12 +13,10 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_executor.h"
#include "paddle/fluid/framework/ir/node.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
@@ -40,7 +38,7 @@ IpuBackend::~IpuBackend() {
...
@@ -40,7 +38,7 @@ IpuBackend::~IpuBackend() {
executor_
.
reset
();
executor_
.
reset
();
}
}
void
IpuBackend
::
Compile
(
Graph
*
graph
,
void
IpuBackend
::
Compile
(
framework
::
ir
::
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>&
feed_list
,
const
std
::
vector
<
std
::
string
>&
feed_list
,
const
std
::
vector
<
std
::
string
>&
fetch_list
)
{
const
std
::
vector
<
std
::
string
>&
fetch_list
)
{
VLOG
(
10
)
<<
"enter IpuBackend::Compile"
;
VLOG
(
10
)
<<
"enter IpuBackend::Compile"
;
...
@@ -63,8 +61,8 @@ void IpuBackend::Compile(Graph* graph,
...
@@ -63,8 +61,8 @@ void IpuBackend::Compile(Graph* graph,
VLOG
(
10
)
<<
"leave IpuBackend::Compile"
;
VLOG
(
10
)
<<
"leave IpuBackend::Compile"
;
}
}
void
IpuBackend
::
Run
(
const
std
::
vector
<
const
Tensor
*>&
inputs
,
void
IpuBackend
::
Run
(
const
std
::
vector
<
const
framework
::
Tensor
*>&
inputs
,
const
std
::
vector
<
Tensor
*>&
outputs
,
const
std
::
vector
<
framework
::
Tensor
*>&
outputs
,
const
framework
::
ExecutionContext
&
ctx
)
{
const
framework
::
ExecutionContext
&
ctx
)
{
timer_
->
Start
();
timer_
->
Start
();
executor_
->
Run
(
inputs
,
outputs
,
ctx
);
executor_
->
Run
(
inputs
,
outputs
,
ctx
);
...
@@ -82,7 +80,7 @@ void IpuBackend::Reset() {
...
@@ -82,7 +80,7 @@ void IpuBackend::Reset() {
executor_
.
reset
();
executor_
.
reset
();
}
}
void
IpuBackend
::
SetScope
(
const
Scope
&
scope
)
{
void
IpuBackend
::
SetScope
(
const
framework
::
Scope
&
scope
)
{
scope_
=
&
scope
;
scope_
=
&
scope
;
executor_
->
SetScope
(
&
scope
);
executor_
->
SetScope
(
&
scope
);
}
}
...
...
paddle/fluid/platform/device/ipu/ipu_backend.h
浏览文件 @
1b5647d7
...
@@ -18,26 +18,25 @@ limitations under the License. */
...
@@ -18,26 +18,25 @@ limitations under the License. */
#include <popart/names.hpp>
#include <popart/names.hpp>
#include <popart/tensorinfo.hpp>
#include <popart/tensorinfo.hpp>
#include "paddle/fluid/framework/
operator
.h"
#include "paddle/fluid/framework/
ir/graph
.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_device.h"
#include "paddle/fluid/platform/device/ipu/ipu_executor.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/timer.h"
#include "paddle/fluid/platform/timer.h"
namespace
paddle
{
namespace
framework
{
class
ExecutionContext
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
// IpuBackend is the center of paddle-ipu, its function include:
class
IpuStrategy
;
// 1. Compile paddle model to popart model
class
Compiler
;
// 2. Run popart model, inference or training
class
Executor
;
// 3. Request and release device
// 4. Other helper function
class
IpuBackend
{
class
IpuBackend
{
public:
public:
static
IpuBackend
*
GetInstance
();
static
IpuBackend
*
GetInstance
();
...
@@ -46,47 +45,46 @@ class IpuBackend {
...
@@ -46,47 +45,46 @@ class IpuBackend {
IpuBackend
();
IpuBackend
();
~
IpuBackend
();
~
IpuBackend
();
// what compile does include(call compiler_):
// What compile method does:
// 1. map paddle-op -> poart op
// Convert paddle ops to popart ops;
// 2. construct popart onnx compute graph
// Construct a popart graph, which is a onnx compute graph;
void
Compile
(
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>
&
feed_list
,
// Load the graph and weights to ipu.
void
Compile
(
framework
::
ir
::
Graph
*
graph
,
const
std
::
vector
<
std
::
string
>
&
feed_list
,
const
std
::
vector
<
std
::
string
>
&
fetch_list
);
const
std
::
vector
<
std
::
string
>
&
fetch_list
);
// what run does include:
// Run the compiled graph on ipu
// 1. construct forward onnx graph
void
Run
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
inputs
,
// 2. graph-level optimization
const
std
::
vector
<
framework
::
Tensor
*>
&
outputs
,
// 3. autodiff
void
Run
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
,
const
framework
::
ExecutionContext
&
ctx
);
const
framework
::
ExecutionContext
&
ctx
);
// Sync weights from IPU while training
// Sync weights from IPU while training
void
WeightsToHost
();
void
WeightsToHost
();
//
d
etach IPU manually
//
D
etach IPU manually
void
Detach
();
void
Detach
();
//
r
eset manually
//
R
eset manually
//
c
all it before destruct works
//
C
all it before destruct works
void
Reset
();
void
Reset
();
void
SetScope
(
const
Scope
&
scope
);
void
SetScope
(
const
framework
::
Scope
&
scope
);
const
Scope
*
GetScope
()
{
return
scope_
;
}
const
framework
::
Scope
*
GetScope
()
{
return
scope_
;
}
void
SetIpuStrategy
(
const
IpuStrategy
&
strategy
);
void
SetIpuStrategy
(
const
IpuStrategy
&
strategy
);
const
IpuStrategy
*
GetIpuStrategy
()
{
return
ipu_strategy_
;
}
const
IpuStrategy
*
GetIpuStrategy
()
{
return
ipu_strategy_
;
}
//
s
ave compiled model to onnx
//
S
ave compiled model to onnx
void
SaveModelProto
(
const
std
::
string
&
path
);
void
SaveModelProto
(
const
std
::
string
&
path
);
private:
private:
//
n
ot own
//
N
ot own
const
Scope
*
scope_
=
nullptr
;
const
framework
::
Scope
*
scope_
=
nullptr
;
const
IpuStrategy
*
ipu_strategy_
=
nullptr
;
const
IpuStrategy
*
ipu_strategy_
=
nullptr
;
//
o
wn
//
O
wn
std
::
unique_ptr
<
Compiler
>
compiler_
;
std
::
unique_ptr
<
Compiler
>
compiler_
;
std
::
unique_ptr
<
Executor
>
executor_
;
std
::
unique_ptr
<
Executor
>
executor_
;
std
::
unique_ptr
<
platform
::
Timer
>
timer_
;
std
::
unique_ptr
<
Timer
>
timer_
;
bool
is_compiled_
=
false
;
bool
is_compiled_
=
false
;
...
...
paddle/fluid/platform/device/ipu/ipu_compiler.cc
浏览文件 @
1b5647d7
...
@@ -20,12 +20,110 @@
...
@@ -20,12 +20,110 @@
#include <popart/sgd.hpp>
#include <popart/sgd.hpp>
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
namespace
{
struct
CustomOpAttrVisitor
:
public
boost
::
static_visitor
<
void
>
{
CustomOpAttrVisitor
(
std
::
map
<
std
::
string
,
popart
::
any
>*
attr
,
const
std
::
string
&
attr_name
)
:
attrs_
(
attr
),
attr_name_
(
attr_name
)
{}
mutable
std
::
map
<
std
::
string
,
popart
::
any
>*
attrs_
;
std
::
string
attr_name_
;
void
operator
()(
int
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
float
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
string
&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
int
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
float
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
std
::
string
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
bool
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
bool
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
BlockDesc
*
desc
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported calling method for `BlockDesc` type when extracting "
"custom operator attributes."
));
}
void
operator
()(
const
std
::
vector
<
BlockDesc
*>&
v
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported calling method for `BlockDesc` type when extracting "
"custom operator attributes."
));
}
void
operator
()(
int64_t
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
int64_t
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
double
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported calling method for `boost::blank` type when extracting "
"custom operator attributes."
));
}
};
struct
ConstantOpAttrVisitor
:
public
boost
::
static_visitor
<
void
>
{
ConstantOpAttrVisitor
(
framework
::
LoDTensor
*
tensor
,
VarType
::
Type
dtype
)
:
tensor_
(
tensor
),
dtype_
(
dtype
)
{}
framework
::
LoDTensor
*
tensor_
;
VarType
::
Type
dtype_
;
void
operator
()(
const
std
::
vector
<
int
>&
vec
)
const
{
framework
::
TensorFromVector
<
int
>
(
vec
,
tensor_
);
}
void
operator
()(
const
std
::
vector
<
float
>&
vec
)
const
{
if
(
dtype_
==
VarType
::
FP16
)
{
std
::
vector
<
float16
>
vec_fp16
;
std
::
transform
(
vec
.
begin
(),
vec
.
end
(),
std
::
back_inserter
(
vec_fp16
),
[](
float
f
)
->
float16
{
return
float16
(
f
);
});
framework
::
TensorFromVector
<
float16
>
(
vec_fp16
,
tensor_
);
}
else
{
framework
::
TensorFromVector
<
float
>
(
vec
,
tensor_
);
}
}
void
operator
()(
const
std
::
vector
<
bool
>&
vec
)
const
{
framework
::
TensorFromVector
<
bool
>
(
vec
,
tensor_
);
}
void
operator
()(
const
std
::
vector
<
int64_t
>&
vec
)
const
{
framework
::
TensorFromVector
<
int64_t
>
(
vec
,
tensor_
);
}
void
operator
()(
const
std
::
vector
<
double
>&
vec
)
const
{
framework
::
TensorFromVector
<
double
>
(
vec
,
tensor_
);
}
#define RAISE_ERROR \
PADDLE_THROW( \
platform::errors::InvalidArgument("Constant value must be a vector"))
void
operator
()(
int
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
float
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
const
std
::
string
&
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
const
std
::
vector
<
std
::
string
>&
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
bool
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
BlockDesc
*
desc
)
const
{
RAISE_ERROR
;
}
void
operator
()(
const
std
::
vector
<
BlockDesc
*>&
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
int64_t
v
)
const
{
RAISE_ERROR
;
}
void
operator
()(
boost
::
blank
)
const
{
RAISE_ERROR
;
}
#undef RAISE_ERROR
};
popart
::
AdamMode
AdamModeFromStr
(
const
std
::
string
&
str
,
popart
::
AdamMode
AdamModeFromStr
(
const
std
::
string
&
str
,
const
bool
&
use_no_bias_optimizer
)
{
const
bool
&
use_no_bias_optimizer
)
{
if
(
str
==
"adam"
)
{
if
(
str
==
"adam"
)
{
...
@@ -117,6 +215,34 @@ TO GetCastSigAttrAllowNull(std::string attr, OpDesc* op_desc) {
...
@@ -117,6 +215,34 @@ TO GetCastSigAttrAllowNull(std::string attr, OpDesc* op_desc) {
}
}
}
}
// Helper for adding namescope info
struct
NameScopeHelper
{
NameScopeHelper
(
const
OpDesc
*
op
,
popart
::
Builder
*
builder
);
~
NameScopeHelper
()
{
if
(
pushed_
)
{
builder_
->
popNameScope
();
}
}
bool
pushed_
=
false
;
popart
::
Builder
*
builder_
;
};
NameScopeHelper
::
NameScopeHelper
(
const
OpDesc
*
op
,
popart
::
Builder
*
builder
)
:
builder_
(
builder
)
{
auto
op_namescope
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
sOpNamescope
));
if
(
op_namescope
.
empty
()
||
op_namescope
==
"/"
)
{
return
;
}
op_namescope
.
pop_back
();
op_namescope
.
erase
(
op_namescope
.
begin
());
builder
->
pushNameScope
(
op_namescope
);
pushed_
=
true
;
}
}
// namespace
GraphHelper
::
GraphHelper
(
const
Graph
*
g
)
{
GraphHelper
::
GraphHelper
(
const
Graph
*
g
)
{
graph
=
g
;
graph
=
g
;
sorted_ops
=
framework
::
ir
::
TopologySortOperations
(
*
g
);
sorted_ops
=
framework
::
ir
::
TopologySortOperations
(
*
g
);
...
@@ -181,14 +307,12 @@ void Compiler::RegisterOpFunc() {
...
@@ -181,14 +307,12 @@ void Compiler::RegisterOpFunc() {
auto op_type = op_desc->Type(); \
auto op_type = op_desc->Type(); \
VLOG(10) << "build op:" << op_type << " args " << #Args; \
VLOG(10) << "build op:" << op_type << " args " << #Args; \
auto inputs = GetOpInputs(op_desc); \
auto inputs = GetOpInputs(op_desc); \
auto output_names = GetOpOutputs(op_desc); \
auto debug_context = BuildDebugContext(op_desc); \
auto debug_context = BuildDebugContext(op_desc); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
auto aiOnnxOpset = builder_->aiOnnxOpset11(); \
NameScopeHelper ns_helper(op_desc, builder_.get()); \
NameScopeHelper ns_helper(op_desc, builder_.get()); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \
auto output_ids = OnnxImpl(inputs Args, debug_context); \
PostLower(output_ids, op_desc); \
PostLower(output_ids, op_desc); \
InsertTensors(output_names, output_ids); \
}}, // NOLINT
}}, // NOLINT
#include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h"
#include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h"
#include "paddle/fluid/platform/device/ipu/supported_ops_custom.h"
#include "paddle/fluid/platform/device/ipu/supported_ops_custom.h"
...
@@ -219,7 +343,7 @@ void Compiler::InitInputs(const std::vector<std::string>& feed_list) {
...
@@ -219,7 +343,7 @@ void Compiler::InitInputs(const std::vector<std::string>& feed_list) {
auto
*
node
=
graph_helper_
->
vars_name_map
[
feed_name
];
auto
*
node
=
graph_helper_
->
vars_name_map
[
feed_name
];
auto
*
var_desc
=
node
->
Var
();
auto
*
var_desc
=
node
->
Var
();
VLOG
(
10
)
<<
"feed_name= "
<<
var_desc
->
Name
();
VLOG
(
10
)
<<
"feed_name= "
<<
var_desc
->
Name
();
auto
data_type
=
VarType2PopartType
(
var_desc
->
GetDataType
());
auto
data_type
=
VarType2Popart
D
Type
(
var_desc
->
GetDataType
());
popart
::
TensorInfo
input_info
{
data_type
,
var_desc
->
GetShape
()};
popart
::
TensorInfo
input_info
{
data_type
,
var_desc
->
GetShape
()};
VLOG
(
10
)
<<
"popart input_info = "
<<
input_info
;
VLOG
(
10
)
<<
"popart input_info = "
<<
input_info
;
popart
::
TensorId
tensor_id
=
popart
::
TensorId
tensor_id
=
...
@@ -255,8 +379,9 @@ void Compiler::LowerConstants(const Scope* scope) {
...
@@ -255,8 +379,9 @@ void Compiler::LowerConstants(const Scope* scope) {
auto
shape
=
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op_desc
->
GetAttr
(
"dims"
));
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op_desc
->
GetAttr
(
"dims"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op_desc
->
GetAttr
(
"dtype"
));
auto
dtype
=
PopartType2VarType
(
OnnxDtype2PopartType
(
dtype_
));
auto
dtype
=
PopartDType2VarType
(
auto
tensor_name
=
op_desc
->
Output
(
"__outputs__"
)[
0
];
OnnxDType2PopartType
(
static_cast
<
ONNXDataType
>
(
dtype_
)));
auto
tensor_name
=
GetOpOutputs
(
op_desc
).
front
();
auto
*
var
=
kid_scope
.
Var
(
tensor_name
);
auto
*
var
=
kid_scope
.
Var
(
tensor_name
);
VLOG
(
10
)
<<
"lowering constant: "
<<
tensor_name
;
VLOG
(
10
)
<<
"lowering constant: "
<<
tensor_name
;
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
tensor
=
var
->
GetMutable
<
framework
::
LoDTensor
>
();
...
@@ -267,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) {
...
@@ -267,7 +392,7 @@ void Compiler::LowerConstants(const Scope* scope) {
tensor
->
Resize
(
ddim
);
tensor
->
Resize
(
ddim
);
auto
const_data
=
std
::
unique_ptr
<
popart
::
ConstVoidData
>
();
auto
const_data
=
std
::
unique_ptr
<
popart
::
ConstVoidData
>
();
popart
::
TensorInfo
tensor_info
(
P
dDataType2Popart
Type
(
tensor
->
dtype
()),
popart
::
TensorInfo
tensor_info
(
P
hiDType2PopartD
Type
(
tensor
->
dtype
()),
shape
);
shape
);
const_data
.
reset
(
new
popart
::
ConstVoidData
(
tensor
->
data
(),
tensor_info
));
const_data
.
reset
(
new
popart
::
ConstVoidData
(
tensor
->
data
(),
tensor_info
));
NameScopeHelper
ns_helper
(
op_desc
,
builder_
.
get
());
NameScopeHelper
ns_helper
(
op_desc
,
builder_
.
get
());
...
@@ -303,7 +428,7 @@ void Compiler::LowerWeights(const Scope* scope) {
...
@@ -303,7 +428,7 @@ void Compiler::LowerWeights(const Scope* scope) {
var
,
platform
::
errors
::
NotFound
(
"Tensor %s is not found in the scope"
,
var
,
platform
::
errors
::
NotFound
(
"Tensor %s is not found in the scope"
,
var_name
));
var_name
));
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
dtype
=
P
dDataType2Popart
Type
(
tensor
.
dtype
());
auto
dtype
=
P
hiDType2PopartD
Type
(
tensor
.
dtype
());
auto
shape
=
std
::
vector
<
int64_t
>
();
auto
shape
=
std
::
vector
<
int64_t
>
();
for
(
size_t
i
=
0
;
i
<
tensor
.
dims
().
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
tensor
.
dims
().
size
();
++
i
)
{
shape
.
push_back
(
tensor
.
dims
().
at
(
i
));
shape
.
push_back
(
tensor
.
dims
().
at
(
i
));
...
@@ -336,11 +461,9 @@ void Compiler::LowerBody() {
...
@@ -336,11 +461,9 @@ void Compiler::LowerBody() {
// pass
// pass
}
else
if
(
op_type
==
"popart_checkpointoutput"
)
{
}
else
if
(
op_type
==
"popart_checkpointoutput"
)
{
auto
inputs
=
GetOpInputs
(
op_desc
);
auto
inputs
=
GetOpInputs
(
op_desc
);
auto
outputs
=
GetOpOutputs
(
op_desc
);
NameScopeHelper
ns_helper
(
op_desc
,
builder_
.
get
());
NameScopeHelper
ns_helper
(
op_desc
,
builder_
.
get
());
auto
output_ids
=
builder_
->
checkpointOutput
(
inputs
);
auto
output_ids
=
builder_
->
checkpointOutput
(
inputs
);
PostLower
(
output_ids
,
op_desc
);
PostLower
(
output_ids
,
op_desc
);
InsertTensors
(
outputs
,
output_ids
);
}
else
if
(
op_type
==
"popart_custom_op"
)
{
}
else
if
(
op_type
==
"popart_custom_op"
)
{
auto
inputs
=
GetOpInputs
(
op_desc
);
auto
inputs
=
GetOpInputs
(
op_desc
);
auto
outputs
=
GetOpOutputs
(
op_desc
);
auto
outputs
=
GetOpOutputs
(
op_desc
);
...
@@ -359,10 +482,8 @@ void Compiler::LowerBody() {
...
@@ -359,10 +482,8 @@ void Compiler::LowerBody() {
builder_
->
customOp
(
it
->
second
.
popart_op
,
it
->
second
.
popart_op
.
version
,
builder_
->
customOp
(
it
->
second
.
popart_op
,
it
->
second
.
popart_op
.
version
,
inputs
,
outputs
.
size
(),
attributes
,
debug_context
);
inputs
,
outputs
.
size
(),
attributes
,
debug_context
);
PostLower
(
output_ids
,
op_desc
);
PostLower
(
output_ids
,
op_desc
);
InsertTensors
(
outputs
,
output_ids
);
}
else
if
(
op_type
==
"popart_printtensor"
)
{
}
else
if
(
op_type
==
"popart_printtensor"
)
{
auto
inputs
=
GetOpInputs
(
op_desc
);
auto
inputs
=
GetOpInputs
(
op_desc
);
auto
outputs
=
GetOpOutputs
(
op_desc
);
auto
debug_context
=
BuildDebugContext
(
op_desc
);
auto
debug_context
=
BuildDebugContext
(
op_desc
);
auto
print_gradient
=
auto
print_gradient
=
BOOST_GET_CONST
(
int64_t
,
op_desc
->
GetAttr
(
"print_gradient"
));
BOOST_GET_CONST
(
int64_t
,
op_desc
->
GetAttr
(
"print_gradient"
));
...
@@ -371,7 +492,6 @@ void Compiler::LowerBody() {
...
@@ -371,7 +492,6 @@ void Compiler::LowerBody() {
auto
output_ids
=
builder_
->
aiGraphcoreOpset1
().
printtensor
(
auto
output_ids
=
builder_
->
aiGraphcoreOpset1
().
printtensor
(
inputs
,
print_gradient
,
debug_context
,
title
);
inputs
,
print_gradient
,
debug_context
,
title
);
PostLower
(
output_ids
,
op_desc
);
PostLower
(
output_ids
,
op_desc
);
InsertTensors
(
outputs
,
output_ids
);
}
else
{
}
else
{
auto
itr
=
name_function_
.
find
(
op_type
);
auto
itr
=
name_function_
.
find
(
op_type
);
if
(
itr
!=
name_function_
.
end
())
{
if
(
itr
!=
name_function_
.
end
())
{
...
@@ -601,23 +721,6 @@ void Compiler::LowerOptimizer(const Scope* scope) {
...
@@ -601,23 +721,6 @@ void Compiler::LowerOptimizer(const Scope* scope) {
}
}
}
}
void
Compiler
::
InsertTensors
(
const
std
::
vector
<
std
::
string
>&
output_names
,
const
std
::
vector
<
std
::
string
>&
tensor_ids
)
{
PADDLE_ENFORCE_EQ
(
output_names
.
size
(),
tensor_ids
.
size
(),
platform
::
errors
::
Fatal
(
"InsertTensors size mismatch"
));
for
(
int
i
=
0
;
i
<
tensor_ids
.
size
();
i
++
)
{
std
::
string
tensor_id
=
tensor_ids
[
i
];
resources_
->
tensors
.
emplace
(
output_names
[
i
],
tensor_ids
[
i
]);
}
}
void
Compiler
::
InsertTensors
(
const
std
::
vector
<
std
::
string
>&
output_names
,
const
std
::
string
&
tensor_id
)
{
PADDLE_ENFORCE_EQ
(
output_names
.
size
(),
1
,
platform
::
errors
::
Fatal
(
"InsertTensors size mismatch"
));
resources_
->
tensors
.
emplace
(
output_names
[
0
],
tensor_id
);
}
void
Compiler
::
PostLower
(
const
std
::
vector
<
std
::
string
>&
tensor_ids
,
void
Compiler
::
PostLower
(
const
std
::
vector
<
std
::
string
>&
tensor_ids
,
const
OpDesc
*
op_desc
)
{
const
OpDesc
*
op_desc
)
{
// Set pipline
// Set pipline
...
@@ -637,13 +740,26 @@ void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
...
@@ -637,13 +740,26 @@ void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
<<
" for op: "
<<
op_desc
->
Type
();
<<
" for op: "
<<
op_desc
->
Type
();
}
}
}
}
// Record output tensors
auto
pd_outs
=
GetOpOutputs
(
op_desc
);
PADDLE_ENFORCE_EQ
(
pd_outs
.
size
(),
tensor_ids
.
size
(),
platform
::
errors
::
Fatal
(
"paddle and popart op have different outputs"
));
for
(
int
i
=
0
;
i
<
tensor_ids
.
size
();
++
i
)
{
resources_
->
tensors
.
emplace
(
pd_outs
[
i
],
tensor_ids
[
i
]);
}
for
(
auto
&
tensor_id
:
tensor_ids
)
{
for
(
auto
&
tensor_id
:
tensor_ids
)
{
PostLower
(
tensor_id
,
op_desc
,
true
);
PostLower
(
tensor_id
,
op_desc
,
true
);
}
}
}
}
void
Compiler
::
PostLower
(
const
std
::
string
&
tensor_id
,
const
OpDesc
*
op_desc
)
{
void
Compiler
::
PostLower
(
const
std
::
string
&
tensor_id
,
const
OpDesc
*
op_desc
)
{
// Record output tensor
auto
pd_outs
=
GetOpOutputs
(
op_desc
);
PADDLE_ENFORCE_EQ
(
pd_outs
.
size
(),
1
,
platform
::
errors
::
Fatal
(
"paddle and popart op have different outputs"
));
resources_
->
tensors
.
emplace
(
pd_outs
[
0
],
tensor_id
);
PostLower
(
tensor_id
,
op_desc
,
false
);
PostLower
(
tensor_id
,
op_desc
,
false
);
}
}
...
@@ -718,13 +834,7 @@ std::string Compiler::GetFP16ModelProto() {
...
@@ -718,13 +834,7 @@ std::string Compiler::GetFP16ModelProto() {
return
graph_transformer
.
getModelProto
();
return
graph_transformer
.
getModelProto
();
}
}
std
::
string
Compiler
::
GetModelProto
()
{
std
::
string
Compiler
::
GetModelProto
()
{
return
builder_
->
getModelProto
();
}
if
(
ipu_strategy_
->
enable_fp16
)
{
return
GetFP16ModelProto
();
}
else
{
return
builder_
->
getModelProto
();
}
}
void
Compiler
::
SaveModelProto
(
const
std
::
string
&
path
)
{
void
Compiler
::
SaveModelProto
(
const
std
::
string
&
path
)
{
builder_
->
saveModelProto
(
path
);
builder_
->
saveModelProto
(
path
);
...
...
paddle/fluid/platform/device/ipu/ipu_compiler.h
浏览文件 @
1b5647d7
...
@@ -17,16 +17,15 @@
...
@@ -17,16 +17,15 @@
#include <popart/builder.hpp>
#include <popart/builder.hpp>
#include <popart/graphtransformer.hpp>
#include <popart/graphtransformer.hpp>
#include <popart/optimizer.hpp>
#include <popart/optimizer.hpp>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
class
IpuStrategy
;
struct
CompilerResources
{
struct
CompilerResources
{
// popart input tensor_ids
// popart input tensor_ids
std
::
vector
<
popart
::
TensorId
>
inputs
;
std
::
vector
<
popart
::
TensorId
>
inputs
;
...
@@ -81,30 +80,6 @@ struct GraphHelper {
...
@@ -81,30 +80,6 @@ struct GraphHelper {
std
::
vector
<
int
>
sorted_vars_id
;
std
::
vector
<
int
>
sorted_vars_id
;
};
};
// Helper for adding namescope info
struct
NameScopeHelper
{
NameScopeHelper
(
const
OpDesc
*
op
,
popart
::
Builder
*
builder
)
:
builder_
(
builder
)
{
auto
op_namescope
=
BOOST_GET_CONST
(
std
::
string
,
op
->
GetAttr
(
sOpNamescope
));
if
(
op_namescope
.
empty
()
||
op_namescope
==
"/"
)
{
return
;
}
op_namescope
.
pop_back
();
op_namescope
.
erase
(
op_namescope
.
begin
());
builder
->
pushNameScope
(
op_namescope
);
pushed_
=
true
;
}
~
NameScopeHelper
()
{
if
(
pushed_
)
{
builder_
->
popNameScope
();
}
}
bool
pushed_
=
false
;
popart
::
Builder
*
builder_
;
};
class
Compiler
{
class
Compiler
{
public:
public:
Compiler
();
Compiler
();
...
@@ -138,11 +113,6 @@ class Compiler {
...
@@ -138,11 +113,6 @@ class Compiler {
const
std
::
vector
<
std
::
string
>
&
GetOpOutputs
(
const
OpDesc
*
op
);
const
std
::
vector
<
std
::
string
>
&
GetOpOutputs
(
const
OpDesc
*
op
);
const
std
::
string
GetNameScope
(
const
OpDesc
*
op
);
const
std
::
string
GetNameScope
(
const
OpDesc
*
op
);
popart
::
DebugContext
BuildDebugContext
(
const
OpDesc
*
op
);
popart
::
DebugContext
BuildDebugContext
(
const
OpDesc
*
op
);
void
InsertTensors
(
const
std
::
vector
<
std
::
string
>
&
output_names
,
const
std
::
vector
<
std
::
string
>
&
tensor_ids
);
void
InsertTensors
(
const
std
::
vector
<
std
::
string
>
&
output_names
,
const
std
::
string
&
tensor_id
);
void
PostLower
(
const
std
::
vector
<
std
::
string
>
&
,
const
OpDesc
*
);
void
PostLower
(
const
std
::
vector
<
std
::
string
>
&
,
const
OpDesc
*
);
void
PostLower
(
const
std
::
string
&
,
const
OpDesc
*
);
void
PostLower
(
const
std
::
string
&
,
const
OpDesc
*
);
void
PostLower
(
const
std
::
string
&
,
const
OpDesc
*
,
bool
);
void
PostLower
(
const
std
::
string
&
,
const
OpDesc
*
,
bool
);
...
...
paddle/fluid/platform/device/ipu/ipu_device.cc
浏览文件 @
1b5647d7
...
@@ -13,14 +13,17 @@ See the License for the specific language governing permissions and
...
@@ -13,14 +13,17 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_device.h"
#include "paddle/fluid/platform/device/ipu/ipu_device.h"
#include <popart/devicemanager.hpp>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
// TODO(alleng) merge with ipu_utils
namespace
{
static
bool
GetBoolEnv
(
std
::
string
str
)
{
const
bool
GetBoolEnv
(
const
std
::
string
&
str
)
{
char
*
str_val
=
getenv
(
str
.
c_str
());
char
*
str_val
=
getenv
(
str
.
c_str
());
if
(
str_val
==
NULL
)
{
if
(
str_val
==
NULL
)
{
return
false
;
return
false
;
...
@@ -32,6 +35,7 @@ static bool GetBoolEnv(std::string str) {
...
@@ -32,6 +35,7 @@ static bool GetBoolEnv(std::string str) {
return
val
;
return
val
;
}
}
}
}
}
// namespace
int
GetNumDevices
()
{
int
GetNumDevices
()
{
bool
ipu_model
=
GetBoolEnv
(
"POPLAR_IPUMODEL"
);
bool
ipu_model
=
GetBoolEnv
(
"POPLAR_IPUMODEL"
);
...
...
paddle/fluid/platform/device/ipu/ipu_device.h
浏览文件 @
1b5647d7
...
@@ -14,7 +14,7 @@ limitations under the License. */
...
@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#pragma once
#include <
popart/devicemanager.hpp
>
#include <
vector
>
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
...
...
paddle/fluid/platform/device/ipu/ipu_executor.cc
浏览文件 @
1b5647d7
...
@@ -14,12 +14,17 @@ limitations under the License. */
...
@@ -14,12 +14,17 @@ limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_executor.h"
#include "paddle/fluid/platform/device/ipu/ipu_executor.h"
using
float16
=
paddle
::
platform
::
float16
;
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
namespace
{
// Get paddle prefix and popart postfix of weight states
// Get paddle prefix and popart postfix of weight states
// Format: {popart_postfix, paddle_prefix}
// Format: {popart_postfix, paddle_prefix}
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
GetOptPrePostfix
(
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
GetOptPrePostfix
(
...
@@ -54,6 +59,35 @@ std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
...
@@ -54,6 +59,35 @@ std::vector<std::pair<std::string, std::string>> GetOptPrePostfix(
return
pre_post_fix
;
return
pre_post_fix
;
}
}
class
PdIArray
final
:
public
popart
::
IArray
{
public:
explicit
PdIArray
(
const
Tensor
*
tensor
)
{
tensor_
.
ShareDataWith
(
*
tensor
);
for
(
int
i
=
0
;
i
<
tensor
->
dims
().
size
();
++
i
)
{
shape_
.
push_back
(
tensor
->
dims
().
at
(
i
));
}
}
public:
void
*
data
()
{
return
tensor_
.
data
();
}
popart
::
DataType
dataType
()
const
{
return
PhiDType2PopartDType
(
tensor_
.
dtype
());
}
std
::
size_t
rank
()
const
{
return
tensor_
.
dims
().
size
();
}
int64_t
dim
(
size_t
index
)
const
{
return
tensor_
.
dims
().
at
(
index
);
}
std
::
size_t
nelms
()
const
{
return
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
static_cast
<
int64_t
>
(
1
),
std
::
multiplies
<
int64_t
>
());
}
const
popart
::
Shape
shape
()
const
{
return
shape_
;
}
private:
Tensor
tensor_
;
std
::
vector
<
int64_t
>
shape_
;
};
}
// namespace
Executor
::~
Executor
()
{
Executor
::~
Executor
()
{
Detach
();
Detach
();
session_
.
reset
();
session_
.
reset
();
...
@@ -110,15 +144,15 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
...
@@ -110,15 +144,15 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
VLOG
(
10
)
<<
"enter Executor::Run"
;
VLOG
(
10
)
<<
"enter Executor::Run"
;
// inputs
// inputs
std
::
map
<
popart
::
TensorId
,
popart
::
IArray
&>
popart_inputs
;
std
::
map
<
popart
::
TensorId
,
popart
::
IArray
&>
popart_inputs
;
std
::
map
<
popart
::
TensorId
,
P
addle
IArray
>
input_wrappers
;
std
::
map
<
popart
::
TensorId
,
P
d
IArray
>
input_wrappers
;
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
inputs
.
size
();
i
++
)
{
auto
tensor_id
=
compiler_resources_
->
inputs
[
i
];
auto
tensor_id
=
compiler_resources_
->
inputs
[
i
];
input_wrappers
.
emplace
(
tensor_id
,
P
addle
IArray
(
inputs
[
i
]));
input_wrappers
.
emplace
(
tensor_id
,
P
d
IArray
(
inputs
[
i
]));
popart_inputs
.
emplace
(
tensor_id
,
input_wrappers
.
at
(
tensor_id
));
popart_inputs
.
emplace
(
tensor_id
,
input_wrappers
.
at
(
tensor_id
));
}
}
// anchors
// anchors
std
::
map
<
popart
::
TensorId
,
popart
::
IArray
&>
popart_anchors
;
std
::
map
<
popart
::
TensorId
,
popart
::
IArray
&>
popart_anchors
;
std
::
map
<
popart
::
TensorId
,
P
addle
IArray
>
anchor_wrappers
;
std
::
map
<
popart
::
TensorId
,
P
d
IArray
>
anchor_wrappers
;
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
outputs
.
size
();
i
++
)
{
auto
tensor_id
=
compiler_resources_
->
outputs
[
i
];
auto
tensor_id
=
compiler_resources_
->
outputs
[
i
];
// get dims & dtype from session
// get dims & dtype from session
...
@@ -140,10 +174,10 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
...
@@ -140,10 +174,10 @@ void Executor::Run(const std::vector<const Tensor *> &inputs,
auto
*
tensor
=
outputs
[
i
];
auto
*
tensor
=
outputs
[
i
];
tensor
->
Resize
(
phi
::
make_ddim
(
output_shape
));
tensor
->
Resize
(
phi
::
make_ddim
(
output_shape
));
auto
fetch_dtype
=
fetch_info
.
dataType
();
auto
fetch_dtype
=
fetch_info
.
dataType
();
auto
paddle_type
=
PopartType2VarType
(
fetch_dtype
);
auto
paddle_type
=
Popart
D
Type2VarType
(
fetch_dtype
);
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
tensor
->
mutable_data
(
ctx
.
GetPlace
(),
framework
::
TransToPhiDataType
(
paddle_type
));
framework
::
TransToPhiDataType
(
paddle_type
));
anchor_wrappers
.
emplace
(
tensor_id
,
P
addle
IArray
(
tensor
));
anchor_wrappers
.
emplace
(
tensor_id
,
P
d
IArray
(
tensor
));
popart_anchors
.
emplace
(
tensor_id
,
anchor_wrappers
.
at
(
tensor_id
));
popart_anchors
.
emplace
(
tensor_id
,
anchor_wrappers
.
at
(
tensor_id
));
}
}
VLOG
(
10
)
<<
"Prepared inputs/anchors"
;
VLOG
(
10
)
<<
"Prepared inputs/anchors"
;
...
@@ -203,16 +237,16 @@ void Executor::AcquireDevice() {
...
@@ -203,16 +237,16 @@ void Executor::AcquireDevice() {
device_
=
popart
::
DeviceManager
::
createDeviceManager
().
acquireDeviceById
(
device_
=
popart
::
DeviceManager
::
createDeviceManager
().
acquireDeviceById
(
device_id
);
device_id
);
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
device_
,
platform
::
errors
::
Unavailable
(
device_
,
"Can't attach IPU in distribution, ipu_num = %d."
,
errors
::
Unavailable
(
"Can't attach IPU in distribution, ipu_num = %d."
,
RequestIpus
(
ipu_strategy_
->
num_ipus
)));
RequestIpus
(
ipu_strategy_
->
num_ipus
)));
}
else
{
}
else
{
device_
=
device_
=
popart
::
DeviceManager
::
createDeviceManager
().
acquireAvailableDevice
(
popart
::
DeviceManager
::
createDeviceManager
().
acquireAvailableDevice
(
RequestIpus
(
ipu_strategy_
->
num_ipus
));
RequestIpus
(
ipu_strategy_
->
num_ipus
));
PADDLE_ENFORCE_NOT_NULL
(
device_
,
platform
::
errors
::
Unavailable
(
PADDLE_ENFORCE_NOT_NULL
(
"Can't attach IPU, ipu_num = %d."
,
device_
,
errors
::
Unavailable
(
"Can't attach IPU, ipu_num = %d."
,
RequestIpus
(
ipu_strategy_
->
num_ipus
)));
RequestIpus
(
ipu_strategy_
->
num_ipus
)));
}
}
VLOG
(
10
)
<<
"leave Executor::AcquireDevice"
;
VLOG
(
10
)
<<
"leave Executor::AcquireDevice"
;
}
}
...
@@ -260,13 +294,13 @@ void Executor::SetWeightsIO() {
...
@@ -260,13 +294,13 @@ void Executor::SetWeightsIO() {
void
Executor
::
ConvertWeights
(
bool
align_to_popart
)
{
void
Executor
::
ConvertWeights
(
bool
align_to_popart
)
{
for
(
auto
weight_pair
:
executor_resources_
->
weights_and_opt_state
)
{
for
(
auto
weight_pair
:
executor_resources_
->
weights_and_opt_state
)
{
auto
paddle_var
=
scope_
->
GetVar
(
weight_pair
.
second
);
auto
paddle_var
=
scope_
->
GetVar
(
weight_pair
.
second
);
auto
paddle_var_dtype
=
P
dDataType2Popart
Type
(
auto
paddle_var_dtype
=
P
hiDType2PopartD
Type
(
paddle_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
dtype
());
paddle_var
->
GetMutable
<
framework
::
LoDTensor
>
()
->
dtype
());
PADDLE_ENFORCE_EQ
((
paddle_var_dtype
==
popart
::
DataType
::
FLOAT
||
PADDLE_ENFORCE_EQ
((
paddle_var_dtype
==
popart
::
DataType
::
FLOAT
||
paddle_var_dtype
==
popart
::
DataType
::
FLOAT16
),
paddle_var_dtype
==
popart
::
DataType
::
FLOAT16
),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Currently, we only support FLOAT16 and FLOAT with "
"Currently, we only support FLOAT16 and FLOAT with "
"Paddle, but received type is %s."
,
"Paddle, but received type is %s."
,
paddle_var_dtype
));
paddle_var_dtype
));
...
@@ -276,7 +310,7 @@ void Executor::ConvertWeights(bool align_to_popart) {
...
@@ -276,7 +310,7 @@ void Executor::ConvertWeights(bool align_to_popart) {
PADDLE_ENFORCE_EQ
((
popart_var_dtype
==
popart
::
DataType
::
FLOAT
||
PADDLE_ENFORCE_EQ
((
popart_var_dtype
==
popart
::
DataType
::
FLOAT
||
popart_var_dtype
==
popart
::
DataType
::
FLOAT16
),
popart_var_dtype
==
popart
::
DataType
::
FLOAT16
),
true
,
true
,
platform
::
errors
::
InvalidArgument
(
errors
::
InvalidArgument
(
"Currently, we only support FLOAT16 and FLOAT with "
"Currently, we only support FLOAT16 and FLOAT with "
"popart, but received type is %s."
,
"popart, but received type is %s."
,
popart_var_dtype
));
popart_var_dtype
));
...
@@ -310,8 +344,8 @@ void Executor::ConvertWeights(bool align_to_popart) {
...
@@ -310,8 +344,8 @@ void Executor::ConvertWeights(bool align_to_popart) {
num_elem
*
sizeof
(
float
));
num_elem
*
sizeof
(
float
));
}
}
}
else
{
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
"Convert Paddle FLOAT16 to popart FLOAT"
));
errors
::
Unimplemented
(
"Convert Paddle FLOAT16 to popart FLOAT"
));
}
}
}
}
}
}
...
...
paddle/fluid/platform/device/ipu/ipu_executor.h
浏览文件 @
1b5647d7
...
@@ -22,17 +22,21 @@ limitations under the License. */
...
@@ -22,17 +22,21 @@ limitations under the License. */
#include <popart/tensorinfo.hpp>
#include <popart/tensorinfo.hpp>
#include <popdist/popdist_poplar.hpp>
#include <popdist/popdist_poplar.hpp>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace
paddle
{
namespace
framework
{
class
ExecutionContext
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
struct
CompilerResources
;
class
IpuStrategy
;
struct
ExecutorResources
{
struct
ExecutorResources
{
// map<tensor_id, paddle_var_ptr>
// map<tensor_id, paddle_var_ptr>
popart
::
WeightsIO
weights_io
;
popart
::
WeightsIO
weights_io
;
...
@@ -45,18 +49,18 @@ class Executor {
...
@@ -45,18 +49,18 @@ class Executor {
Executor
()
=
default
;
Executor
()
=
default
;
~
Executor
();
~
Executor
();
//
b
uild popart session
//
B
uild popart session
void
Prepare
(
const
std
::
string
&
proto
);
void
Prepare
(
const
std
::
string
&
proto
);
//
r
un popart session
//
R
un popart session
void
Run
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
void
Run
(
const
std
::
vector
<
const
Tensor
*>
&
inputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
,
const
std
::
vector
<
Tensor
*>
&
outputs
,
const
framework
::
ExecutionContext
&
ctx
);
const
framework
::
ExecutionContext
&
ctx
);
//
s
ync weights from popart to paddle
//
S
ync weights from popart to paddle
void
WeightsToHost
();
void
WeightsToHost
();
//
d
etach IPU
//
D
etach IPU
void
Detach
();
void
Detach
();
// Scope
// Scope
...
@@ -83,16 +87,16 @@ class Executor {
...
@@ -83,16 +87,16 @@ class Executor {
void
WeightsToPaddle
();
void
WeightsToPaddle
();
private:
private:
//
n
ot own
//
N
ot own
const
Scope
*
scope_
=
nullptr
;
const
Scope
*
scope_
=
nullptr
;
const
IpuStrategy
*
ipu_strategy_
=
nullptr
;
const
IpuStrategy
*
ipu_strategy_
=
nullptr
;
CompilerResources
*
compiler_resources_
=
nullptr
;
CompilerResources
*
compiler_resources_
=
nullptr
;
//
d
eviceinfo for popart session
//
D
eviceinfo for popart session
std
::
shared_ptr
<
popart
::
DeviceInfo
>
device_
;
std
::
shared_ptr
<
popart
::
DeviceInfo
>
device_
;
//
p
opart session, where graph running
//
P
opart session, where graph running
std
::
unique_ptr
<
popart
::
Session
>
session_
;
std
::
unique_ptr
<
popart
::
Session
>
session_
;
//
one OneSession means
a graph
//
A ExecutorResources corresponds to
a graph
std
::
unique_ptr
<
ExecutorResources
>
executor_resources_
;
std
::
unique_ptr
<
ExecutorResources
>
executor_resources_
;
};
};
...
...
paddle/fluid/platform/device/ipu/ipu_strategy.cc
浏览文件 @
1b5647d7
...
@@ -316,8 +316,10 @@ IpuStrategy::IpuStrategy() {
...
@@ -316,8 +316,10 @@ IpuStrategy::IpuStrategy() {
RegisterSetter
(
bool_options
,
"enable_half_partial"
,
[
&
](
bool
value
)
{
RegisterSetter
(
bool_options
,
"enable_half_partial"
,
[
&
](
bool
value
)
{
if
(
value
)
{
if
(
value
)
{
popart_options
.
partialsTypeMatMuls
=
"half"
;
popart_options
.
partialsTypeMatMuls
=
"half"
;
popart_options
.
convolutionOptions
.
insert
({{
"partialsType"
,
"half"
}});
}
else
{
}
else
{
popart_options
.
partialsTypeMatMuls
=
"float"
;
popart_options
.
partialsTypeMatMuls
=
"float"
;
popart_options
.
convolutionOptions
.
insert
({{
"partialsType"
,
"float"
}});
}
}
});
});
...
...
paddle/fluid/platform/device/ipu/ipu_utils.cc
浏览文件 @
1b5647d7
...
@@ -13,133 +13,111 @@ See the License for the specific language governing permissions and
...
@@ -13,133 +13,111 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
#include <cmath>
#include <cmath>
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
void
*
PaddleIArray
::
data
()
{
return
tensor_
.
data
();
}
const
popart
::
DataType
VarType2PopartDType
(
const
VarType
::
Type
type
)
{
popart
::
DataType
PaddleIArray
::
dataType
()
const
{
return
PdDataType2PopartType
(
tensor_
.
dtype
());
}
std
::
size_t
PaddleIArray
::
rank
()
const
{
return
tensor_
.
dims
().
size
();
}
int64_t
PaddleIArray
::
dim
(
size_t
index
)
const
{
return
tensor_
.
dims
().
at
(
index
);
}
std
::
size_t
PaddleIArray
::
nelms
()
const
{
return
std
::
accumulate
(
shape_
.
begin
(),
shape_
.
end
(),
static_cast
<
int64_t
>
(
1
),
std
::
multiplies
<
int64_t
>
());
}
const
popart
::
Shape
PaddleIArray
::
shape
()
const
{
return
shape_
;
}
popart
::
DataType
VarType2PopartType
(
const
framework
::
proto
::
VarType
::
Type
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
framework
::
proto
::
VarType
::
UINT8
:
case
VarType
::
UINT8
:
return
popart
::
DataType
::
UINT8
;
return
popart
::
DataType
::
UINT8
;
case
framework
::
proto
::
VarType
::
INT8
:
case
VarType
::
INT8
:
return
popart
::
DataType
::
INT8
;
return
popart
::
DataType
::
INT8
;
case
framework
::
proto
::
VarType
::
INT16
:
case
VarType
::
INT16
:
return
popart
::
DataType
::
INT16
;
return
popart
::
DataType
::
INT16
;
case
framework
::
proto
::
VarType
::
INT32
:
case
VarType
::
INT32
:
return
popart
::
DataType
::
INT32
;
return
popart
::
DataType
::
INT32
;
case
framework
::
proto
::
VarType
::
INT64
:
case
VarType
::
INT64
:
return
popart
::
DataType
::
INT64
;
return
popart
::
DataType
::
INT64
;
case
framework
::
proto
::
VarType
::
BOOL
:
case
VarType
::
BOOL
:
return
popart
::
DataType
::
BOOL
;
return
popart
::
DataType
::
BOOL
;
case
framework
::
proto
::
VarType
::
FP64
:
case
VarType
::
FP64
:
return
popart
::
DataType
::
DOUBLE
;
return
popart
::
DataType
::
DOUBLE
;
case
framework
::
proto
::
VarType
::
FP32
:
case
VarType
::
FP32
:
return
popart
::
DataType
::
FLOAT
;
return
popart
::
DataType
::
FLOAT
;
case
framework
::
proto
::
VarType
::
FP16
:
case
VarType
::
FP16
:
return
popart
::
DataType
::
FLOAT16
;
return
popart
::
DataType
::
FLOAT16
;
case
framework
::
proto
::
VarType
::
BF16
:
case
VarType
::
BF16
:
return
popart
::
DataType
::
BFLOAT16
;
return
popart
::
DataType
::
BFLOAT16
;
case
framework
::
proto
::
VarType
::
COMPLEX64
:
case
VarType
::
COMPLEX64
:
return
popart
::
DataType
::
COMPLEX64
;
return
popart
::
DataType
::
COMPLEX64
;
case
framework
::
proto
::
VarType
::
COMPLEX128
:
case
VarType
::
COMPLEX128
:
return
popart
::
DataType
::
COMPLEX128
;
return
popart
::
DataType
::
COMPLEX128
;
default:
default:
PADDLE_THROW
(
p
addle
::
p
latform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported
Paddle var
type."
));
"Unsupported
VarType::Type when converting to popart data
type."
));
}
}
}
}
popart
::
DataType
PdDataType2PopartType
(
const
popart
::
DataType
PhiDType2PopartDType
(
const
phi
::
DataType
type
)
{
const
paddle
::
experimental
::
DataType
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
p
addle
::
experimental
::
DataType
::
UINT8
:
case
p
hi
::
DataType
::
UINT8
:
return
popart
::
DataType
::
UINT8
;
return
popart
::
DataType
::
UINT8
;
case
p
addle
::
experimental
::
DataType
::
INT8
:
case
p
hi
::
DataType
::
INT8
:
return
popart
::
DataType
::
INT8
;
return
popart
::
DataType
::
INT8
;
case
p
addle
::
experimental
::
DataType
::
INT16
:
case
p
hi
::
DataType
::
INT16
:
return
popart
::
DataType
::
INT16
;
return
popart
::
DataType
::
INT16
;
case
p
addle
::
experimental
::
DataType
::
INT32
:
case
p
hi
::
DataType
::
INT32
:
return
popart
::
DataType
::
INT32
;
return
popart
::
DataType
::
INT32
;
case
p
addle
::
experimental
::
DataType
::
INT64
:
case
p
hi
::
DataType
::
INT64
:
return
popart
::
DataType
::
INT64
;
return
popart
::
DataType
::
INT64
;
case
p
addle
::
experimental
::
DataType
::
BOOL
:
case
p
hi
::
DataType
::
BOOL
:
return
popart
::
DataType
::
BOOL
;
return
popart
::
DataType
::
BOOL
;
case
p
addle
::
experimental
::
DataType
::
FLOAT64
:
case
p
hi
::
DataType
::
FLOAT64
:
return
popart
::
DataType
::
DOUBLE
;
return
popart
::
DataType
::
DOUBLE
;
case
p
addle
::
experimental
::
DataType
::
FLOAT32
:
case
p
hi
::
DataType
::
FLOAT32
:
return
popart
::
DataType
::
FLOAT
;
return
popart
::
DataType
::
FLOAT
;
case
p
addle
::
experimental
::
DataType
::
FLOAT16
:
case
p
hi
::
DataType
::
FLOAT16
:
return
popart
::
DataType
::
FLOAT16
;
return
popart
::
DataType
::
FLOAT16
;
case
p
addle
::
experimental
::
DataType
::
BFLOAT16
:
case
p
hi
::
DataType
::
BFLOAT16
:
return
popart
::
DataType
::
BFLOAT16
;
return
popart
::
DataType
::
BFLOAT16
;
case
p
addle
::
experimental
::
DataType
::
COMPLEX64
:
case
p
hi
::
DataType
::
COMPLEX64
:
return
popart
::
DataType
::
COMPLEX64
;
return
popart
::
DataType
::
COMPLEX64
;
case
p
addle
::
experimental
::
DataType
::
COMPLEX128
:
case
p
hi
::
DataType
::
COMPLEX128
:
return
popart
::
DataType
::
COMPLEX128
;
return
popart
::
DataType
::
COMPLEX128
;
default:
default:
PADDLE_THROW
(
p
addle
::
p
latform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported
Paddle
data type."
));
"Unsupported
phi::DataType when converting to popart
data type."
));
}
}
}
}
framework
::
proto
::
VarType
::
Type
PopartType2VarType
(
const
VarType
::
Type
PopartDType2VarType
(
const
popart
::
DataType
type
)
{
const
popart
::
DataType
type
)
{
switch
(
type
)
{
switch
(
type
)
{
case
popart
::
DataType
::
UINT8
:
case
popart
::
DataType
::
UINT8
:
return
framework
::
proto
::
VarType
::
UINT8
;
return
VarType
::
UINT8
;
case
popart
::
DataType
::
INT8
:
case
popart
::
DataType
::
INT8
:
return
framework
::
proto
::
VarType
::
INT8
;
return
VarType
::
INT8
;
case
popart
::
DataType
::
INT16
:
case
popart
::
DataType
::
INT16
:
return
framework
::
proto
::
VarType
::
INT16
;
return
VarType
::
INT16
;
case
popart
::
DataType
::
INT32
:
case
popart
::
DataType
::
INT32
:
return
framework
::
proto
::
VarType
::
INT32
;
return
VarType
::
INT32
;
case
popart
::
DataType
::
INT64
:
case
popart
::
DataType
::
INT64
:
return
framework
::
proto
::
VarType
::
INT64
;
return
VarType
::
INT64
;
case
popart
::
DataType
::
BOOL
:
case
popart
::
DataType
::
BOOL
:
return
framework
::
proto
::
VarType
::
BOOL
;
return
VarType
::
BOOL
;
case
popart
::
DataType
::
DOUBLE
:
case
popart
::
DataType
::
DOUBLE
:
return
framework
::
proto
::
VarType
::
FP64
;
return
VarType
::
FP64
;
case
popart
::
DataType
::
FLOAT
:
case
popart
::
DataType
::
FLOAT
:
return
framework
::
proto
::
VarType
::
FP32
;
return
VarType
::
FP32
;
case
popart
::
DataType
::
FLOAT16
:
case
popart
::
DataType
::
FLOAT16
:
return
framework
::
proto
::
VarType
::
FP16
;
return
VarType
::
FP16
;
case
popart
::
DataType
::
BFLOAT16
:
case
popart
::
DataType
::
BFLOAT16
:
return
framework
::
proto
::
VarType
::
BF16
;
return
VarType
::
BF16
;
case
popart
::
DataType
::
COMPLEX64
:
case
popart
::
DataType
::
COMPLEX64
:
return
framework
::
proto
::
VarType
::
COMPLEX64
;
return
VarType
::
COMPLEX64
;
case
popart
::
DataType
::
COMPLEX128
:
case
popart
::
DataType
::
COMPLEX128
:
return
framework
::
proto
::
VarType
::
COMPLEX128
;
return
VarType
::
COMPLEX128
;
default:
default:
PADDLE_THROW
(
p
addle
::
platform
::
errors
::
Unavailable
(
PADDLE_THROW
(
p
latform
::
errors
::
Unimplemented
(
"Unsupported
Paddle
var type."
));
"Unsupported
popart::DataType when converting to
var type."
));
}
}
}
}
popart
::
DataType
OnnxDtype2PopartType
(
const
int
type
)
{
const
popart
::
DataType
OnnxDType2PopartType
(
const
ONNXDataType
type
)
{
auto
dtype
=
static_cast
<
ONNXDataType
>
(
type
);
switch
(
type
)
{
switch
(
dtype
)
{
case
ONNXDataType
::
BOOL
:
case
ONNXDataType
::
BOOL
:
return
popart
::
DataType
::
BOOL
;
return
popart
::
DataType
::
BOOL
;
case
ONNXDataType
::
INT16
:
case
ONNXDataType
::
INT16
:
...
@@ -166,12 +144,69 @@ popart::DataType OnnxDtype2PopartType(const int type) {
...
@@ -166,12 +144,69 @@ popart::DataType OnnxDtype2PopartType(const int type) {
return
popart
::
DataType
::
COMPLEX128
;
return
popart
::
DataType
::
COMPLEX128
;
default:
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported ONNX data type: %d."
,
dtype
));
"Unsupported ONNXDataType when converting to popart data type."
));
}
}
const
ONNXDataType
VarType2OnnxDType
(
const
VarType
::
Type
type
)
{
switch
(
type
)
{
case
VarType
::
BOOL
:
return
ONNXDataType
::
BOOL
;
case
VarType
::
INT16
:
return
ONNXDataType
::
INT16
;
case
VarType
::
INT32
:
return
ONNXDataType
::
INT32
;
case
VarType
::
INT64
:
return
ONNXDataType
::
INT64
;
case
VarType
::
FP16
:
return
ONNXDataType
::
FLOAT16
;
case
VarType
::
FP32
:
return
ONNXDataType
::
FLOAT
;
case
VarType
::
FP64
:
return
ONNXDataType
::
DOUBLE
;
case
VarType
::
UINT8
:
return
ONNXDataType
::
UINT8
;
case
VarType
::
INT8
:
return
ONNXDataType
::
INT8
;
case
VarType
::
BF16
:
return
ONNXDataType
::
BFLOAT16
;
case
VarType
::
COMPLEX64
:
return
ONNXDataType
::
COMPLEX64
;
case
VarType
::
COMPLEX128
:
return
ONNXDataType
::
COMPLEX128
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported VarType::Type when converting to onnx data type."
));
}
}
const
std
::
string
VarType2PopartStr
(
const
VarType
::
Type
type
)
{
switch
(
type
)
{
case
VarType
::
UINT8
:
return
"UINT8"
;
case
VarType
::
INT8
:
return
"INT8"
;
case
VarType
::
INT16
:
return
"INT16"
;
case
VarType
::
INT32
:
return
"INT32"
;
case
VarType
::
INT64
:
return
"INT64"
;
case
VarType
::
BOOL
:
return
"BOOL"
;
case
VarType
::
FP64
:
return
"DOUBLE"
;
case
VarType
::
FP32
:
return
"FLOAT"
;
case
VarType
::
FP16
:
return
"FLOAT16"
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported VarType::Type when converting to popart type string."
));
}
}
}
}
// count num should > 0
const
bool
GetBoolEnv
(
const
std
::
string
&
str
)
{
bool
GetBoolEnv
(
std
::
string
str
)
{
char
*
str_val
=
getenv
(
str
.
c_str
());
char
*
str_val
=
getenv
(
str
.
c_str
());
if
(
str_val
==
NULL
)
{
if
(
str_val
==
NULL
)
{
return
false
;
return
false
;
...
@@ -184,8 +219,7 @@ bool GetBoolEnv(std::string str) {
...
@@ -184,8 +219,7 @@ bool GetBoolEnv(std::string str) {
}
}
}
}
int
RequestIpus
(
const
int
num_ipus
)
{
const
int
RequestIpus
(
const
int
num_ipus
)
{
// num_ipus must be pow(2, n);
return
std
::
pow
(
2
,
ceil
(
log2
(
num_ipus
)));
return
std
::
pow
(
2
,
ceil
(
log2
(
num_ipus
)));
}
}
...
...
paddle/fluid/platform/device/ipu/ipu_utils.h
浏览文件 @
1b5647d7
...
@@ -19,155 +19,32 @@ limitations under the License. */
...
@@ -19,155 +19,32 @@ limitations under the License. */
#include <popart/tensorinfo.hpp>
#include <popart/tensorinfo.hpp>
#include <popart/vendored/any.hpp>
#include <popart/vendored/any.hpp>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
using
float16
=
paddle
::
platform
::
float16
;
using
Tensor
=
paddle
::
framework
::
Tensor
;
using
LoDTensor
=
paddle
::
framework
::
LoDTensor
;
using
Scope
=
paddle
::
framework
::
Scope
;
using
OpDesc
=
paddle
::
framework
::
OpDesc
;
using
Graph
=
paddle
::
framework
::
ir
::
Graph
;
using
Node
=
paddle
::
framework
::
ir
::
Node
;
using
BlockDesc
=
paddle
::
framework
::
BlockDesc
;
using
VarType
=
paddle
::
framework
::
proto
::
VarType
;
namespace
paddle
{
namespace
paddle
{
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
using
float16
=
platform
::
float16
;
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
Scope
=
framework
::
Scope
;
using
OpDesc
=
framework
::
OpDesc
;
using
Graph
=
framework
::
ir
::
Graph
;
using
Node
=
framework
::
ir
::
Node
;
using
BlockDesc
=
framework
::
BlockDesc
;
// onnx dtype
// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
enum
ONNXDataType
:
int
{
UNDEFINED
=
0
,
FLOAT
=
1
,
UINT8
=
2
,
INT8
=
3
,
UINT16
=
4
,
INT16
=
5
,
INT32
=
6
,
INT64
=
7
,
STRING
=
8
,
BOOL
=
9
,
FLOAT16
=
10
,
DOUBLE
=
11
,
UINT32
=
12
,
UINT64
=
13
,
COMPLEX64
=
14
,
COMPLEX128
=
15
,
BFLOAT16
=
16
};
class
PaddleIArray
final
:
public
popart
::
IArray
{
public:
explicit
PaddleIArray
(
const
Tensor
*
tensor
)
{
tensor_
.
ShareDataWith
(
*
tensor
);
for
(
int
i
=
0
;
i
<
tensor
->
dims
().
size
();
++
i
)
{
shape_
.
push_back
(
tensor
->
dims
().
at
(
i
));
}
}
public:
void
*
data
();
popart
::
DataType
dataType
()
const
;
std
::
size_t
rank
()
const
;
int64_t
dim
(
size_t
index
)
const
;
std
::
size_t
nelms
()
const
;
const
popart
::
Shape
shape
()
const
;
private:
Tensor
tensor_
;
std
::
vector
<
int64_t
>
shape_
;
};
popart
::
DataType
VarType2PopartType
(
const
framework
::
proto
::
VarType
::
Type
type
);
popart
::
DataType
PdDataType2PopartType
(
const
paddle
::
experimental
::
DataType
type
);
framework
::
proto
::
VarType
::
Type
PopartType2VarType
(
const
popart
::
DataType
type
);
popart
::
DataType
OnnxDtype2PopartType
(
const
int
type
);
bool
GetBoolEnv
(
std
::
string
str
);
template
<
typename
T
>
std
::
unique_ptr
<
popart
::
NDArrayWrapper
<
T
>>
Tensor2IArray
(
const
Tensor
&
tensor
)
{
auto
dtype
=
PdDataType2PopartType
(
tensor
.
dtype
());
auto
shape
=
std
::
vector
<
int64_t
>
();
for
(
size_t
i
=
0
;
i
<
tensor
.
dims
().
size
();
++
i
)
{
shape
.
push_back
(
tensor
.
dims
().
at
(
i
));
}
popart
::
TensorInfo
tensor_info
(
dtype
,
shape
);
return
std
::
make_unique
<
popart
::
NDArrayWrapper
<
T
>>
(
reinterpret_cast
<
T
*>
(
tensor
.
data
()),
tensor_info
);
}
template
<
typename
T
>
std
::
unique_ptr
<
popart
::
NDArrayWrapper
<
T
>>
LoDTensor2IArray
(
LoDTensor
const
&
lod_tensor
)
{
if
(
lod_tensor
.
lod
().
size
()
==
0
)
{
return
Tensor2IArray
<
T
>
(
lod_tensor
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"LoDTensor2IArray is Unimplemented"
));
}
}
template
<
typename
T
>
template
<
typename
T
>
T
GetSingleVarFromScope
(
const
Scope
*
scope
,
const
std
::
string
&
var_name
)
{
T
GetSingleVarFromScope
(
const
Scope
*
scope
,
const
std
::
string
&
var_name
)
{
auto
var
=
scope
->
GetVar
(
var_name
);
auto
var
=
scope
->
GetVar
(
var_name
);
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
auto
tensor
=
var
->
Get
<
framework
::
LoDTensor
>
();
// check dtype is ?
return
tensor
.
data
<
T
>
()[
0
];
return
tensor
.
data
<
T
>
()[
0
];
}
}
struct
CustomOpAttrVisitor
:
public
boost
::
static_visitor
<
void
>
{
explicit
CustomOpAttrVisitor
(
std
::
map
<
std
::
string
,
popart
::
any
>*
attr
,
const
std
::
string
&
attr_name
)
:
attrs_
(
attr
),
attr_name_
(
attr_name
)
{}
mutable
std
::
map
<
std
::
string
,
popart
::
any
>*
attrs_
;
std
::
string
attr_name_
;
void
operator
()(
int
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
float
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
string
&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
int
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
float
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
std
::
string
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
bool
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
bool
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
BlockDesc
*
desc
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported calling method for `BlockDesc` type."
));
}
void
operator
()(
const
std
::
vector
<
BlockDesc
*>&
v
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported calling method for `BlockDesc` type."
));
}
void
operator
()(
int64_t
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
int64_t
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
const
std
::
vector
<
double
>&
v
)
const
{
attrs_
->
emplace
(
attr_name_
,
v
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
platform
::
errors
::
Unavailable
(
"Unsupported calling method for `boost::blank` type."
));
}
};
struct
IpuCustomOpIdentifier
{
struct
IpuCustomOpIdentifier
{
IpuCustomOpIdentifier
(
const
std
::
string
&
_paddle_op
,
IpuCustomOpIdentifier
(
const
std
::
string
&
_paddle_op
,
const
std
::
string
&
_popart_op
,
const
std
::
string
&
_popart_op
,
...
@@ -185,51 +62,44 @@ struct IpuCustomOpIdentifier {
...
@@ -185,51 +62,44 @@ struct IpuCustomOpIdentifier {
popart
::
OperatorIdentifier
popart_op
;
popart
::
OperatorIdentifier
popart_op
;
};
};
struct
ConstantOpAttrVisitor
:
public
boost
::
static_visitor
<
void
>
{
// Onnx dtype
explicit
ConstantOpAttrVisitor
(
framework
::
LoDTensor
*
tensor
,
// https://github.com/onnx/onnx/blob/master/onnx/onnx-ml.proto3
framework
::
proto
::
VarType
::
Type
dtype
)
enum
ONNXDataType
:
int
{
:
tensor_
(
tensor
),
dtype_
(
dtype
)
{}
UNDEFINED
=
0
,
framework
::
LoDTensor
*
tensor_
;
FLOAT
=
1
,
framework
::
proto
::
VarType
::
Type
dtype_
;
UINT8
=
2
,
INT8
=
3
,
void
operator
()(
const
std
::
vector
<
int
>&
vec
)
const
{
UINT16
=
4
,
framework
::
TensorFromVector
<
int
>
(
vec
,
tensor_
);
INT16
=
5
,
}
INT32
=
6
,
void
operator
()(
const
std
::
vector
<
float
>&
vec
)
const
{
INT64
=
7
,
if
(
dtype_
==
framework
::
proto
::
VarType
::
FP16
)
{
STRING
=
8
,
std
::
vector
<
float16
>
vec_fp16
;
BOOL
=
9
,
std
::
transform
(
vec
.
begin
(),
vec
.
end
(),
std
::
back_inserter
(
vec_fp16
),
FLOAT16
=
10
,
[](
float
f
)
->
float16
{
return
float16
(
f
);
});
DOUBLE
=
11
,
framework
::
TensorFromVector
<
float16
>
(
vec_fp16
,
tensor_
);
UINT32
=
12
,
}
else
{
UINT64
=
13
,
framework
::
TensorFromVector
<
float
>
(
vec
,
tensor_
);
COMPLEX64
=
14
,
}
COMPLEX128
=
15
,
}
BFLOAT16
=
16
void
operator
()(
const
std
::
vector
<
bool
>&
vec
)
const
{
framework
::
TensorFromVector
<
bool
>
(
vec
,
tensor_
);
}
void
operator
()(
const
std
::
vector
<
int64_t
>&
vec
)
const
{
framework
::
TensorFromVector
<
int64_t
>
(
vec
,
tensor_
);
}
void
operator
()(
const
std
::
vector
<
double
>&
vec
)
const
{
framework
::
TensorFromVector
<
double
>
(
vec
,
tensor_
);
}
void
RaiseError
()
const
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Constant value must be a vector"
));
}
void
operator
()(
int
v
)
const
{
RaiseError
();
}
void
operator
()(
float
v
)
const
{
RaiseError
();
}
void
operator
()(
const
std
::
string
&
v
)
const
{
RaiseError
();
}
void
operator
()(
const
std
::
vector
<
std
::
string
>&
v
)
const
{
RaiseError
();
}
void
operator
()(
bool
v
)
const
{
RaiseError
();
}
void
operator
()(
BlockDesc
*
desc
)
const
{
RaiseError
();
}
void
operator
()(
const
std
::
vector
<
BlockDesc
*>&
v
)
const
{
RaiseError
();
}
void
operator
()(
int64_t
v
)
const
{
RaiseError
();
}
void
operator
()(
boost
::
blank
)
const
{
RaiseError
();
}
};
};
int
RequestIpus
(
const
int
num_ipus
);
// VarType::Type to popart::DataType
const
popart
::
DataType
VarType2PopartDType
(
const
VarType
::
Type
type
);
// phi::DataType to popart::DataType
const
popart
::
DataType
PhiDType2PopartDType
(
const
phi
::
DataType
type
);
// popart::DataType to VarType::Type
const
VarType
::
Type
PopartDType2VarType
(
const
popart
::
DataType
type
);
// ONNXDataType to popart::DataType
const
popart
::
DataType
OnnxDType2PopartType
(
const
ONNXDataType
type
);
// VarType::Type to ONNXDataType
const
ONNXDataType
VarType2OnnxDType
(
const
VarType
::
Type
type
);
// VarType::Type to String in Popart
const
std
::
string
VarType2PopartStr
(
const
VarType
::
Type
type
);
// Get bool from envirnment varaible
const
bool
GetBoolEnv
(
const
std
::
string
&
str
);
// Request number of ipus must be pow(2, n)
const
int
RequestIpus
(
const
int
num_ipus
);
}
// namespace ipu
}
// namespace ipu
}
// namespace platform
}
// namespace platform
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc
浏览文件 @
1b5647d7
...
@@ -56,15 +56,15 @@ Node *gelu_handler(Graph *graph, Node *node) {
...
@@ -56,15 +56,15 @@ Node *gelu_handler(Graph *graph, Node *node) {
auto
sqrt2
=
CreateConst
(
graph
,
node
,
{},
{},
auto
sqrt2
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
1.4142135623730951
}},
{{
"value"
,
std
::
vector
<
float
>
{
1.4142135623730951
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
GetOutputVarD
t
ype
(
node
)}});
{
"dtype"
,
GetOutputVarD
T
ype
(
node
)}});
auto
zero_point_five
=
auto
zero_point_five
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
0.5
}},
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
0.5
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
GetOutputVarD
t
ype
(
node
)}});
{
"dtype"
,
GetOutputVarD
T
ype
(
node
)}});
auto
one
=
auto
one
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
1
}},
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
GetOutputVarD
t
ype
(
node
)}});
{
"dtype"
,
GetOutputVarD
T
ype
(
node
)}});
auto
div
=
auto
div
=
CreateBaseOp
(
graph
,
node
,
"popart_div"
,
CreateBaseOp
(
graph
,
node
,
"popart_div"
,
{
GetInputVarNode
(
"X"
,
node
),
sqrt2
->
outputs
[
0
]},
{},
{});
{
GetInputVarNode
(
"X"
,
node
),
sqrt2
->
outputs
[
0
]},
{},
{});
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc
浏览文件 @
1b5647d7
...
@@ -18,7 +18,6 @@ namespace paddle {
...
@@ -18,7 +18,6 @@ namespace paddle {
namespace
platform
{
namespace
platform
{
namespace
ipu
{
namespace
ipu
{
// This avoids the static initialisation order fiasco,
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
&
SymbolHandlers
()
{
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
&
SymbolHandlers
()
{
static
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
symbol_handlers
;
static
std
::
unordered_map
<
std
::
string
,
SymbolHandler
>
symbol_handlers
;
return
symbol_handlers
;
return
symbol_handlers
;
...
@@ -34,8 +33,6 @@ bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) {
...
@@ -34,8 +33,6 @@ bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) {
return
new_handler
;
return
new_handler
;
}
}
// Return a pointer to a handler if one is registered for this kind of node or
// an empty std::function otherwise.
SymbolHandler
GetHandler
(
const
std
::
string
&
kind
)
{
SymbolHandler
GetHandler
(
const
std
::
string
&
kind
)
{
auto
it
=
SymbolHandlers
().
find
(
kind
);
auto
it
=
SymbolHandlers
().
find
(
kind
);
if
(
it
!=
SymbolHandlers
().
end
())
{
if
(
it
!=
SymbolHandlers
().
end
())
{
...
@@ -84,66 +81,6 @@ void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
...
@@ -84,66 +81,6 @@ void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
}
}
}
}
const
int
VarType2OnnxDtype
(
const
int
type
)
{
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
type
);
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
BOOL
:
return
static_cast
<
int
>
(
ONNXDataType
::
BOOL
);
case
framework
::
proto
::
VarType
::
INT16
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT16
);
case
framework
::
proto
::
VarType
::
INT32
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT32
);
case
framework
::
proto
::
VarType
::
INT64
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT64
);
case
framework
::
proto
::
VarType
::
FP16
:
return
static_cast
<
int
>
(
ONNXDataType
::
FLOAT16
);
case
framework
::
proto
::
VarType
::
FP32
:
return
static_cast
<
int
>
(
ONNXDataType
::
FLOAT
);
case
framework
::
proto
::
VarType
::
FP64
:
return
static_cast
<
int
>
(
ONNXDataType
::
DOUBLE
);
case
framework
::
proto
::
VarType
::
UINT8
:
return
static_cast
<
int
>
(
ONNXDataType
::
UINT8
);
case
framework
::
proto
::
VarType
::
INT8
:
return
static_cast
<
int
>
(
ONNXDataType
::
INT8
);
case
framework
::
proto
::
VarType
::
BF16
:
return
static_cast
<
int
>
(
ONNXDataType
::
BFLOAT16
);
case
framework
::
proto
::
VarType
::
COMPLEX64
:
return
static_cast
<
int
>
(
ONNXDataType
::
COMPLEX64
);
case
framework
::
proto
::
VarType
::
COMPLEX128
:
return
static_cast
<
int
>
(
ONNXDataType
::
COMPLEX128
);
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported data type: %d."
,
dtype
));
}
}
const
std
::
string
VarType2PopStr
(
const
int
type
)
{
auto
dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
type
);
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
UINT8
:
return
"UINT8"
;
case
framework
::
proto
::
VarType
::
INT8
:
return
"INT8"
;
case
framework
::
proto
::
VarType
::
INT16
:
return
"INT16"
;
case
framework
::
proto
::
VarType
::
INT32
:
return
"INT32"
;
case
framework
::
proto
::
VarType
::
INT64
:
return
"INT64"
;
case
framework
::
proto
::
VarType
::
BOOL
:
return
"BOOL"
;
case
framework
::
proto
::
VarType
::
FP64
:
return
"DOUBLE"
;
case
framework
::
proto
::
VarType
::
FP32
:
return
"FLOAT"
;
case
framework
::
proto
::
VarType
::
FP16
:
return
"FLOAT16"
;
default:
PADDLE_THROW
(
paddle
::
platform
::
errors
::
Unavailable
(
"Unsupported data type."
));
}
}
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
const
int
id
)
{
const
int
id
)
{
auto
var_name
=
op_node
->
Op
()
->
Input
(
input_name
).
at
(
id
);
auto
var_name
=
op_node
->
Op
()
->
Input
(
input_name
).
at
(
id
);
...
@@ -180,7 +117,7 @@ const bool is_float_equal(float a, float b, float eps) {
...
@@ -180,7 +117,7 @@ const bool is_float_equal(float a, float b, float eps) {
return
std
::
fabs
(
a
-
b
)
<=
eps
;
return
std
::
fabs
(
a
-
b
)
<=
eps
;
}
}
const
int
GetOutputVarD
t
ype
(
const
Node
*
node
,
const
std
::
string
&
output_name
)
{
const
int
GetOutputVarD
T
ype
(
const
Node
*
node
,
const
std
::
string
&
output_name
)
{
auto
out_node
=
GetOutputVarNode
(
output_name
,
node
);
auto
out_node
=
GetOutputVarNode
(
output_name
,
node
);
PADDLE_ENFORCE_NOT_NULL
(
out_node
,
platform
::
errors
::
Unavailable
(
PADDLE_ENFORCE_NOT_NULL
(
out_node
,
platform
::
errors
::
Unavailable
(
"Node's out node does not exist."
));
"Node's out node does not exist."
));
...
@@ -188,7 +125,7 @@ const int GetOutputVarDtype(const Node *node, const std::string &output_name) {
...
@@ -188,7 +125,7 @@ const int GetOutputVarDtype(const Node *node, const std::string &output_name) {
PADDLE_ENFORCE_NOT_NULL
(
PADDLE_ENFORCE_NOT_NULL
(
var
,
platform
::
errors
::
Unavailable
(
"Node is not a variable."
));
var
,
platform
::
errors
::
Unavailable
(
"Node is not a variable."
));
auto
proto_var_type
=
var
->
GetDataType
();
auto
proto_var_type
=
var
->
GetDataType
();
return
VarType2OnnxDtype
(
proto_var_type
);
return
static_cast
<
int
>
(
VarType2OnnxDType
(
proto_var_type
)
);
}
}
}
// namespace ipu
}
// namespace ipu
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h
浏览文件 @
1b5647d7
...
@@ -68,9 +68,6 @@ void ClearNode(Node *node);
...
@@ -68,9 +68,6 @@ void ClearNode(Node *node);
void
CopyOpAttr
(
const
std
::
string
&
attr_name
,
OpDesc
*
op
,
OpDesc
*
new_op
,
void
CopyOpAttr
(
const
std
::
string
&
attr_name
,
OpDesc
*
op
,
OpDesc
*
new_op
,
bool
override
=
false
);
bool
override
=
false
);
const
int
VarType2OnnxDtype
(
const
int
type
);
const
std
::
string
VarType2PopStr
(
const
int
type
);
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
Node
*
GetInputVarNode
(
const
std
::
string
&
input_name
,
const
Node
*
op_node
,
const
int
id
=
0
);
const
int
id
=
0
);
Node
*
GetOutputVarNode
(
const
std
::
string
&
output_name
,
const
Node
*
op_node
,
Node
*
GetOutputVarNode
(
const
std
::
string
&
output_name
,
const
Node
*
op_node
,
...
@@ -81,7 +78,7 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
...
@@ -81,7 +78,7 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
const
Node
*
op_node
);
const
Node
*
op_node
);
const
bool
is_float_equal
(
float
a
,
float
b
,
float
eps
=
1e-8
);
const
bool
is_float_equal
(
float
a
,
float
b
,
float
eps
=
1e-8
);
const
int
GetOutputVarD
t
ype
(
const
Node
*
node
,
const
int
GetOutputVarD
T
ype
(
const
Node
*
node
,
const
std
::
string
&
output_name
=
"Out"
);
const
std
::
string
&
output_name
=
"Out"
);
}
// namespace ipu
}
// namespace ipu
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc
浏览文件 @
1b5647d7
...
@@ -28,6 +28,14 @@ Node *equal_handler(Graph *graph, Node *node) {
...
@@ -28,6 +28,14 @@ Node *equal_handler(Graph *graph, Node *node) {
return
new_node
;
return
new_node
;
}
}
Node
*
not_equal_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
equal_node
=
CreateBaseOp
(
graph
,
node
,
"popart_equal"
,
{
GetInputVarNode
(
"X"
,
node
),
GetInputVarNode
(
"Y"
,
node
)},
{});
return
CreateBaseOp
(
graph
,
node
,
"popart_logical_not"
,
{
equal_node
->
outputs
[
0
]},
node
->
outputs
,
{});
}
Node
*
logical_not_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
logical_not_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
CreateBaseOp
(
graph
,
node
,
"popart_logical_not"
,
return
CreateBaseOp
(
graph
,
node
,
"popart_logical_not"
,
{
GetInputVarNode
(
"X"
,
node
)},
{
GetInputVarNode
(
"X"
,
node
)},
...
@@ -64,6 +72,7 @@ Node *less_than_handler(Graph *graph, Node *node) {
...
@@ -64,6 +72,7 @@ Node *less_than_handler(Graph *graph, Node *node) {
}
// namespace paddle
}
// namespace paddle
REGISTER_HANDLER
(
equal
,
equal_handler
);
REGISTER_HANDLER
(
equal
,
equal_handler
);
REGISTER_HANDLER
(
not_equal
,
not_equal_handler
);
REGISTER_HANDLER
(
logical_not
,
logical_not_handler
);
REGISTER_HANDLER
(
logical_not
,
logical_not_handler
);
REGISTER_HANDLER
(
logical_or
,
logical_or_handler
);
REGISTER_HANDLER
(
logical_or
,
logical_or_handler
);
REGISTER_HANDLER
(
logical_and
,
logical_and_handler
);
REGISTER_HANDLER
(
logical_and
,
logical_and_handler
);
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc
浏览文件 @
1b5647d7
...
@@ -41,7 +41,7 @@ Node *pow_handler(Graph *graph, Node *node) {
...
@@ -41,7 +41,7 @@ Node *pow_handler(Graph *graph, Node *node) {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto
value_
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"factor"
));
auto
value_
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"factor"
));
auto
attrs
=
auto
attrs
=
MakeConstAttrMapFromValue
<
float
>
(
value_
,
{
1
},
GetOutputVarD
t
ype
(
node
));
MakeConstAttrMapFromValue
<
float
>
(
value_
,
{
1
},
GetOutputVarD
T
ype
(
node
));
auto
new_node_const
=
CreateConst
(
graph
,
node
,
{},
{},
attrs
);
auto
new_node_const
=
CreateConst
(
graph
,
node
,
{},
{},
attrs
);
return
CreateBaseOp
(
graph
,
node
,
"popart_pow"
,
{
GetInputVarNode
(
"X"
,
node
),
return
CreateBaseOp
(
graph
,
node
,
"popart_pow"
,
{
GetInputVarNode
(
"X"
,
node
),
...
@@ -134,7 +134,7 @@ Node *matmul_handler(Graph *graph, Node *node) {
...
@@ -134,7 +134,7 @@ Node *matmul_handler(Graph *graph, Node *node) {
}
else
{
}
else
{
auto
o_node
=
auto
o_node
=
CreateBaseOp
(
graph
,
node
,
"popart_matmul"
,
{
x_node
,
y_node
},
{});
CreateBaseOp
(
graph
,
node
,
"popart_matmul"
,
{
x_node
,
y_node
},
{});
auto
attr
=
MakeConstAttrMapFromValue
(
alpha
,
{
1
},
GetOutputVarD
t
ype
(
node
));
auto
attr
=
MakeConstAttrMapFromValue
(
alpha
,
{
1
},
GetOutputVarD
T
ype
(
node
));
auto
const_node
=
CreateConst
(
graph
,
node
,
{},
{},
attr
);
auto
const_node
=
CreateConst
(
graph
,
node
,
{},
{},
attr
);
return
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
return
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
o_node
->
outputs
[
0
],
const_node
->
outputs
[
0
]},
{
o_node
->
outputs
[
0
],
const_node
->
outputs
[
0
]},
...
@@ -299,6 +299,80 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
...
@@ -299,6 +299,80 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
}
}
}
}
Node
*
softmax_with_cross_entropy_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
ignoreIndex
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"ignore_index"
));
auto
axis
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
auto
soft_label
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"soft_label"
));
if
(
soft_label
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"soft_label is not supported yet in IPU"
));
}
Node
*
new_cast
=
nullptr
;
if
(
GetInputVarNode
(
"Label"
,
node
)
->
Var
()
->
GetDataType
()
==
framework
::
proto
::
VarType
::
INT32
)
{
new_cast
=
GetInputVarNode
(
"Label"
,
node
);
}
else
{
auto
new_cast
=
CreateCast
(
graph
,
node
,
{
GetInputVarNode
(
"Label"
,
node
)},
{},
framework
::
proto
::
VarType
::
INT32
);
new_cast
=
new_cast
->
outputs
[
0
];
}
auto
softmax_node
=
CreateSoftmaxOpset11
(
graph
,
node
,
{
GetInputVarNode
(
"Logits"
,
node
)},
{},
axis
);
auto
label_shape_
=
GetInputVarNode
(
"Label"
,
node
)
->
Var
()
->
GetShape
();
if
(
label_shape_
[
label_shape_
.
size
()
-
1
]
!=
1
)
{
auto
log
=
CreateBaseOp
(
graph
,
node
,
"popart_log"
,
{
softmax_node
->
outputs
[
0
]},
{},
{});
// softmax_with_cross_entropy is split to several ops in python.
// reduction is not needed here.
return
CreateBaseOp
(
graph
,
node
,
"popart_nllloss_v2"
,
{
log
->
outputs
[
0
],
new_cast
},
{
GetOutputVarNode
(
"Loss"
,
node
)},
{
{
"reduction"
,
2
},
// popart::ReductionType::NoReduction
{
"ignoreIndex"
,
ignoreIndex
},
{
"inputIsLogProbability"
,
true
},
});
}
else
{
std
::
vector
<
int64_t
>
new_shape_
{
label_shape_
[
0
]};
auto
const_before_loss
=
CreateBaseOp
(
graph
,
node
,
"popart_constant"
,
{},
{},
{{
"value"
,
new_shape_
},
{
"dims"
,
std
::
vector
<
int64_t
>
{
static_cast
<
int64_t
>
(
new_shape_
.
size
())}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
reshape_before_loss
=
CreateBaseOp
(
graph
,
node
,
"popart_reshape"
,
{
new_cast
,
const_before_loss
->
outputs
[
0
]},
{},
{});
auto
log
=
CreateBaseOp
(
graph
,
node
,
"popart_log"
,
{
softmax_node
->
outputs
[
0
]},
{},
{});
auto
nllloss
=
CreateBaseOp
(
graph
,
node
,
"popart_nllloss_v2"
,
{
log
->
outputs
[
0
],
reshape_before_loss
->
outputs
[
0
]},
{},
{
{
"reduction"
,
2
},
// popart::ReductionType::NoReduction
{
"ignoreIndex"
,
ignoreIndex
},
{
"inputIsLogProbability"
,
true
},
});
auto
const_after_loss
=
CreateBaseOp
(
graph
,
node
,
"popart_constant"
,
{},
{},
{{
"value"
,
label_shape_
},
{
"dims"
,
std
::
vector
<
int64_t
>
{
static_cast
<
int64_t
>
(
label_shape_
.
size
())}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
reshape_after_loss
=
CreateBaseOp
(
graph
,
node
,
"popart_reshape"
,
{
nllloss
->
outputs
[
0
],
const_after_loss
->
outputs
[
0
]},
{
GetOutputVarNode
(
"Loss"
,
node
)},
{});
return
reshape_after_loss
;
}
}
Node
*
cumsum_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
cumsum_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
exclusive
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"exclusive"
));
auto
exclusive
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"exclusive"
));
...
@@ -378,6 +452,8 @@ REGISTER_HANDLER(matmul, matmul_handler);
...
@@ -378,6 +452,8 @@ REGISTER_HANDLER(matmul, matmul_handler);
REGISTER_HANDLER
(
sum
,
sum_handler
);
REGISTER_HANDLER
(
sum
,
sum_handler
);
REGISTER_HANDLER
(
softmax
,
softmax_handler
);
REGISTER_HANDLER
(
softmax
,
softmax_handler
);
REGISTER_HANDLER
(
scale
,
scale_handler
);
REGISTER_HANDLER
(
scale
,
scale_handler
);
REGISTER_HANDLER
(
softmax_with_cross_entropy
,
softmax_with_cross_entropy_handler
);
REGISTER_HANDLER
(
cross_entropy2
,
cross_entropy2_handler
);
REGISTER_HANDLER
(
cross_entropy2
,
cross_entropy2_handler
);
REGISTER_HANDLER
(
cumsum
,
cumsum_handler
);
REGISTER_HANDLER
(
cumsum
,
cumsum_handler
);
REGISTER_HANDLER
(
matmul_v2
,
matmul_v2_handler
);
REGISTER_HANDLER
(
matmul_v2
,
matmul_v2_handler
);
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc
浏览文件 @
1b5647d7
...
@@ -299,7 +299,7 @@ Node *dropout_handler(Graph *graph, Node *node) {
...
@@ -299,7 +299,7 @@ Node *dropout_handler(Graph *graph, Node *node) {
CreateConst
(
graph
,
node
,
{},
{},
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
1
-
dropout_prob_
}},
{{
"value"
,
std
::
vector
<
float
>
{
1
-
dropout_prob_
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
GetOutputVarD
t
ype
(
node
)}});
{
"dtype"
,
GetOutputVarD
T
ype
(
node
)}});
return
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
return
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
GetInputVarNode
(
"X"
,
node
),
scale
->
outputs
[
0
]},
{
GetInputVarNode
(
"X"
,
node
),
scale
->
outputs
[
0
]},
{
GetOutputVarNode
(
"Out"
,
node
)},
{});
{
GetOutputVarNode
(
"Out"
,
node
)},
{});
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc
浏览文件 @
1b5647d7
...
@@ -124,7 +124,7 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs,
...
@@ -124,7 +124,7 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs,
Node
*
CreateCast
(
Graph
*
graph
,
Node
*
node
,
const
std
::
vector
<
Node
*>
&
inputs
,
Node
*
CreateCast
(
Graph
*
graph
,
Node
*
node
,
const
std
::
vector
<
Node
*>
&
inputs
,
const
std
::
vector
<
Node
*>
&
outputs
,
const
int
otype
)
{
const
std
::
vector
<
Node
*>
&
outputs
,
const
int
otype
)
{
auto
to
=
VarType2Pop
Str
(
otype
);
auto
to
=
VarType2Pop
artStr
(
static_cast
<
VarType
::
Type
>
(
otype
)
);
return
CreateBaseOp
(
graph
,
node
,
"popart_cast"
,
inputs
,
outputs
,
return
CreateBaseOp
(
graph
,
node
,
"popart_cast"
,
inputs
,
outputs
,
{{
"to"
,
to
}});
{{
"to"
,
to
}});
}
}
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc
浏览文件 @
1b5647d7
...
@@ -23,12 +23,14 @@ namespace {
...
@@ -23,12 +23,14 @@ namespace {
Node
*
fill_constant_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
fill_constant_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
if
(
!
op
->
Input
(
"ShapeTensor"
).
empty
())
{
auto
op_inputs
=
op
->
Inputs
();
if
(
op_inputs
.
find
(
"ShapeTensor"
)
!=
op_inputs
.
end
()
&&
!
op
->
Input
(
"ShapeTensor"
).
empty
())
{
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"op fill_constant with ShapeTensor"
));
platform
::
errors
::
Unimplemented
(
"op fill_constant with ShapeTensor"
));
}
}
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxD
type
(
dtype_
);
auto
dtype
=
VarType2OnnxD
Type
(
static_cast
<
VarType
::
Type
>
(
dtype_
)
);
auto
dims
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
dims
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
value_
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"value"
));
auto
value_
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"value"
));
size_t
size
=
1
;
size_t
size
=
1
;
...
@@ -37,19 +39,20 @@ Node *fill_constant_handler(Graph *graph, Node *node) {
...
@@ -37,19 +39,20 @@ Node *fill_constant_handler(Graph *graph, Node *node) {
}
}
Attribute
value
;
Attribute
value
;
switch
(
dtype_
)
{
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
FP32
:
case
VarType
::
FP16
:
case
VarType
::
FP32
:
value
=
std
::
vector
<
float
>
(
size
,
value_
);
value
=
std
::
vector
<
float
>
(
size
,
value_
);
break
;
break
;
case
framework
::
proto
::
VarType
::
FP64
:
case
VarType
::
FP64
:
value
=
std
::
vector
<
double
>
(
size
,
value_
);
value
=
std
::
vector
<
double
>
(
size
,
value_
);
break
;
break
;
case
framework
::
proto
::
VarType
::
INT32
:
case
VarType
::
INT32
:
value
=
std
::
vector
<
int
>
(
size
,
value_
);
value
=
std
::
vector
<
int
>
(
size
,
value_
);
break
;
break
;
case
framework
::
proto
::
VarType
::
INT64
:
case
VarType
::
INT64
:
value
=
std
::
vector
<
int64_t
>
(
size
,
value_
);
value
=
std
::
vector
<
int64_t
>
(
size
,
value_
);
break
;
break
;
case
framework
::
proto
::
VarType
::
BOOL
:
case
VarType
::
BOOL
:
value
=
std
::
vector
<
bool
>
(
size
,
value_
);
value
=
std
::
vector
<
bool
>
(
size
,
value_
);
break
;
break
;
default:
default:
...
@@ -66,7 +69,7 @@ Node *gaussian_random_handler(Graph *graph, Node *node) {
...
@@ -66,7 +69,7 @@ Node *gaussian_random_handler(Graph *graph, Node *node) {
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxD
type
(
dtype_
);
auto
dtype
=
VarType2OnnxD
Type
(
static_cast
<
VarType
::
Type
>
(
dtype_
)
);
auto
mean
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"mean"
));
auto
mean
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"mean"
));
auto
scale
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"std"
));
auto
scale
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"std"
));
// seed not work
// seed not work
...
@@ -86,7 +89,7 @@ Node *uniform_random_handler(Graph *graph, Node *node) {
...
@@ -86,7 +89,7 @@ Node *uniform_random_handler(Graph *graph, Node *node) {
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
shape
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
"shape"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxD
type
(
dtype_
);
auto
dtype
=
VarType2OnnxD
Type
(
static_cast
<
VarType
::
Type
>
(
dtype_
)
);
auto
high
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"max"
));
auto
high
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"max"
));
auto
low
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"min"
));
auto
low
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"min"
));
// seed not work
// seed not work
...
@@ -172,9 +175,21 @@ Node *squeeze_handler(Graph *graph, Node *node) {
...
@@ -172,9 +175,21 @@ Node *squeeze_handler(Graph *graph, Node *node) {
Node
*
cast_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
cast_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
otype
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"out_dtype"
));
auto
otype
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"out_dtype"
));
auto
new_node_cast
=
auto
new_node
=
CreateCast
(
graph
,
node
,
node
->
inputs
,
node
->
outputs
,
otype
);
CreateCast
(
graph
,
node
,
node
->
inputs
,
node
->
outputs
,
otype
);
// Cast op created in mixed-precison has no pipline attrs
return
new_node_cast
;
auto
&
prev_nodes
=
node
->
inputs
.
front
()
->
inputs
;
if
(
!
prev_nodes
.
empty
())
{
auto
*
prev_op
=
prev_nodes
.
front
()
->
Op
();
if
(
!
new_node
->
Op
()
->
HasAttr
(
sIpuIndexAttr
)
&&
prev_op
->
HasAttr
(
sIpuIndexAttr
))
{
CopyOpAttr
(
sIpuIndexAttr
,
prev_op
,
new_node
->
Op
());
}
if
(
!
new_node
->
Op
()
->
HasAttr
(
sIpuStageAttr
)
&&
prev_op
->
HasAttr
(
sIpuStageAttr
))
{
CopyOpAttr
(
sIpuStageAttr
,
prev_op
,
new_node
->
Op
());
}
}
return
new_node
;
}
}
Node
*
lookup_table_op_handler
(
Graph
*
graph
,
Node
*
node
,
Node
*
lookup_table_op_handler
(
Graph
*
graph
,
Node
*
node
,
...
@@ -192,7 +207,7 @@ Node *lookup_table_op_handler(Graph *graph, Node *node,
...
@@ -192,7 +207,7 @@ Node *lookup_table_op_handler(Graph *graph, Node *node,
auto
concat_const
=
auto
concat_const
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
const_value_
},
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
const_value_
},
{
"dims"
,
const_shape_
},
{
"dims"
,
const_shape_
},
{
"dtype"
,
GetOutputVarD
t
ype
(
node
)}});
{
"dtype"
,
GetOutputVarD
T
ype
(
node
)}});
auto
axes
=
auto
axes
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
0
}},
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
0
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
...
@@ -397,7 +412,7 @@ Node *expand_handler(Graph *graph, Node *node) {
...
@@ -397,7 +412,7 @@ Node *expand_handler(Graph *graph, Node *node) {
// cast to int64
// cast to int64
expand_times
=
expand_times
=
CreateCast
(
graph
,
node
,
{
GetInputVarNode
(
"ExpandTimes"
,
node
)},
{},
CreateCast
(
graph
,
node
,
{
GetInputVarNode
(
"ExpandTimes"
,
node
)},
{},
framework
::
proto
::
VarType
::
INT64
);
VarType
::
INT64
);
}
else
{
}
else
{
auto
expand_times_i32
=
auto
expand_times_i32
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"expand_times"
));
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"expand_times"
));
...
@@ -423,27 +438,28 @@ Node *assign_handler(Graph *graph, Node *node) {
...
@@ -423,27 +438,28 @@ Node *assign_handler(Graph *graph, Node *node) {
Node
*
assign_value_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
assign_value_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
=
VarType2OnnxD
type
(
dtype_
);
auto
dtype
=
VarType2OnnxD
Type
(
static_cast
<
VarType
::
Type
>
(
dtype_
)
);
auto
dims_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"shape"
));
auto
dims_
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"shape"
));
std
::
vector
<
int64_t
>
dims
(
dims_
.
begin
(),
dims_
.
end
());
std
::
vector
<
int64_t
>
dims
(
dims_
.
begin
(),
dims_
.
end
());
Attribute
values
;
Attribute
values
;
std
::
string
value_name
;
std
::
string
value_name
;
switch
(
dtype_
)
{
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
BOOL
:
{
case
VarType
::
BOOL
:
{
value_name
=
"bool_values"
;
value_name
=
"bool_values"
;
auto
vec_int
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
value_name
));
auto
vec_int
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
value_name
));
std
::
vector
<
bool
>
vec_bool
(
vec_int
.
begin
(),
vec_int
.
end
());
std
::
vector
<
bool
>
vec_bool
(
vec_int
.
begin
(),
vec_int
.
end
());
values
=
vec_bool
;
values
=
vec_bool
;
}
break
;
}
break
;
case
framework
::
proto
::
VarType
::
INT32
:
case
VarType
::
INT32
:
value_name
=
"int32_values"
;
value_name
=
"int32_values"
;
values
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
value_name
));
values
=
BOOST_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
value_name
));
break
;
break
;
case
framework
::
proto
::
VarType
::
FP32
:
case
VarType
::
FP16
:
case
VarType
::
FP32
:
value_name
=
"fp32_values"
;
value_name
=
"fp32_values"
;
values
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op
->
GetAttr
(
value_name
));
values
=
BOOST_GET_CONST
(
std
::
vector
<
float
>
,
op
->
GetAttr
(
value_name
));
break
;
break
;
case
framework
::
proto
::
VarType
::
INT64
:
case
VarType
::
INT64
:
value_name
=
"int64_values"
;
value_name
=
"int64_values"
;
values
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
value_name
));
values
=
BOOST_GET_CONST
(
std
::
vector
<
int64_t
>
,
op
->
GetAttr
(
value_name
));
break
;
break
;
...
@@ -463,39 +479,40 @@ Node *fill_any_like_handler(Graph *graph, Node *node) {
...
@@ -463,39 +479,40 @@ Node *fill_any_like_handler(Graph *graph, Node *node) {
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
value
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"value"
));
auto
value
=
BOOST_GET_CONST
(
float
,
op
->
GetAttr
(
"value"
));
auto
x_shape
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
auto
x_shape
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
auto
dtype
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype
_
=
BOOST_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
x_dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
dtype
);
auto
dtype
=
static_cast
<
VarType
::
Type
>
(
dtype_
);
size_t
size
=
1
;
size_t
size
=
1
;
for
(
auto
&
dim
:
x_shape
)
{
for
(
auto
&
dim
:
x_shape
)
{
size
*=
dim
;
size
*=
dim
;
}
}
Attribute
out_value
;
Attribute
out_value
;
switch
(
x_dtype
)
{
switch
(
dtype
)
{
case
framework
::
proto
::
VarType
::
FP32
:
case
VarType
::
FP16
:
case
VarType
::
FP32
:
out_value
=
std
::
vector
<
float
>
(
size
,
value
);
out_value
=
std
::
vector
<
float
>
(
size
,
value
);
break
;
break
;
case
framework
::
proto
::
VarType
::
FP64
:
case
VarType
::
FP64
:
out_value
=
std
::
vector
<
double
>
(
size
,
value
);
out_value
=
std
::
vector
<
double
>
(
size
,
value
);
break
;
break
;
case
framework
::
proto
::
VarType
::
INT32
:
case
VarType
::
INT32
:
out_value
=
std
::
vector
<
int
>
(
size
,
value
);
out_value
=
std
::
vector
<
int
>
(
size
,
value
);
break
;
break
;
case
framework
::
proto
::
VarType
::
INT64
:
case
VarType
::
INT64
:
out_value
=
std
::
vector
<
int64_t
>
(
size
,
value
);
out_value
=
std
::
vector
<
int64_t
>
(
size
,
value
);
break
;
break
;
case
framework
::
proto
::
VarType
::
BOOL
:
case
VarType
::
BOOL
:
out_value
=
std
::
vector
<
int64_t
>
(
size
,
value
);
out_value
=
std
::
vector
<
int64_t
>
(
size
,
value
);
break
;
break
;
default:
default:
PADDLE_THROW
(
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"fill_any_like dtype: %d"
,
x_
dtype
));
platform
::
errors
::
Unimplemented
(
"fill_any_like dtype: %d"
,
dtype
));
}
}
return
CreateConst
(
graph
,
node
,
node
->
inputs
,
node
->
outputs
,
return
CreateConst
(
graph
,
node
,
node
->
inputs
,
node
->
outputs
,
AttributeMap
{
AttributeMap
{
{
"value"
,
out_value
},
{
"value"
,
out_value
},
{
"dims"
,
x_shape
},
{
"dims"
,
x_shape
},
{
"dtype"
,
VarType2OnnxD
t
ype
(
dtype
)},
{
"dtype"
,
VarType2OnnxD
T
ype
(
dtype
)},
});
});
}
}
...
@@ -538,8 +555,7 @@ Node *one_hot_v2_handler(Graph *graph, Node *node) {
...
@@ -538,8 +555,7 @@ Node *one_hot_v2_handler(Graph *graph, Node *node) {
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT32
}});
{
"dtype"
,
ONNXDataType
::
INT32
}});
Node
*
value_tensor
=
nullptr
;
Node
*
value_tensor
=
nullptr
;
if
(
GetOutputVarNode
(
"Out"
,
node
)
->
Var
()
->
GetDataType
()
==
if
(
GetOutputVarNode
(
"Out"
,
node
)
->
Var
()
->
GetDataType
()
==
VarType
::
FP16
)
{
framework
::
proto
::
VarType
::
FP16
)
{
value_tensor
=
value_tensor
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
0
,
1
}},
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
float
>
{
0
,
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
2
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
2
}},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录