Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
8643dbc2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8643dbc2
编写于
4月 11, 2019
作者:
N
nhzlx
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cherry-pick from 16691:Anakin subgraph support yolo_v3 and faster-rcnn
上级
463f88a7
变更
20
隐藏空白更改
内联
并排
Showing
20 changed file
with
382 addition
and
35 deletion
+382
-35
paddle/fluid/inference/anakin/convert/CMakeLists.txt
paddle/fluid/inference/anakin/convert/CMakeLists.txt
+2
-2
paddle/fluid/inference/anakin/convert/affine_channel.cc
paddle/fluid/inference/anakin/convert/affine_channel.cc
+100
-0
paddle/fluid/inference/anakin/convert/affine_channel.h
paddle/fluid/inference/anakin/convert/affine_channel.h
+39
-0
paddle/fluid/inference/anakin/convert/op_converter.h
paddle/fluid/inference/anakin/convert/op_converter.h
+8
-8
paddle/fluid/inference/anakin/convert/relu.cc
paddle/fluid/inference/anakin/convert/relu.cc
+18
-0
paddle/fluid/inference/anakin/convert/relu.h
paddle/fluid/inference/anakin/convert/relu.h
+11
-0
paddle/fluid/inference/anakin/convert/roi_align.cc
paddle/fluid/inference/anakin/convert/roi_align.cc
+59
-0
paddle/fluid/inference/anakin/convert/roi_align.h
paddle/fluid/inference/anakin/convert/roi_align.h
+38
-0
paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc
.../fluid/inference/anakin/convert/test_affine_channel_op.cc
+55
-0
paddle/fluid/inference/anakin/convert/test_relu_op.cc
paddle/fluid/inference/anakin/convert/test_relu_op.cc
+9
-2
paddle/fluid/inference/anakin/convert/ut_helper.h
paddle/fluid/inference/anakin/convert/ut_helper.h
+2
-2
paddle/fluid/inference/anakin/engine.cc
paddle/fluid/inference/anakin/engine.cc
+18
-8
paddle/fluid/inference/anakin/engine.h
paddle/fluid/inference/anakin/engine.h
+12
-9
paddle/fluid/inference/anakin/op_teller.cc
paddle/fluid/inference/anakin/op_teller.cc
+2
-0
paddle/fluid/inference/anakin/test_anakin_engine.cc
paddle/fluid/inference/anakin/test_anakin_engine.cc
+1
-1
paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
...luid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
+2
-1
paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc
paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc
+0
-2
paddle/fluid/inference/api/CMakeLists.txt
paddle/fluid/inference/api/CMakeLists.txt
+1
-0
paddle/fluid/inference/api/analysis_predictor.cc
paddle/fluid/inference/api/analysis_predictor.cc
+2
-0
paddle/fluid/pybind/inference_api.cc
paddle/fluid/pybind/inference_api.cc
+3
-0
未找到文件。
paddle/fluid/inference/anakin/convert/CMakeLists.txt
浏览文件 @
8643dbc2
cc_library
(
anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc
softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum
.cc DEPS anakin_engine framework_proto scope op_registry
)
cc_library
(
anakin_op_converter SRCS fc.cc conv2d.cc conv2d_fusion.cc elementwise.cc activation.cc pool2d.cc concat.cc split.cc relu.cc
softmax.cc batch_norm.cc reshape.cc flatten.cc transpose.cc density_prior_box.cc detection_out.cc scale.cc dropout.cc im2sequence.cc sum.cc affine_channel.cc roi_align
.cc DEPS anakin_engine framework_proto scope op_registry
)
cc_test
(
test_anakin_fc SRCS test_fc_op.cc DEPS anakin_op_converter mul_op SERIAL
)
cc_test
(
test_anakin_conv2d SRCS test_conv2d_op.cc DEPS anakin_op_converter conv_op im2col vol2col depthwise_conv SERIAL
)
...
...
@@ -14,5 +14,5 @@ cc_test(test_anakin_flatten SRCS test_flatten_op.cc DEPS anakin_op_converter fla
cc_test
(
test_anakin_transpose SRCS test_transpose_op.cc DEPS anakin_op_converter transpose_op SERIAL
)
cc_test
(
test_anakin_batch_norm SRCS test_batch_norm_op.cc DEPS anakin_op_converter batch_norm_op SERIAL
)
cc_test
(
test_anakin_dropout SRCS test_dropout_op.cc DEPS anakin_op_converter dropout_op SERIAL
)
#cc_test(test_anakin_im2sequence SRCS test_im2sequence_op.cc DEPS anakin_op_converter im2sequence_op im2col)
cc_test
(
test_anakin_sum SRCS test_sum_op.cc DEPS anakin_op_converter sum_op selected_rows_functor SERIAL
)
cc_test
(
test_anakin_affine_channel SRCS test_affine_channel_op.cc DEPS anakin_op_converter affine_channel_op SERIAL
)
paddle/fluid/inference/anakin/convert/affine_channel.cc
0 → 100644
浏览文件 @
8643dbc2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/affine_channel.h"
#include <algorithm>
#include <string>
#include <vector>
using
anakin
::
graph
::
GraphGlobalMem
;
using
anakin
::
AK_FLOAT
;
using
anakin
::
Precision
;
using
anakin
::
saber
::
NV
;
using
anakin
::
saber
::
X86
;
using
anakin
::
saber
::
Shape
;
using
anakin
::
PBlock
;
using
anakin
::
PTuple
;
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
void
AffineChannelOpConverter
::
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
op_name
=
op_desc
.
Type
()
+
":"
+
op_desc
.
Output
(
"Out"
).
front
();
auto
input_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
// Copy the Scale to CPUPlace and get the pointer.
auto
*
scale_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Scale"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
scale_v
);
auto
*
scale_t
=
scale_v
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
unique_ptr
<
framework
::
LoDTensor
>
scale_tensor
(
new
framework
::
LoDTensor
());
scale_tensor
->
Resize
(
scale_t
->
dims
());
TensorCopySync
((
*
scale_t
),
platform
::
CPUPlace
(),
scale_tensor
.
get
());
// Copy the Bias to CPUPlace and get the pointer.
auto
*
bias_v
=
scope
.
FindVar
(
op_desc
.
Input
(
"Bias"
).
front
());
PADDLE_ENFORCE_NOT_NULL
(
bias_v
);
auto
*
bias_t
=
bias_v
->
GetMutable
<
framework
::
LoDTensor
>
();
std
::
unique_ptr
<
framework
::
LoDTensor
>
bias_tensor
(
new
framework
::
LoDTensor
());
bias_tensor
->
Resize
(
bias_t
->
dims
());
TensorCopySync
((
*
bias_t
),
platform
::
CPUPlace
(),
bias_tensor
.
get
());
engine_
->
AddOp
(
op_name
,
"AffineChannel"
,
{
input_name
},
{
output_name
});
// Generate the Scale parameter of Anakin.
auto
scale_shape
=
framework
::
vectorize2int
(
scale_t
->
dims
());
while
(
scale_shape
.
size
()
<
4
)
{
scale_shape
.
insert
(
scale_shape
.
begin
(),
1
);
}
Shape
anakin_scale_shape
(
scale_shape
);
auto
*
weight1
=
GraphGlobalMem
<
NV
>::
Global
().
template
new_block
<
AK_FLOAT
>(
anakin_scale_shape
);
float
*
scale_cpu_data
=
static_cast
<
float
*>
(
weight1
->
h_tensor
().
mutable_data
());
std
::
copy_n
(
scale_tensor
->
data
<
float
>
(),
scale_tensor
->
numel
(),
scale_cpu_data
);
weight1
->
d_tensor
().
set_shape
(
anakin_scale_shape
);
weight1
->
d_tensor
().
copy_from
(
weight1
->
h_tensor
());
engine_
->
AddOpAttr
(
op_name
,
"weight_1"
,
*
weight1
);
// Generate the Bias parameter of Anakin.
auto
bias_shape
=
framework
::
vectorize2int
(
bias_t
->
dims
());
while
(
bias_shape
.
size
()
<
4
)
{
bias_shape
.
insert
(
bias_shape
.
begin
(),
1
);
}
Shape
anakin_bias_shape
(
bias_shape
);
auto
*
weight2
=
GraphGlobalMem
<
NV
>::
Global
().
template
new_block
<
AK_FLOAT
>(
anakin_bias_shape
);
float
*
bias_cpu_data
=
static_cast
<
float
*>
(
weight2
->
h_tensor
().
mutable_data
());
std
::
copy_n
(
bias_tensor
->
data
<
float
>
(),
bias_tensor
->
numel
(),
bias_cpu_data
);
weight2
->
d_tensor
().
set_shape
(
anakin_bias_shape
);
weight2
->
d_tensor
().
copy_from
(
weight2
->
h_tensor
());
engine_
->
AddOpAttr
(
op_name
,
"weight_2"
,
*
weight2
);
}
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
REGISTER_ANAKIN_OP_CONVERTER
(
affine_channel
,
AffineChannelOpConverter
);
paddle/fluid/inference/anakin/convert/affine_channel.h
0 → 100644
浏览文件 @
8643dbc2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
class
AffineChannelOpConverter
:
public
AnakinOpConverter
{
public:
AffineChannelOpConverter
()
=
default
;
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
;
virtual
~
AffineChannelOpConverter
()
{}
private:
};
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/anakin/convert/op_converter.h
浏览文件 @
8643dbc2
...
...
@@ -81,7 +81,6 @@ class AnakinOpConverter {
const
std
::
unordered_set
<
std
::
string
>
&
parameters
,
const
std
::
vector
<
std
::
string
>
&
outputs
,
AnakinNvEngine
*
engine
)
{
ConvertBlock
(
block_desc
,
parameters
,
*
scope
,
engine
);
engine
->
Freeze
();
// if the max_batch size
int
max_batch_size
=
engine
->
GetMaxBatchSize
();
PADDLE_ENFORCE
(
max_batch_size
>
0
,
...
...
@@ -91,7 +90,12 @@ class AnakinOpConverter {
// the block_desc.
auto
max_input_shape
=
engine
->
GetMaxInputShape
();
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
temp_max_input_shape
;
// Register outputs with anakin using the RegistVar interface before Freeze.
// Note that RegistVar's parameters can only be outputs, not inputs.
for
(
auto
&
output
:
outputs
)
{
engine
->
Graph
()
->
RegistVar
(
output
);
}
engine
->
Freeze
();
for
(
auto
&
input
:
inputs
)
{
if
(
parameters
.
count
(
input
))
continue
;
std
::
vector
<
int
>
input_shape
;
...
...
@@ -99,7 +103,7 @@ class AnakinOpConverter {
input_shape
[
0
]
=
max_batch_size
;
if
(
max_input_shape
.
count
(
input
))
{
PADDLE_ENFORCE
(
max_input_shape
[
input
].
size
()
==
4
,
"the dimensions of
max_input_shape setted from "
"the dimensions of max_input_shape setted from "
"config->EnableAnakinEngine must be 4"
);
for
(
int
i
=
1
;
i
<
4
;
i
++
)
{
input_shape
[
i
]
=
max_input_shape
[
input
][
i
];
...
...
@@ -118,14 +122,10 @@ class AnakinOpConverter {
}
temp_max_input_shape
[
input
]
=
input_shape
;
engine
->
SetInputShape
(
input
,
input_shape
);
engine
->
Graph
()
->
RegistVar
(
input
);
// For share from data.
}
engine
->
SetMaxInputShape
(
temp_max_input_shape
);
engine
->
Optimize
();
// For anakin share with fluid tensor.
engine
->
AllocTmpMem
();
engine
->
InitGraph
();
engine
->
InitNet
();
}
void
SetEngine
(
AnakinNvEngine
*
engine
)
{
engine_
=
engine
;
}
...
...
paddle/fluid/inference/anakin/convert/relu.cc
浏览文件 @
8643dbc2
...
...
@@ -41,8 +41,26 @@ void ReluOpConverter::operator()(const framework::proto::OpDesc &op,
engine_
->
AddOpAttr
(
op_name
,
"alpha"
,
0
);
}
void
LeakyReluOpConverter
::
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
op_name
=
op_desc
.
Type
()
+
":"
+
op_desc
.
Output
(
"Out"
).
front
();
auto
input_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
float
alpha
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"alpha"
));
engine_
->
AddOp
(
op_name
,
"ReLU"
,
{
input_name
},
{
output_name
});
engine_
->
AddOpAttr
(
op_name
,
"alpha"
,
alpha
);
}
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
REGISTER_ANAKIN_OP_CONVERTER
(
relu
,
ReluOpConverter
);
REGISTER_ANAKIN_OP_CONVERTER
(
leaky_relu
,
LeakyReluOpConverter
);
paddle/fluid/inference/anakin/convert/relu.h
浏览文件 @
8643dbc2
...
...
@@ -33,6 +33,17 @@ class ReluOpConverter : public AnakinOpConverter {
virtual
~
ReluOpConverter
()
{}
};
class
LeakyReluOpConverter
:
public
AnakinOpConverter
{
public:
LeakyReluOpConverter
()
=
default
;
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
;
virtual
~
LeakyReluOpConverter
()
{}
};
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/anakin/convert/roi_align.cc
0 → 100644
浏览文件 @
8643dbc2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/inference/anakin/convert/roi_align.h"
#include <algorithm>
#include <map>
using
anakin
::
graph
::
GraphGlobalMem
;
using
anakin
::
AK_FLOAT
;
using
anakin
::
saber
::
NV
;
using
anakin
::
saber
::
Shape
;
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
void
RoiAlignOpConverter
::
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
{
framework
::
OpDesc
op_desc
(
op
,
nullptr
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"X"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Input
(
"ROIs"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
op_desc
.
Output
(
"Out"
).
size
(),
1
);
auto
op_name
=
op_desc
.
Type
()
+
":"
+
op_desc
.
Output
(
"Out"
).
front
();
auto
input_x_name
=
op_desc
.
Input
(
"X"
).
front
();
auto
input_rois_name
=
op_desc
.
Input
(
"ROIs"
).
front
();
auto
output_name
=
op_desc
.
Output
(
"Out"
).
front
();
auto
spatial_scale
=
boost
::
get
<
float
>
(
op_desc
.
GetAttr
(
"spatial_scale"
));
auto
pooled_height
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"pooled_height"
));
auto
pooled_width
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"pooled_width"
));
auto
sampling_ratio
=
boost
::
get
<
int
>
(
op_desc
.
GetAttr
(
"sampling_ratio"
));
engine_
->
AddOp
(
op_name
,
"RoiAlign"
,
{
input_x_name
,
input_rois_name
},
{
output_name
});
engine_
->
AddOpAttr
(
op_name
,
"spatial_scale"
,
spatial_scale
);
engine_
->
AddOpAttr
(
op_name
,
"pooled_height"
,
pooled_height
);
engine_
->
AddOpAttr
(
op_name
,
"pooled_width"
,
pooled_width
);
engine_
->
AddOpAttr
(
op_name
,
"sampling_ratio"
,
sampling_ratio
);
}
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
REGISTER_ANAKIN_OP_CONVERTER
(
roi_align
,
RoiAlignOpConverter
);
paddle/fluid/inference/anakin/convert/roi_align.h
0 → 100644
浏览文件 @
8643dbc2
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
class
RoiAlignOpConverter
:
public
AnakinOpConverter
{
public:
RoiAlignOpConverter
()
=
default
;
virtual
void
operator
()(
const
framework
::
proto
::
OpDesc
&
op
,
const
framework
::
BlockDesc
&
block_desc
,
const
framework
::
Scope
&
scope
,
bool
test_mode
)
override
;
virtual
~
RoiAlignOpConverter
()
{}
};
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
paddle/fluid/inference/anakin/convert/test_affine_channel_op.cc
0 → 100644
浏览文件 @
8643dbc2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/anakin/convert/affine_channel.h"
#include "paddle/fluid/inference/anakin/convert/op_converter.h"
#include "paddle/fluid/inference/anakin/convert/ut_helper.h"
namespace
paddle
{
namespace
inference
{
namespace
anakin
{
TEST
(
affine_channel
,
native
)
{
// Declare the difference between the inputs.
std
::
unordered_set
<
std
::
string
>
parameters
({
"scale"
,
"bias"
});
framework
::
Scope
scope
;
AnakinConvertValidation
validator
(
parameters
,
&
scope
);
validator
.
DeclInputVar
(
"x"
,
{
1
,
3
,
5
,
2
});
validator
.
DeclOutputVar
(
"out"
,
{
1
,
3
,
5
,
2
});
validator
.
DeclParamVar
(
"scale"
,
{
1
,
3
,
1
,
1
});
validator
.
DeclParamVar
(
"bias"
,
{
1
,
3
,
1
,
1
});
// Prepare Op descriptions.
framework
::
OpDesc
desc
;
desc
.
SetType
(
"affine_channel"
);
desc
.
SetInput
(
"X"
,
{
"x"
});
desc
.
SetInput
(
"Bias"
,
{
"bias"
});
desc
.
SetInput
(
"Scale"
,
{
"scale"
});
desc
.
SetOutput
(
"Out"
,
{
"out"
});
// Layout must be explicitly specified here as NCHW.
desc
.
SetAttr
(
"data_layout"
,
std
::
string
(
"NCHW"
));
validator
.
SetOp
(
*
desc
.
Proto
());
validator
.
Execute
(
1
);
}
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
USE_OP
(
affine_channel
);
USE_ANAKIN_CONVERTER
(
affine_channel
);
paddle/fluid/inference/anakin/convert/test_relu_op.cc
浏览文件 @
8643dbc2
...
...
@@ -21,7 +21,7 @@ namespace paddle {
namespace
inference
{
namespace
anakin
{
static
void
test_
activation
_op
(
const
std
::
string
&
op_type
)
{
static
void
test_
relu
_op
(
const
std
::
string
&
op_type
)
{
auto
*
converter
=
Registry
<
AnakinOpConverter
>::
Global
().
Lookup
(
op_type
);
PADDLE_ENFORCE
(
converter
!=
nullptr
);
std
::
unordered_set
<
std
::
string
>
parameters
;
...
...
@@ -33,6 +33,9 @@ static void test_activation_op(const std::string &op_type) {
desc
.
SetType
(
op_type
);
desc
.
SetInput
(
"X"
,
{
"act-X"
});
desc
.
SetOutput
(
"Out"
,
{
"act-Out"
});
if
(
op_type
==
"leaky_relu"
)
{
desc
.
SetAttr
(
"alpha"
,
0.1
f
);
}
LOG
(
INFO
)
<<
"set OP"
;
validator
.
SetOp
(
*
desc
.
Proto
());
...
...
@@ -41,10 +44,14 @@ static void test_activation_op(const std::string &op_type) {
validator
.
Execute
(
5
);
}
TEST
(
sigm_op
,
test
)
{
test_activation_op
(
"relu"
);
}
TEST
(
activation
,
relu
)
{
test_relu_op
(
"relu"
);
}
TEST
(
activation
,
leaky_relu
)
{
test_relu_op
(
"leaky_relu"
);
}
}
// namespace anakin
}
// namespace inference
}
// namespace paddle
USE_OP
(
relu
);
USE_ANAKIN_CONVERTER
(
relu
);
USE_OP
(
leaky_relu
);
USE_ANAKIN_CONVERTER
(
leaky_relu
);
paddle/fluid/inference/anakin/convert/ut_helper.h
浏览文件 @
8643dbc2
...
...
@@ -67,7 +67,7 @@ void RandomizeTensor(framework::LoDTensor* tensor, const platform::Place& place,
auto
*
temp_data
=
temp_tensor
.
mutable_data
<
float
>
(
cpu_place
);
for
(
size_t
i
=
0
;
i
<
num_elements
;
i
++
)
{
*
(
temp_data
+
i
)
=
random
(
0.
,
1
.
);
*
(
temp_data
+
i
)
=
random
(
-
128.
,
128
.
);
}
TensorCopySync
(
temp_tensor
,
place
,
tensor
);
...
...
@@ -151,7 +151,7 @@ class AnakinConvertValidation {
}
engine_
->
SetMaxInputShape
(
temp_max_input_shape
);
engine_
->
Optimize
();
engine_
->
Init
Graph
();
engine_
->
Init
Net
();
}
// We use the set 'neglected_output' here, because some Ops like batch norm,
...
...
paddle/fluid/inference/anakin/engine.cc
浏览文件 @
8643dbc2
...
...
@@ -35,12 +35,14 @@ namespace anakin {
template
<
typename
TargetT
,
Precision
PrecisionType
,
OpRunType
RunType
>
AnakinEngine
<
TargetT
,
PrecisionType
,
RunType
>::
AnakinEngine
(
bool
need_summary
,
int
device
,
int
max_batch_size
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
)
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
,
std
::
vector
<
std
::
string
>
program_inputs
)
:
graph_
(
new
AnakinGraphT
<
TargetT
,
PrecisionType
>
()),
net_
(
new
AnakinNetT
<
TargetT
,
PrecisionType
,
RunType
>
(
need_summary
))
{
device_
=
device
;
max_batch_size_
=
max_batch_size
;
max_input_shape_
=
max_input_shape
;
program_inputs_
=
program_inputs
;
}
template
<
typename
TargetT
,
Precision
PrecisionType
,
OpRunType
RunType
>
...
...
@@ -54,7 +56,7 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::SetInputShape(
}
template
<
typename
TargetT
,
Precision
PrecisionType
,
OpRunType
RunType
>
void
AnakinEngine
<
TargetT
,
PrecisionType
,
RunType
>::
Init
Graph
()
{
void
AnakinEngine
<
TargetT
,
PrecisionType
,
RunType
>::
Init
Net
()
{
net_
->
init
(
*
graph_
);
}
...
...
@@ -85,11 +87,19 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
int
max_shape_sum
=
std
::
accumulate
(
max_input_shape
.
begin
(),
max_input_shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
PADDLE_ENFORCE
(
max_shape_sum
>=
tensor
->
numel
(),
"The anakin input max shape should be greater than"
" or equal to the real input shape, Please set the max "
"input shape using EnableAnakinEngine"
);
if
(
tensor
->
numel
()
>
max_shape_sum
)
{
PADDLE_ENFORCE
(
std
::
find
(
program_inputs_
.
begin
(),
program_inputs_
.
end
(),
input
.
first
)
==
program_inputs_
.
end
(),
"The anakin input max shape should be greater than"
" or equal to the real input shape, Please set the max "
"input shape using EnableAnakinEngine"
);
VLOG
(
3
)
<<
"Anakin Net will be reset because of the inputs out of range: "
<<
input
.
first
;
graph_
->
Reshape
(
input
.
first
,
fluid_input_shape
);
net_
.
reset
(
new
AnakinNetT
<
TargetT
,
PrecisionType
,
RunType
>
(
true
));
net_
->
init
(
*
graph_
);
anakin_input
=
net_
->
get_in
(
input
.
first
);
}
anakin_input
->
reshape
(
fluid_input_shape
);
::
anakin
::
saber
::
Tensor
<
TargetT
>
tmp_anakin_tensor
(
data
,
TargetT
(),
0
,
fluid_input_shape
);
...
...
@@ -114,7 +124,7 @@ void AnakinEngine<TargetT, PrecisionType, RunType>::Execute(
template
<
typename
TargetT
,
Precision
PrecisionType
,
OpRunType
RunType
>
void
AnakinEngine
<
TargetT
,
PrecisionType
,
RunType
>::
Freeze
()
{
PADDLE_ENFORCE
(
graph_
->
Freeze
_v3
(),
"Freeze anakin subgraph."
);
PADDLE_ENFORCE
(
graph_
->
Freeze
(),
"Freeze anakin subgraph."
);
}
template
<
typename
TargetT
,
Precision
PrecisionType
,
OpRunType
RunType
>
...
...
paddle/fluid/inference/anakin/engine.h
浏览文件 @
8643dbc2
...
...
@@ -58,9 +58,10 @@ class AnakinEngine {
public:
explicit
AnakinEngine
(
bool
need_summary
=
false
,
int
device
=
0
,
int
max_batch_size
=
1
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{});
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
=
{},
std
::
vector
<
std
::
string
>
program_inputs
=
{});
~
AnakinEngine
();
void
Init
Graph
();
void
Init
Net
();
void
SetInputShape
(
const
std
::
string
&
name
,
std
::
vector
<
int
>
shape
);
void
AddOp
(
const
std
::
string
&
name
,
const
std
::
string
&
type
,
const
std
::
vector
<
std
::
string
>
&
inputs
,
...
...
@@ -81,15 +82,16 @@ class AnakinEngine {
void
SetMaxInputShape
(
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
shape
)
{
max_input_shape_
=
shape
;
}
const
std
::
vector
<
std
::
string
>
&
GetScalableInputs
()
{
return
program_inputs_
;
}
void
SetScalableInputs
(
std
::
vector
<
std
::
string
>
program_inputs
)
{
program_inputs_
=
program_inputs
;
}
int
GetMaxBatchSize
()
{
return
max_batch_size_
;
}
void
Freeze
();
void
Optimize
();
void
AllocTmpMem
()
{
PADDLE_ENFORCE
(
net_
->
alloc_memory_first
(
*
graph_
),
"anakin alloc temp memory first failed"
);
}
void
Save
(
std
::
string
path
)
{
graph_
->
save
(
path
);
}
bool
IsInit
()
{
return
initialized_
;
}
int
GetDevice
()
{
return
device_
;
}
void
Execute
(
const
std
::
map
<
std
::
string
,
framework
::
LoDTensor
*>
&
inputs
,
...
...
@@ -103,6 +105,7 @@ class AnakinEngine {
int
device_
;
std
::
unique_ptr
<
GraphT
>
graph_
;
std
::
unique_ptr
<
NetT
>
net_
;
std
::
vector
<
std
::
string
>
program_inputs_
;
};
class
AnakinEngineManager
{
...
...
@@ -120,10 +123,10 @@ class AnakinEngineManager {
AnakinNvEngineT
*
Create
(
bool
need_summary
,
int
device
,
int
max_batch_size
,
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>
max_input_shape
,
std
::
string
engine_name
)
{
std
::
vector
<
std
::
string
>
program_inputs
,
std
::
string
engine_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lk
(
mut_
);
auto
*
p
=
new
AnakinEngine
<
NV
,
Precision
::
FP32
>
(
need_summary
,
device
,
max_batch_size
,
max_input_shape
);
need_summary
,
device
,
max_batch_size
,
max_input_shape
,
program_inputs
);
engines_
[
engine_name
].
reset
(
p
);
return
p
;
}
...
...
paddle/fluid/inference/anakin/op_teller.cc
浏览文件 @
8643dbc2
...
...
@@ -44,6 +44,8 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set
.
insert
(
"sum"
);
teller_set
.
insert
(
"depthwise_conv2d"
);
teller_set
.
insert
(
"prior_box"
);
teller_set
.
insert
(
"leaky_relu"
);
teller_set
.
insert
(
"affine_channel"
);
}
bool
operator
()(
const
std
::
string
&
op_type
,
...
...
paddle/fluid/inference/anakin/test_anakin_engine.cc
浏览文件 @
8643dbc2
...
...
@@ -68,7 +68,7 @@ TEST_F(TestAnakinEngine, Execute) {
// engine_->AddOpAttr("x", "input_shape", input_shape);
engine_
->
SetInputShape
(
"x"
,
{
1
,
1
,
1
,
1
});
engine_
->
Optimize
();
engine_
->
Init
Graph
();
engine_
->
Init
Net
();
framework
::
LoDTensor
x
;
framework
::
LoDTensor
y
;
x
.
Resize
({
1
,
1
,
1
,
1
});
...
...
paddle/fluid/inference/analysis/ir_passes/anakin_subgraph_pass.cc
浏览文件 @
8643dbc2
...
...
@@ -192,11 +192,12 @@ void AnakinSubgraphPass::CreateAnakinOp(
auto
max_input_shape
=
Get
<
std
::
map
<
std
::
string
,
std
::
vector
<
int
>>>
(
"max_input_shape"
);
auto
max_batch_size
=
Get
<
int
>
(
"max_batch_size"
);
auto
program_inputs
=
program_desc
->
GetFeedTargetNames
();
auto
*
anakin_engine
=
inference
::
Singleton
<
anakin
::
AnakinEngineManager
>::
Global
().
Create
(
true
,
Get
<
int
>
(
"gpu_device_id"
),
max_batch_size
,
max_input_shape
,
engine_key
);
program_inputs
,
engine_key
);
auto
*
scope
=
param_scope
();
std
::
unordered_set
<
std
::
string
>
param_set
(
params
.
begin
(),
params
.
end
());
...
...
paddle/fluid/inference/analysis/ir_passes/subgraph_util.cc
浏览文件 @
8643dbc2
...
...
@@ -100,7 +100,6 @@ void RenameAndGetOutputs(
const
std
::
string
arg_value
=
in_var
->
arguments
(
k
);
const
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
input_names_with_id
.
count
(
arg_value_with_id
))
{
replaced_names
.
push_back
(
arg_value
);
if
(
graph_var_map
.
count
(
arg_value
))
{
...
...
@@ -149,7 +148,6 @@ void RenameAndGetOutputs(
const
std
::
string
arg_value
=
out_var
->
arguments
(
k
);
const
std
::
string
arg_value_with_id
=
arg_value
+
std
::
to_string
(
var2id
[
arg_value
]);
if
(
graph_var_map
.
count
(
arg_value
))
{
add_block_var
(
arg_value
,
arg_value_with_id
);
}
...
...
paddle/fluid/inference/api/CMakeLists.txt
浏览文件 @
8643dbc2
...
...
@@ -70,3 +70,4 @@ if (WITH_ANAKIN AND WITH_MKL) # only needed in CI
anakin_target
(
inference_anakin_api
)
anakin_target
(
inference_anakin_api_shared
)
endif
()
inference_analysis_test
(
faster_rcnn_test SRCS faster_rcnn_test.cc EXTRA_DEPS paddle_fluid
)
paddle/fluid/inference/api/analysis_predictor.cc
浏览文件 @
8643dbc2
...
...
@@ -888,4 +888,6 @@ USE_ANAKIN_CONVERTER(density_prior_box);
USE_ANAKIN_CONVERTER
(
dropout
);
USE_ANAKIN_CONVERTER
(
sum
);
USE_ANAKIN_CONVERTER
(
prior_box
);
USE_ANAKIN_CONVERTER
(
leaky_relu
);
USE_ANAKIN_CONVERTER
(
affine_channel
);
#endif
paddle/fluid/pybind/inference_api.cc
浏览文件 @
8643dbc2
...
...
@@ -229,6 +229,9 @@ void BindAnalysisConfig(py::module *m) {
py
::
arg
(
"min_subgraph_size"
)
=
3
,
py
::
arg
(
"precision_mode"
)
=
AnalysisConfig
::
Precision
::
kFloat32
,
py
::
arg
(
"use_static"
)
=
true
)
.
def
(
"enable_anakin_engine"
,
&
AnalysisConfig
::
EnableAnakinEngine
,
py
::
arg
(
"max_batch_size"
)
=
1
,
py
::
arg
(
"max_input_shape"
)
=
{},
py
::
arg
(
"min_subgraph_size"
)
=
6
)
.
def
(
"tensorrt_engine_enabled"
,
&
AnalysisConfig
::
tensorrt_engine_enabled
)
.
def
(
"switch_ir_debug"
,
&
AnalysisConfig
::
SwitchIrDebug
,
py
::
arg
(
"x"
)
=
true
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录