Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
6641a314
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6641a314
编写于
2月 15, 2019
作者:
H
hjchen2
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Delivery input lod to output for elementwise_add/top_k/activation ops
上级
276ec6f9
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
557 addition
and
325 deletion
+557
-325
src/operators/activation_op.cpp
src/operators/activation_op.cpp
+6
-5
src/operators/elementwise_add_op.cpp
src/operators/elementwise_add_op.cpp
+1
-0
src/operators/kernel/arm/beam_search_decode_kernel.cpp
src/operators/kernel/arm/beam_search_decode_kernel.cpp
+243
-10
src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp
src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp
+2
-2
src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp
src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp
+2
-2
src/operators/kernel/arm/conv_bn_relu_kernel.cpp
src/operators/kernel/arm/conv_bn_relu_kernel.cpp
+2
-2
src/operators/kernel/arm/conv_kernel.cpp
src/operators/kernel/arm/conv_kernel.cpp
+1
-1
src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp
src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp
+2
-2
src/operators/kernel/arm/sequence_softmax_kernel.cpp
src/operators/kernel/arm/sequence_softmax_kernel.cpp
+4
-6
src/operators/op_param.h
src/operators/op_param.h
+291
-291
src/operators/softmax_op.cpp
src/operators/softmax_op.cpp
+1
-0
src/operators/top_k_op.cpp
src/operators/top_k_op.cpp
+2
-4
未找到文件。
src/operators/activation_op.cpp
浏览文件 @
6641a314
...
...
@@ -22,6 +22,7 @@ namespace operators {
void OpName##Op<Dtype, T>::InferShape() const { \
const auto &input_dims = this->param_.InputX()->dims(); \
this->param_.Out()->Resize(input_dims); \
this->param_.Out()->set_lod(this->param_.InputX()->lod()); \
}
#ifdef RELU_OP
...
...
src/operators/elementwise_add_op.cpp
浏览文件 @
6641a314
...
...
@@ -23,6 +23,7 @@ template <typename Dtype, typename T>
void
ElementwiseAddOp
<
Dtype
,
T
>::
InferShape
()
const
{
auto
x_dim
=
this
->
param_
.
InputX
()
->
dims
();
this
->
param_
.
Out
()
->
Resize
(
x_dim
);
this
->
param_
.
Out
()
->
set_lod
(
this
->
param_
.
InputX
()
->
lod
());
}
}
// namespace operators
...
...
src/operators/kernel/arm/beam_search_decode_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -15,27 +15,260 @@ limitations under the License. */
#ifdef BEAM_SEARCH_DECODE_OP
#include "operators/kernel/beam_search_decode_kernel.h"
#include "framework/data_type.h"
namespace
paddle_mobile
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensorArray
=
framework
::
LoDTensorArray
;
// all the lod have 2 levels.
// The first is source level, the second is sentence level.
// source level describe how many prefixes (branchs) for each source sentece
// (beam). sentence level describe how these candidates belong to the prefixes.
const
size_t
kSourceLevel
=
0
;
const
size_t
kSentenceLevel
=
1
;
template
<
typename
T
>
struct
Sentence
{
std
::
vector
<
int64_t
>
word_ids
;
std
::
vector
<
T
>
scores
;
};
template
<
typename
T
>
using
SentenceVector
=
std
::
vector
<
Sentence
<
T
>>
;
template
<
typename
T
>
struct
BeamSearchDecoder
{
BeamSearchDecoder
(
size_t
beam_size
,
int
end_id
)
:
beam_size_
(
beam_size
),
end_id_
(
end_id
)
{}
/**
* convert the result sentence_vector for each source sentence into two
* LodTensor.
* One is all candidate sentences with word id, one is all candidate sentences
* with word score.
* Param:
* sentence_vector_list: sentence_vector for each source sentence.
* id_tensor: result LoDTensor for sentences of id.
* score_tensor: result LoDTensor for sentences of score.
* reverse: whether ids of sentence in sentence_vector_list is reversed
* sort_by_score: whether to sort hypotheses of each sentence by scores.
*/
void
ConvertSentenceVectorToLodTensor
(
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
,
bool
reverse
=
true
,
bool
sort_by_score
=
true
)
const
;
/**
* Gather the hypotheses for each source sentence by backtrace though the
* LoDTensorArray step_ids whose lods reserve the path in the tree.
*/
void
Backtrace
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
;
size_t
beam_size_
;
int
end_id_
;
};
template
<
typename
T
>
void
BeamSearchDecoder
<
T
>::
ConvertSentenceVectorToLodTensor
(
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
,
bool
reverse
,
bool
sort_by_score
)
const
{
size_t
src_num
=
sentence_vector_list
.
size
();
PADDLE_MOBILE_ENFORCE
(
src_num
>
0
,
"src_num should be larger than 0"
);
std
::
vector
<
size_t
>
source_level_lod
=
{
0
};
std
::
vector
<
size_t
>
sentence_level_lod
=
{
0
};
std
::
vector
<
int64_t
>
id_data
;
std
::
vector
<
T
>
score_data
;
for
(
size_t
src_idx
=
0
;
src_idx
<
src_num
;
++
src_idx
)
{
if
(
sort_by_score
)
{
sort
(
sentence_vector_list
[
src_idx
].
begin
(),
sentence_vector_list
[
src_idx
].
end
(),
[
reverse
](
const
Sentence
<
T
>&
a
,
const
Sentence
<
T
>&
b
)
{
if
(
reverse
)
return
a
.
scores
.
front
()
>
b
.
scores
.
front
();
else
return
a
.
scores
.
back
()
>
b
.
scores
.
back
();
});
}
for
(
Sentence
<
T
>&
sentence
:
sentence_vector_list
[
src_idx
])
{
if
(
reverse
)
{
id_data
.
insert
(
id_data
.
end
(),
sentence
.
word_ids
.
rbegin
(),
sentence
.
word_ids
.
rend
());
score_data
.
insert
(
score_data
.
end
(),
sentence
.
scores
.
rbegin
(),
sentence
.
scores
.
rend
());
}
else
{
id_data
.
insert
(
id_data
.
end
(),
sentence
.
word_ids
.
begin
(),
sentence
.
word_ids
.
end
());
score_data
.
insert
(
score_data
.
end
(),
sentence
.
scores
.
begin
(),
sentence
.
scores
.
end
());
}
sentence_level_lod
.
push_back
(
sentence_level_lod
.
back
()
+
sentence
.
word_ids
.
size
());
}
source_level_lod
.
push_back
(
source_level_lod
.
back
()
+
sentence_vector_list
[
src_idx
].
size
());
}
framework
::
LoD
lod
;
lod
.
push_back
(
source_level_lod
);
lod
.
push_back
(
sentence_level_lod
);
id_tensor
->
set_lod
(
lod
);
id_tensor
->
Resize
({
static_cast
<
int64_t
>
(
id_data
.
size
())});
id_tensor
->
mutable_data
<
int64_t
>
();
// framework::TensorFromVector<int64_t>(id_data, cpu_ctx, id_tensor);
score_tensor
->
set_lod
(
lod
);
score_tensor
->
Resize
({
static_cast
<
int64_t
>
(
score_data
.
size
())});
score_tensor
->
mutable_data
<
T
>
();
// framework::TensorFromVector<T>(score_data, cpu_ctx, score_tensor);
}
template
<
typename
T
>
void
BeamSearchDecoder
<
T
>::
Backtrace
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
{
PADDLE_MOBILE_ENFORCE
(
!
step_ids
.
empty
(),
"step num should be larger than 0"
);
PADDLE_MOBILE_ENFORCE
(
step_ids
.
size
()
==
step_scores
.
size
(),
"step_ids and step_scores should be the same"
);
const
size_t
step_num
=
step_ids
.
size
();
const
size_t
src_num
=
step_ids
.
at
(
0
).
lod
().
at
(
kSourceLevel
).
size
()
-
1
;
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
(
src_num
,
SentenceVector
<
T
>
(
beam_size_
));
std
::
vector
<
std
::
vector
<
size_t
>>
prefix_idx_vector_list
(
src_num
);
for
(
int
step_id
=
step_num
-
1
;
step_id
>=
0
;
--
step_id
)
{
auto
&
cur_ids
=
step_ids
.
at
(
step_id
);
auto
&
cur_scores
=
step_scores
.
at
(
step_id
);
for
(
size_t
src_idx
=
0
;
src_idx
<
src_num
;
++
src_idx
)
{
// for each source sentence
auto
&
sentence_vector
=
sentence_vector_list
.
at
(
src_idx
);
auto
&
prefix_idx_vector
=
prefix_idx_vector_list
.
at
(
src_idx
);
size_t
src_prefix_start
=
cur_ids
.
lod
().
at
(
kSourceLevel
)[
src_idx
];
size_t
src_prefix_end
=
cur_ids
.
lod
().
at
(
kSourceLevel
)[
src_idx
+
1
];
if
(
prefix_idx_vector
.
empty
())
{
// be finished and pruned at this step
// or the last time step
for
(
size_t
prefix_idx
=
src_prefix_start
;
prefix_idx
<
src_prefix_end
;
++
prefix_idx
)
{
size_t
candidate_start
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
];
size_t
candidate_end
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
+
1
];
for
(
size_t
candidate_idx
=
candidate_start
;
candidate_idx
<
candidate_end
;
++
candidate_idx
)
{
prefix_idx_vector
.
push_back
(
prefix_idx
);
size_t
idx
=
prefix_idx_vector
.
size
()
-
1
;
auto
cur_id
=
cur_ids
.
data
<
int64_t
>
()[
candidate_idx
];
auto
cur_score
=
cur_scores
.
data
<
T
>
()[
candidate_idx
];
sentence_vector
.
at
(
idx
).
word_ids
.
push_back
(
cur_id
);
sentence_vector
.
at
(
idx
).
scores
.
push_back
(
cur_score
);
}
}
}
else
{
// use prefix_idx_vector to backtrace
size_t
src_candidate_start
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
src_prefix_start
];
size_t
prefix_idx
=
src_prefix_start
;
size_t
candidate_num
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
+
1
]
-
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
];
for
(
size_t
idx
=
0
;
idx
<
prefix_idx_vector
.
size
();
++
idx
)
{
auto
candidate_idx
=
prefix_idx_vector
.
at
(
idx
);
auto
cur_id
=
cur_ids
.
data
<
int64_t
>
()[
candidate_idx
];
auto
cur_score
=
cur_scores
.
data
<
T
>
()[
candidate_idx
];
if
(
cur_id
!=
end_id_
||
sentence_vector
.
at
(
idx
).
word_ids
.
empty
())
{
// to skip redundant end tokens
sentence_vector
.
at
(
idx
).
word_ids
.
push_back
(
cur_id
);
sentence_vector
.
at
(
idx
).
scores
.
push_back
(
cur_score
);
}
while
(
src_candidate_start
+
candidate_num
<=
candidate_idx
)
{
// search the corresponding prefix
prefix_idx
++
;
candidate_num
+=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
+
1
]
-
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
];
}
prefix_idx_vector
.
at
(
idx
)
=
prefix_idx
;
}
}
}
}
ConvertSentenceVectorToLodTensor
(
sentence_vector_list
,
id_tensor
,
score_tensor
,
true
,
true
);
}
struct
BeamSearchDecodeFunctor
{
BeamSearchDecodeFunctor
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
,
size_t
beam_size
,
int
end_id
)
:
beam_size_
(
beam_size
),
end_id_
(
end_id
),
step_ids_
(
step_ids
),
step_scores_
(
step_scores
),
id_tensor_
(
id_tensor
),
score_tensor_
(
score_tensor
)
{}
template
<
typename
T
>
void
apply
()
const
;
size_t
beam_size_
;
int
end_id_
;
const
LoDTensorArray
&
step_ids_
;
const
LoDTensorArray
&
step_scores_
;
LoDTensor
*
id_tensor_
;
LoDTensor
*
score_tensor_
;
};
template
<
typename
T
>
void
BeamSearchDecodeFunctor
::
apply
()
const
{
BeamSearchDecoder
<
T
>
beam_search_decoder
(
beam_size_
,
end_id_
);
beam_search_decoder
.
Backtrace
(
step_ids_
,
step_scores_
,
id_tensor_
,
score_tensor_
);
}
template
<
>
void
BeamSearchDecodeFunctor
::
apply
<
bool
>
()
const
{
PADDLE_MOBILE_THROW_EXCEPTION
(
"beam search decode op does not support bool."
);
}
template
<
>
bool
BeamSearchDecodeKernel
<
CPU
,
float
>::
Init
(
BeamSearchDecodeParam
<
CPU
>
*
param
)
{
BeamSearchDecodeParam
<
CPU
>
*
param
)
{
return
true
;
}
template
<
>
void
BeamSearchDecodeKernel
<
CPU
,
float
>::
Compute
(
const
BeamSearchDecodeParam
<
CPU
>
&
param
)
{
// TODO(hjchen2)
DLOG
<<
"BeamSearchDecodeKernel"
;
param
.
sentence_scores_
->
Resize
(
framework
::
make_ddim
({
10
}));
param
.
sentence_scores_
->
mutable_data
<
float
>
();
DLOG
<<
"BeamSearchDecodeKernel"
;
param
.
sentence_ids_
->
Resize
(
framework
::
make_ddim
({
10
}));
param
.
sentence_ids_
->
mutable_data
<
int64_t
>
();
const
BeamSearchDecodeParam
<
CPU
>&
param
)
{
const
LoDTensorArray
*
ids
=
param
.
ids_
;
const
LoDTensorArray
*
scores
=
param
.
scores_
;
const
size_t
step_num
=
ids
->
size
();
PADDLE_MOBILE_ENFORCE
(
step_num
>
0
,
"beam search steps should be larger than 0"
);
for
(
size_t
i
=
0
;
i
<
step_num
;
++
i
)
{
PADDLE_MOBILE_ENFORCE
(
ids
->
at
(
i
).
lod
().
size
()
==
2
,
"Level of LodTensor should be 2"
);
}
const
size_t
source_num
=
ids
->
at
(
0
).
lod
().
at
(
0
).
size
()
-
1
;
PADDLE_MOBILE_ENFORCE
(
source_num
>
0
,
"source num should be larger than 0"
);
LoDTensor
*
sentence_ids
=
param
.
sentence_ids_
;
LoDTensor
*
sentence_scores
=
param
.
sentence_scores_
;
framework
::
VisitDataType
(
framework
::
ToDataType
(
scores
->
at
(
0
).
type
()),
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentence_ids
,
sentence_scores
,
param
.
beam_size_
,
param
.
end_id_
));
}
}
// namespace operators
...
...
src/operators/kernel/arm/conv_add_bn_relu_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -41,8 +41,8 @@ bool ConvAddBNReluKernel<CPU, float>::Init(
inv_std_ptr
[
i
]
=
1
/
static_cast
<
float
>
(
pow
((
variance_ptr
[
i
]
+
epsilon
),
0.5
));
}
Tensor
*
new_scale
=
new
Tensor
();
Tensor
*
new_bias
=
new
Tensor
();
LoDTensor
*
new_scale
=
new
LoD
Tensor
();
LoDTensor
*
new_bias
=
new
LoD
Tensor
();
auto
new_scale_ptr
=
new_scale
->
mutable_data
<
float
>
({
C
});
auto
new_bias_ptr
=
new_bias
->
mutable_data
<
float
>
({
C
});
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
...
...
src/operators/kernel/arm/conv_bn_add_relu_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -41,8 +41,8 @@ bool ConvBNAddReluKernel<CPU, float>::Init(
inv_std_ptr
[
i
]
=
1
/
static_cast
<
float
>
(
pow
((
variance_ptr
[
i
]
+
epsilon
),
0.5
));
}
Tensor
*
new_scale
=
new
Tensor
();
Tensor
*
new_bias
=
new
Tensor
();
LoDTensor
*
new_scale
=
new
LoD
Tensor
();
LoDTensor
*
new_bias
=
new
LoD
Tensor
();
auto
new_scale_ptr
=
new_scale
->
mutable_data
<
float
>
({
C
});
auto
new_bias_ptr
=
new_bias
->
mutable_data
<
float
>
({
C
});
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
...
...
src/operators/kernel/arm/conv_bn_relu_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -42,8 +42,8 @@ bool ConvBNReluKernel<CPU, float>::Init(FusionConvBNReluParam<CPU> *param) {
inv_std_ptr
[
i
]
=
1
/
static_cast
<
float
>
(
pow
((
variance_ptr
[
i
]
+
epsilon
),
0.5
));
}
Tensor
*
new_scale
=
new
Tensor
();
Tensor
*
new_bias
=
new
Tensor
();
LoDTensor
*
new_scale
=
new
LoD
Tensor
();
LoDTensor
*
new_bias
=
new
LoD
Tensor
();
auto
new_scale_ptr
=
new_scale
->
mutable_data
<
float
>
({
C
});
auto
new_bias_ptr
=
new_bias
->
mutable_data
<
float
>
({
C
});
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
...
...
src/operators/kernel/arm/conv_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -69,7 +69,7 @@ bool ConvKernel<CPU, float>::Init(ConvParam<CPU> *param) {
param
->
Input
()
->
dims
()[
2
]
<=
140
/* refered from ncnn */
)
{
param
->
ExecMode
()
=
ConvParam
<
CPU
>::
EXEC_WINOGRAD3X3_FLOAT
;
// transform weight
param
->
transformed_filter_
=
new
framework
::
Tensor
;
param
->
transformed_filter_
=
new
framework
::
LoD
Tensor
;
operators
::
math
::
winograd_transform_weight
<
8
,
3
>
(
*
param
->
Filter
(),
param
->
transformed_filter_
);
#endif
...
...
src/operators/kernel/arm/dwconv_bn_relu_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -40,8 +40,8 @@ bool DWConvBNReluKernel<CPU, float>::Init(FusionDWConvBNReluParam<CPU> *param) {
inv_std_ptr
[
i
]
=
1
/
static_cast
<
float
>
(
pow
((
variance_ptr
[
i
]
+
epsilon
),
0.5
));
}
Tensor
*
new_scale
=
new
Tensor
();
Tensor
*
new_bias
=
new
Tensor
();
LoDTensor
*
new_scale
=
new
LoD
Tensor
();
LoDTensor
*
new_bias
=
new
LoD
Tensor
();
auto
new_scale_ptr
=
new_scale
->
mutable_data
<
float
>
({
C
});
auto
new_bias_ptr
=
new_bias
->
mutable_data
<
float
>
({
C
});
for
(
int
i
=
0
;
i
<
C
;
i
++
)
{
...
...
src/operators/kernel/arm/sequence_softmax_kernel.cpp
浏览文件 @
6641a314
...
...
@@ -29,12 +29,10 @@ class SequenceSoftmaxKernel<CPU, T>
void
Compute
(
const
SoftmaxParam
<
CPU
>
&
param
)
{
param
.
Out
()
->
mutable_data
<
float
>
();
/*
const
framework
::
LoDTensor
*
input
=
param
.
InputX
();
framework
::
LoDTensor
*
output
=
param
.
Out
();
math
::
SequenceSoftmaxFuntor
<
CPU
,
T
>
sequence_softmax
;
sequence_softmax
(
input
,
output
);
*/
}
};
...
...
src/operators/op_param.h
浏览文件 @
6641a314
...
...
@@ -426,11 +426,11 @@ class ConvParam : public OpParam {
groups
=
OpParam
::
GetAttr
<
int
>
(
"groups"
,
attrs
);
}
const
R
Type
*
Input
()
const
{
return
input_
;
}
const
G
Type
*
Input
()
const
{
return
input_
;
}
R
Type
*
Filter
()
const
{
return
filter_
;
}
G
Type
*
Filter
()
const
{
return
filter_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
vector
<
int
>
&
Strides
()
const
{
return
strides_
;
}
...
...
@@ -465,10 +465,10 @@ class ConvParam : public OpParam {
#endif
public:
R
Type
*
input_
;
R
Type
*
output_
;
R
Type
*
filter_
;
R
Type
*
transformed_filter_
;
G
Type
*
input_
;
G
Type
*
output_
;
G
Type
*
filter_
;
G
Type
*
transformed_filter_
;
vector
<
int
>
strides_
;
vector
<
int
>
paddings_
;
vector
<
int
>
dilations_
;
...
...
@@ -728,11 +728,11 @@ class LrnParam : public OpParam {
data_format_
=
GetStringAttr
(
"data_format"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
R
Type
*
MidOut
()
const
{
return
mid_out_
;
}
G
Type
*
MidOut
()
const
{
return
mid_out_
;
}
const
int
&
N
()
const
{
return
n_
;
}
...
...
@@ -745,9 +745,9 @@ class LrnParam : public OpParam {
const
string
&
DataFormat
()
const
{
return
data_format_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
R
Type
*
mid_out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
G
Type
*
mid_out_
;
int
n_
;
float
alpha_
;
float
beta_
;
...
...
@@ -772,20 +772,20 @@ class NormParam : OpParam {
axis_
=
GetAttr
<
int
>
(
"axis"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
R
Type
*
OutputNorm
()
const
{
return
output_norm_
;
}
G
Type
*
OutputNorm
()
const
{
return
output_norm_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
R
Type
*
output_norm_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
G
Type
*
output_norm_
;
float
epsilon_
;
int
axis_
;
};
...
...
@@ -811,17 +811,17 @@ class BatchNormParam : OpParam {
// is_test_ = GetAttr<bool>("is_test", attrs);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
OutputY
()
const
{
return
output_y_
;
}
G
Type
*
OutputY
()
const
{
return
output_y_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -831,27 +831,27 @@ class BatchNormParam : OpParam {
const
string
&
DataFormat
()
const
{
return
data_format_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
output_y_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
input_x_
;
G
Type
*
output_y_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
string
data_format_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
};
#endif
...
...
@@ -875,9 +875,9 @@ class PoolParam : public OpParam {
global_pooling_
=
GetAttr
<
bool
>
(
"global_pooling"
,
attrs
);
}
const
R
Type
*
Input
()
const
{
return
input_
;
}
const
G
Type
*
Input
()
const
{
return
input_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
string
&
PoolingType
()
const
{
return
pooling_type_
;
}
...
...
@@ -892,8 +892,8 @@ class PoolParam : public OpParam {
bool
isGlobalPooling
()
const
{
return
global_pooling_
;
}
private:
R
Type
*
input_
;
R
Type
*
output_
;
G
Type
*
input_
;
G
Type
*
output_
;
string
pooling_type_
;
vector
<
int
>
ksize_
;
vector
<
int
>
strides_
;
...
...
@@ -942,13 +942,13 @@ class PriorBoxParam : public OpParam {
step_h_
=
GetAttr
<
float
>
(
"step_h"
,
attrs
);
offset_
=
GetAttr
<
float
>
(
"offset"
,
attrs
);
}
const
R
Type
*
Input
()
const
{
return
input_
;
}
const
G
Type
*
Input
()
const
{
return
input_
;
}
const
R
Type
*
InputImage
()
const
{
return
input_image_
;
}
const
G
Type
*
InputImage
()
const
{
return
input_image_
;
}
R
Type
*
OutputBoxes
()
const
{
return
output_boxes_
;
}
G
Type
*
OutputBoxes
()
const
{
return
output_boxes_
;
}
R
Type
*
OutputVariances
()
const
{
return
output_variances_
;
}
G
Type
*
OutputVariances
()
const
{
return
output_variances_
;
}
const
vector
<
float
>
&
MinSizes
()
const
{
return
min_sizes_
;
}
...
...
@@ -973,10 +973,10 @@ class PriorBoxParam : public OpParam {
}
private:
R
Type
*
input_
;
R
Type
*
input_image_
;
R
Type
*
output_boxes_
;
R
Type
*
output_variances_
;
G
Type
*
input_
;
G
Type
*
input_image_
;
G
Type
*
output_boxes_
;
G
Type
*
output_variances_
;
vector
<
float
>
min_sizes_
;
vector
<
float
>
max_sizes_
;
vector
<
float
>
aspect_ratios_
;
...
...
@@ -1005,21 +1005,21 @@ class BoxCoderParam : public OpParam {
output_box_
=
OutputBoxFrom
<
GType
>
(
outputs
,
scope
);
code_type_
=
GetStringAttr
(
"code_type"
,
attrs
);
}
const
R
Type
*
InputPriorBox
()
const
{
return
input_priorbox_
;
}
const
G
Type
*
InputPriorBox
()
const
{
return
input_priorbox_
;
}
const
R
Type
*
InputPriorBoxVar
()
const
{
return
input_priorboxvar_
;
}
const
G
Type
*
InputPriorBoxVar
()
const
{
return
input_priorboxvar_
;
}
const
R
Type
*
InputTargetBox
()
const
{
return
input_targetbox_
;
}
const
G
Type
*
InputTargetBox
()
const
{
return
input_targetbox_
;
}
R
Type
*
OutputBox
()
const
{
return
output_box_
;
}
G
Type
*
OutputBox
()
const
{
return
output_box_
;
}
const
std
::
string
&
CodeType
()
const
{
return
code_type_
;
}
private:
R
Type
*
input_priorbox_
;
R
Type
*
input_priorboxvar_
;
R
Type
*
input_targetbox_
;
R
Type
*
output_box_
;
G
Type
*
input_priorbox_
;
G
Type
*
input_priorboxvar_
;
G
Type
*
input_targetbox_
;
G
Type
*
output_box_
;
std
::
string
code_type_
;
};
#endif
...
...
@@ -1046,11 +1046,11 @@ class SoftmaxParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
std
::
shared_ptr
<
R
Type
>
float_input_x_
;
std
::
shared_ptr
<
G
Type
>
float_input_x_
;
fpga
::
BypassArgs
fpga_bypass_args
;
public:
R
Type
*
FloatInput
()
const
{
G
Type
*
FloatInput
()
const
{
return
float_input_x_
==
nullptr
?
input_x_
:
float_input_x_
.
get
();
}
void
SetFloatInput
(
Tensor
*
input
)
{
float_input_x_
.
reset
(
input
);
}
...
...
@@ -1072,12 +1072,12 @@ class SigmoidParam : public OpParam {
input_x_
=
InputXFrom
<
GType
>
(
inputs
,
scope
);
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
#ifdef PADDLE_MOBILE_FPGA
private:
...
...
@@ -1111,11 +1111,11 @@ class MultiClassNMSParam : public OpParam {
score_threshold_
=
GetAttr
<
float
>
(
"score_threshold"
,
attrs
);
}
R
Type
*
InputBBoxes
()
const
{
return
input_bboxes_
;
}
G
Type
*
InputBBoxes
()
const
{
return
input_bboxes_
;
}
R
Type
*
InputScores
()
const
{
return
input_scores_
;
}
G
Type
*
InputScores
()
const
{
return
input_scores_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
int
&
BackGroundLabel
()
const
{
return
background_label_
;
}
...
...
@@ -1130,9 +1130,9 @@ class MultiClassNMSParam : public OpParam {
const
float
&
ScoreThreshold
()
const
{
return
score_threshold_
;
}
private:
R
Type
*
input_bboxes_
;
R
Type
*
input_scores_
;
R
Type
*
out_
;
G
Type
*
input_bboxes_
;
G
Type
*
input_scores_
;
G
Type
*
out_
;
int
background_label_
;
int
nms_top_k_
;
int
keep_top_k_
;
...
...
@@ -1155,12 +1155,12 @@ class PolygonBoxTransformParam : public OpParam {
input_
=
InputFrom
<
GType
>
(
inputs
,
scope
);
output_
=
OutputFrom
<
GType
>
(
outputs
,
scope
);
}
const
R
Type
*
Input
()
const
{
return
input_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
const
G
Type
*
Input
()
const
{
return
input_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
private:
R
Type
*
input_
;
R
Type
*
output_
;
G
Type
*
input_
;
G
Type
*
output_
;
};
#endif
...
...
@@ -1214,11 +1214,11 @@ class FetchParam : public OpParam {
#ifdef PADDLE_MOBILE_FPGA
private:
std
::
shared_ptr
<
R
Type
>
float_input_x_
;
std
::
shared_ptr
<
G
Type
>
float_input_x_
;
fpga
::
BypassArgs
fpga_bypass_args
;
public:
R
Type
*
FloatInput
()
const
{
G
Type
*
FloatInput
()
const
{
return
float_input_x_
==
nullptr
?
input_x_
:
float_input_x_
.
get
();
}
void
SetFloatInput
(
Tensor
*
input
)
{
float_input_x_
.
reset
(
input
);
}
...
...
@@ -1246,7 +1246,7 @@ class FillConstantParam : public OpParam {
Variable
*
OutVar
()
const
{
return
out_var_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
int
&
DataDtype
()
const
{
return
dtype_
;
}
...
...
@@ -1256,7 +1256,7 @@ class FillConstantParam : public OpParam {
private:
Variable
*
out_var_
;
R
Type
*
out_
;
G
Type
*
out_
;
int
dtype_
;
vector
<
int
>
shape_
;
float
value_
;
...
...
@@ -1277,15 +1277,15 @@ class TransposeParam : public OpParam {
axis_
=
GetAttr
<
vector
<
int
>>
(
"axis"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
vector
<
int
>
&
Axis
()
const
{
return
axis_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
vector
<
int
>
axis_
;
};
#endif
...
...
@@ -1305,18 +1305,18 @@ class Transpose2Param : public OpParam {
axis_
=
GetAttr
<
vector
<
int
>>
(
"axis"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
R
Type
*
OutputXShape
()
const
{
return
output_xshape_
;
}
G
Type
*
OutputXShape
()
const
{
return
output_xshape_
;
}
const
vector
<
int
>
&
Axis
()
const
{
return
axis_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
R
Type
*
output_xshape_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
G
Type
*
output_xshape_
;
vector
<
int
>
axis_
;
};
#endif
...
...
@@ -1371,8 +1371,8 @@ class CrfParam : public OpParam {
const
GType
*
InputTransition
()
const
{
return
input_transition_
;
}
const
GType
*
InputLabel
()
const
{
return
input_label_
;
}
GType
*
outputVBP
()
const
{
return
output_viterbipath_
;
}
// const
R
Type *InputIds() const { return input_ids_; }
//
R
Type *Out() const { return out_; }
// const
G
Type *InputIds() const { return input_ids_; }
//
G
Type *Out() const { return out_; }
// int64_t PaddingIdx() const { return padding_idx_; }
private:
...
...
@@ -1381,8 +1381,8 @@ class CrfParam : public OpParam {
GType
*
input_label_
;
GType
*
output_viterbipath_
;
//
R
Type *input_ids_;
//
R
Type *out_;
//
G
Type *input_ids_;
//
G
Type *out_;
// int64_t padding_idx_;
};
#endif
...
...
@@ -1409,20 +1409,20 @@ class ReshapeParam : public OpParam {
}
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
const
R
Type
*
InputShape
()
const
{
return
input_shape_
;
}
const
G
Type
*
InputShape
()
const
{
return
input_shape_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
vector
<
int
>
&
Shape
()
const
{
return
shape_
;
}
const
bool
&
Inplace
()
const
{
return
inplace_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
input_shape_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
input_shape_
;
G
Type
*
out_
;
vector
<
int
>
shape_
;
bool
inplace_
;
};
...
...
@@ -1489,11 +1489,11 @@ class ScaleParam : public OpParam {
biases_
=
GetAttr
<
vector
<
float
>>
(
"biases"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
bool
&
Inplace
()
const
{
return
inplace_
;
}
...
...
@@ -1504,9 +1504,9 @@ class ScaleParam : public OpParam {
const
vector
<
float
>
&
Biases
()
const
{
return
biases_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
input_bias_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
input_bias_
;
G
Type
*
out_
;
bool
inplace_
;
bool
has_bias_
;
vector
<
float
>
scales_
;
...
...
@@ -1559,11 +1559,11 @@ class ResizeParam : public OpParam {
out_width_scale_
=
GetAttr
<
float
>
(
"out_width_scale"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
const
R
Type
*
InputShape
()
const
{
return
input_shape_
;
}
const
G
Type
*
InputShape
()
const
{
return
input_shape_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
bool
&
IsPyramidTest
()
const
{
return
is_pyramid_test_
;
}
...
...
@@ -1576,9 +1576,9 @@ class ResizeParam : public OpParam {
const
float
&
OutWidthScale
()
const
{
return
out_width_scale_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
input_shape_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
input_shape_
;
G
Type
*
out_
;
bool
is_pyramid_test_
;
int
height_
;
int
width_
;
...
...
@@ -1603,13 +1603,13 @@ class ReluParamBase : public OpParam {
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
};
template
<
typename
Dtype
>
...
...
@@ -1644,20 +1644,20 @@ class TanhParam : public OpParam {
input_x_
=
InputXFrom
<
GType
>
(
inputs
,
scope
);
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
#ifdef PADDLE_MOBILE_FPGA
private:
std
::
shared_ptr
<
R
Type
>
float_input_x_
;
std
::
shared_ptr
<
G
Type
>
float_input_x_
;
fpga
::
BypassArgs
fpga_bypass_args
;
public:
R
Type
*
FloatInput
()
const
{
G
Type
*
FloatInput
()
const
{
return
float_input_x_
==
nullptr
?
input_x_
:
float_input_x_
.
get
();
}
void
SetFloatInput
(
Tensor
*
input
)
{
float_input_x_
.
reset
(
input
);
}
...
...
@@ -1684,15 +1684,15 @@ class PReluParam : public OpParam {
mode_
=
GetStringAttr
(
"mode"
,
attrs
);
DLOG
<<
"PReluParam mode after"
<<
mode_
;
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
R
Type
*
InputAlpha
()
const
{
return
alpha_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputAlpha
()
const
{
return
alpha_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
std
::
string
&
Mode
()
const
{
return
mode_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
R
Type
*
alpha_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
G
Type
*
alpha_
;
std
::
string
mode_
;
};
#endif
...
...
@@ -1715,9 +1715,9 @@ class FusionFcParam : public OpParam {
}
GType
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
InputY
()
const
{
return
input_y_
;
}
G
Type
*
InputY
()
const
{
return
input_y_
;
}
R
Type
*
InputZ
()
const
{
return
input_z_
;
}
G
Type
*
InputZ
()
const
{
return
input_z_
;
}
GType
*
Out
()
const
{
return
out_
;
}
...
...
@@ -1729,8 +1729,8 @@ class FusionFcParam : public OpParam {
private:
GType
*
input_x_
;
R
Type
*
input_y_
;
R
Type
*
input_z_
;
G
Type
*
input_y_
;
G
Type
*
input_z_
;
GType
*
out_
;
int
x_num_col_dims_
;
int
y_num_col_dims_
;
...
...
@@ -1765,16 +1765,16 @@ class FusionConvAddParam : public ConvParam<Dtype> {
axis_
=
OpParam
::
GetAttr
<
int
>
(
"axis"
,
attrs
);
output_
=
OpParam
::
OutFrom
<
GType
>
(
outputs
,
scope
);
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_
;
G
Type
*
output_
;
};
template
<
typename
Dtype
>
...
...
@@ -1809,17 +1809,17 @@ class FusionConvAddPReluParam : public ConvParam<Dtype> {
axis_
=
OpParam
::
GetAttr
<
int
>
(
"axis"
,
attrs
);
output_
=
OpParam
::
OutFrom
<
GType
>
(
outputs
,
scope
);
}
const
R
Type
*
InputAlpha
()
const
{
return
alpha_
;
}
const
G
Type
*
InputAlpha
()
const
{
return
alpha_
;
}
const
std
::
string
&
Mode
()
const
{
return
mode_
;
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_
;
R
Type
*
alpha_
;
G
Type
*
output_
;
G
Type
*
alpha_
;
std
::
string
mode_
;
};
#endif
...
...
@@ -1851,22 +1851,22 @@ class FusionConvAddAddPReluParam : public ConvParam<Dtype> {
bias1_
=
OpParam
::
InputXFrom1
<
GType
>
(
inputs
,
scope
);
}
}
const
R
Type
*
InputAlpha
()
const
{
return
alpha_
;
}
const
G
Type
*
InputAlpha
()
const
{
return
alpha_
;
}
const
std
::
string
&
Mode
()
const
{
return
mode_
;
}
const
R
Type
*
Bias1
()
const
{
return
bias1_
;
}
const
G
Type
*
Bias1
()
const
{
return
bias1_
;
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_
;
R
Type
*
alpha_
;
G
Type
*
output_
;
G
Type
*
alpha_
;
std
::
string
mode_
;
R
Type
*
bias1_
;
G
Type
*
bias1_
;
std
::
string
keyOutput_
;
std
::
string
keyX1_
;
std
::
string
keyY1_
;
...
...
@@ -1895,19 +1895,19 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
momentum_
=
OpParam
::
GetAttr
<
float
>
(
"momentum"
,
attrs
);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -1915,27 +1915,27 @@ class FusionConvAddBNReluParam : public ConvParam<Dtype> {
const
bool
&
IsTest
()
const
{
return
is_test_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
output_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
};
#endif
...
...
@@ -1969,19 +1969,19 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
}
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -1989,27 +1989,27 @@ class FusionConvBNAddReluParam : public ConvParam<Dtype> {
const
bool
&
IsTest
()
const
{
return
is_test_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
output_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
std
::
string
keyBNY_
;
std
::
string
keyX_
;
std
::
string
keyY_
;
...
...
@@ -2036,15 +2036,15 @@ class FusionConvBNParam : public ConvParam<Dtype> {
momentum_
=
OpParam
::
GetAttr
<
float
>
(
"momentum"
,
attrs
);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
R
Type
*
Output
()
const
{
return
output_y_
;
}
G
Type
*
Output
()
const
{
return
output_y_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -2052,25 +2052,25 @@ class FusionConvBNParam : public ConvParam<Dtype> {
const
bool
&
IsTest
()
const
{
return
is_test_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
protected:
R
Type
*
output_y_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
output_y_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
};
#endif
...
...
@@ -2096,19 +2096,19 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
momentum_
=
OpParam
::
GetAttr
<
float
>
(
"momentum"
,
attrs
);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_y_
;
}
G
Type
*
Output
()
const
{
return
output_y_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -2116,27 +2116,27 @@ class FusionConvAddBNParam : public ConvParam<Dtype> {
const
bool
&
IsTest
()
const
{
return
is_test_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_y_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
output_y_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
};
#endif
...
...
@@ -2160,15 +2160,15 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
momentum_
=
OpParam
::
GetAttr
<
float
>
(
"momentum"
,
attrs
);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -2176,25 +2176,25 @@ class FusionDWConvBNReluParam : public ConvParam<Dtype> {
const
bool
&
IsTest
()
const
{
return
is_test_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
protected:
R
Type
*
output_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
output_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
};
#endif
...
...
@@ -2219,15 +2219,15 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
momentum_
=
OpParam
::
GetAttr
<
float
>
(
"momentum"
,
attrs
);
// is_test_ = OpParam::GetAttr<bool>("is_test", attrs);
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
R
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
G
Type
*
InputBias
()
const
{
return
input_bias_
;
}
const
R
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
G
Type
*
InputMean
()
const
{
return
input_mean_
;
}
const
R
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
G
Type
*
InputScale
()
const
{
return
input_scale_
;
}
const
R
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
G
Type
*
InputVariance
()
const
{
return
input_variance_
;
}
const
float
&
Epsilon
()
const
{
return
epsilon_
;
}
...
...
@@ -2235,25 +2235,25 @@ class FusionConvBNReluParam : public ConvParam<Dtype> {
const
bool
&
IsTest
()
const
{
return
is_test_
;
}
void
SetNewScale
(
R
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewScale
(
G
Type
*
new_scale
)
{
new_scale_
=
new_scale
;
}
void
SetNewBias
(
R
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
void
SetNewBias
(
G
Type
*
new_bias
)
{
new_bias_
=
new_bias
;
}
const
R
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
G
Type
*
NewScale
()
const
{
return
new_scale_
;
}
const
R
Type
*
NewBias
()
const
{
return
new_bias_
;
}
const
G
Type
*
NewBias
()
const
{
return
new_bias_
;
}
protected:
R
Type
*
output_
;
R
Type
*
input_bias_
;
R
Type
*
input_mean_
;
R
Type
*
input_scale_
;
R
Type
*
input_variance_
;
G
Type
*
output_
;
G
Type
*
input_bias_
;
G
Type
*
input_mean_
;
G
Type
*
input_scale_
;
G
Type
*
input_variance_
;
float
epsilon_
;
float
momentum_
;
bool
is_test_
;
R
Type
*
new_bias_
;
R
Type
*
new_scale_
;
G
Type
*
new_bias_
;
G
Type
*
new_scale_
;
};
#endif
...
...
@@ -2308,15 +2308,15 @@ class DropoutParam : public OpParam {
dropout_prob_
=
GetAttr
<
float
>
(
"dropout_prob"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
float
DropoutProb
()
const
{
return
dropout_prob_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
float
dropout_prob_
;
};
#endif
...
...
@@ -2342,11 +2342,11 @@ class ConvTransposeParam : public OpParam {
groups
=
GetAttr
<
int
>
(
"groups"
,
attrs
);
}
const
R
Type
*
Input
()
const
{
return
input_
;
}
const
G
Type
*
Input
()
const
{
return
input_
;
}
const
R
Type
*
Filter
()
const
{
return
filter_
;
}
const
G
Type
*
Filter
()
const
{
return
filter_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
const
vector
<
int
>
&
Strides
()
const
{
return
strides_
;
}
...
...
@@ -2357,9 +2357,9 @@ class ConvTransposeParam : public OpParam {
const
int
&
Groups
()
const
{
return
groups
;
}
private:
R
Type
*
input_
;
R
Type
*
output_
;
R
Type
*
filter_
;
G
Type
*
input_
;
G
Type
*
output_
;
G
Type
*
filter_
;
vector
<
int
>
strides_
;
vector
<
int
>
paddings_
;
vector
<
int
>
dilations_
;
...
...
@@ -2398,16 +2398,16 @@ class FusionDeconvAddParam : public ConvTransposeParam<Dtype> {
axis_
=
OpParam
::
GetAttr
<
int
>
(
"axis"
,
attrs
);
output_
=
OpParam
::
OutFrom
<
GType
>
(
outputs
,
scope
);
}
R
Type
*
Bias
()
const
{
return
bias_
;
}
G
Type
*
Bias
()
const
{
return
bias_
;
}
const
int
&
Axis
()
const
{
return
axis_
;
}
R
Type
*
Output
()
const
{
return
output_
;
}
G
Type
*
Output
()
const
{
return
output_
;
}
protected:
R
Type
*
bias_
;
G
Type
*
bias_
;
int
axis_
;
R
Type
*
output_
;
G
Type
*
output_
;
};
#endif
...
...
@@ -2539,13 +2539,13 @@ class FlattenParam : public OpParam {
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
axis
=
GetAttr
<
int
>
(
"axis"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
const
int
&
Axis
()
const
{
return
axis
;
}
private:
R
Type
*
input_x_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
out_
;
int
axis
;
};
#endif
...
...
@@ -2569,7 +2569,7 @@ class SplitParam : public OpParam {
// out_ts_.push_back(*scope.FindVar(outs_[i])->GetMutable());
// }
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
std
::
vector
<
GType
*>
Outs
()
const
{
return
outs_
;
}
int
Axis
()
const
{
return
axis
;
}
int
Num
()
const
{
return
num
;
}
...
...
@@ -2577,7 +2577,7 @@ class SplitParam : public OpParam {
// std::vector<GType> OutTs() const { return out_ts_; }
private:
R
Type
*
input_x_
;
G
Type
*
input_x_
;
std
::
vector
<
GType
*>
outs_
;
int
axis
;
int
num
;
...
...
@@ -2611,16 +2611,16 @@ class BilinearInterpParam : public OpParam {
out_h_
=
GetAttr
<
int
>
(
"out_h"
,
attrs
);
out_w_
=
GetAttr
<
int
>
(
"out_w"
,
attrs
);
}
const
R
Type
*
InputX
()
const
{
return
input_x_
;
}
const
R
Type
*
InputOutPutSize
()
const
{
return
input_outsize_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
const
G
Type
*
InputX
()
const
{
return
input_x_
;
}
const
G
Type
*
InputOutPutSize
()
const
{
return
input_outsize_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
int
OutH
()
const
{
return
out_h_
;
}
int
OutW
()
const
{
return
out_w_
;
}
private:
R
Type
*
input_x_
;
R
Type
*
input_outsize_
;
R
Type
*
out_
;
G
Type
*
input_x_
;
G
Type
*
input_outsize_
;
G
Type
*
out_
;
int
out_h_
;
int
out_w_
;
};
...
...
@@ -2638,12 +2638,12 @@ class ShapeParam : public OpParam {
input_
=
InputFrom
<
GType
>
(
inputs
,
scope
);
out_
=
OutFrom
<
GType
>
(
outputs
,
scope
);
}
const
R
Type
*
Input
()
const
{
return
input_
;
}
R
Type
*
Out
()
const
{
return
out_
;
}
const
G
Type
*
Input
()
const
{
return
input_
;
}
G
Type
*
Out
()
const
{
return
out_
;
}
private:
R
Type
*
input_
;
R
Type
*
out_
;
G
Type
*
input_
;
G
Type
*
out_
;
};
#endif
...
...
@@ -2686,8 +2686,8 @@ class CastParam : public OpParam {
}
public:
R
Type
*
input_
;
R
Type
*
output_
;
G
Type
*
input_
;
G
Type
*
output_
;
int
input_type_
;
int
output_type_
;
};
...
...
@@ -2723,9 +2723,9 @@ class QuantizeParam : public OpParam {
GType
*
input_
;
// op output
GType
*
output_
;
R
Type
*
online_scale_
;
G
Type
*
online_scale_
;
// quantize offline scale
R
Type
*
offline_scale_
;
G
Type
*
offline_scale_
;
// if offine scale or not
bool
offline_
=
false
;
// round method type
...
...
@@ -2759,7 +2759,7 @@ class DequantizeParam : public OpParam {
GType
*
input_
;
// op output
GType
*
output_
;
R
Type
*
activation_scale_
;
G
Type
*
activation_scale_
;
float
weight_scale_
;
};
#endif
...
...
@@ -2789,10 +2789,10 @@ class FusionDequantBNParam : public DequantizeParam<Dtype> {
public:
// batch norm
R
Type
*
bn_mean_
;
R
Type
*
bn_variance_
;
R
Type
*
bn_scale_
;
R
Type
*
bn_bias_
;
G
Type
*
bn_mean_
;
G
Type
*
bn_variance_
;
G
Type
*
bn_scale_
;
G
Type
*
bn_bias_
;
float
epsilon_
;
};
#endif
...
...
@@ -2819,7 +2819,7 @@ class FusionDequantAddBNParam : public FusionDequantBNParam<Dtype> {
public:
// elementwise add
int
axis_
;
R
Type
*
bias_
;
G
Type
*
bias_
;
};
#endif
...
...
@@ -2848,9 +2848,9 @@ class FusionDequantAddBNQuantParam : public FusionDequantAddBNParam<Dtype> {
}
public:
R
Type
*
online_scale_
;
G
Type
*
online_scale_
;
// quantize offline scale
R
Type
*
offline_scale_
;
G
Type
*
offline_scale_
;
// if offine scale or not
bool
offline_
=
false
;
// round method type
...
...
src/operators/softmax_op.cpp
浏览文件 @
6641a314
...
...
@@ -21,6 +21,7 @@ namespace operators {
template
<
typename
DeviceType
,
typename
T
>
void
SoftmaxOp
<
DeviceType
,
T
>::
InferShape
()
const
{
this
->
param_
.
Out
()
->
Resize
(
this
->
param_
.
InputX
()
->
dims
());
this
->
param_
.
Out
()
->
set_lod
(
this
->
param_
.
InputX
()
->
lod
());
}
}
// namespace operators
...
...
src/operators/top_k_op.cpp
浏览文件 @
6641a314
...
...
@@ -26,11 +26,9 @@ void TopKOp<DeviceType, T>::InferShape() const {
// should check k <= dims[-1] && k >= 1
dims
[
dims
.
size
()
-
1
]
=
k
;
this
->
param_
.
output_
->
Resize
(
dims
);
// this->param_.output_->set_lod(this->param_.input_->lod());
this
->
param_
.
output_
->
set_lod
({{
0
,
1
}});
this
->
param_
.
indices_
->
Resize
(
dims
);
// this->param_.indices
_->set_lod(this->param_.input_->lod());
this
->
param_
.
indices_
->
set_lod
(
{{
0
,
1
}}
);
this
->
param_
.
output
_
->
set_lod
(
this
->
param_
.
input_
->
lod
());
this
->
param_
.
indices_
->
set_lod
(
this
->
param_
.
input_
->
lod
()
);
}
}
// namespace operators
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录