Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
c754a38f
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
331
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c754a38f
编写于
3月 31, 2020
作者:
H
huzhiqiang
提交者:
GitHub
3月 31, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[operator] add InferShapeImpl method (#3294)
上级
50638e96
变更
243
隐藏空白更改
内联
并排
Showing
243 changed file
with
517 addition
and
502 deletion
+517
-502
lite/core/op_lite.cc
lite/core/op_lite.cc
+55
-0
lite/core/op_lite.h
lite/core/op_lite.h
+14
-6
lite/core/program.cc
lite/core/program.cc
+1
-2
lite/operators/activation_grad_ops.cc
lite/operators/activation_grad_ops.cc
+1
-1
lite/operators/activation_grad_ops.h
lite/operators/activation_grad_ops.h
+1
-1
lite/operators/activation_ops.cc
lite/operators/activation_ops.cc
+1
-1
lite/operators/activation_ops.h
lite/operators/activation_ops.h
+1
-1
lite/operators/affine_channel_op.cc
lite/operators/affine_channel_op.cc
+1
-1
lite/operators/affine_channel_op.h
lite/operators/affine_channel_op.h
+1
-1
lite/operators/anchor_generator_op.cc
lite/operators/anchor_generator_op.cc
+1
-1
lite/operators/anchor_generator_op.h
lite/operators/anchor_generator_op.h
+1
-1
lite/operators/argmax_op.cc
lite/operators/argmax_op.cc
+1
-1
lite/operators/argmax_op.h
lite/operators/argmax_op.h
+1
-1
lite/operators/assign_op.cc
lite/operators/assign_op.cc
+1
-1
lite/operators/assign_op.h
lite/operators/assign_op.h
+1
-1
lite/operators/assign_value_op.cc
lite/operators/assign_value_op.cc
+1
-1
lite/operators/assign_value_op.h
lite/operators/assign_value_op.h
+1
-1
lite/operators/attention_padding_mask_op.cc
lite/operators/attention_padding_mask_op.cc
+1
-1
lite/operators/attention_padding_mask_op.h
lite/operators/attention_padding_mask_op.h
+1
-1
lite/operators/axpy_op.cc
lite/operators/axpy_op.cc
+1
-1
lite/operators/axpy_op.h
lite/operators/axpy_op.h
+1
-1
lite/operators/batch_norm_op.cc
lite/operators/batch_norm_op.cc
+1
-1
lite/operators/batch_norm_op.h
lite/operators/batch_norm_op.h
+1
-1
lite/operators/beam_search_decode_op.cc
lite/operators/beam_search_decode_op.cc
+1
-1
lite/operators/beam_search_decode_op.h
lite/operators/beam_search_decode_op.h
+1
-1
lite/operators/beam_search_op.cc
lite/operators/beam_search_op.cc
+1
-1
lite/operators/beam_search_op.h
lite/operators/beam_search_op.h
+1
-1
lite/operators/box_clip_op.cc
lite/operators/box_clip_op.cc
+1
-1
lite/operators/box_clip_op.h
lite/operators/box_clip_op.h
+1
-1
lite/operators/box_coder_op.cc
lite/operators/box_coder_op.cc
+1
-1
lite/operators/box_coder_op.h
lite/operators/box_coder_op.h
+1
-1
lite/operators/calib_op.cc
lite/operators/calib_op.cc
+1
-1
lite/operators/calib_op.h
lite/operators/calib_op.h
+1
-1
lite/operators/cast_op.cc
lite/operators/cast_op.cc
+1
-1
lite/operators/cast_op.h
lite/operators/cast_op.h
+1
-1
lite/operators/collect_fpn_proposals_op.cc
lite/operators/collect_fpn_proposals_op.cc
+1
-1
lite/operators/collect_fpn_proposals_op.h
lite/operators/collect_fpn_proposals_op.h
+1
-1
lite/operators/compare_op.cc
lite/operators/compare_op.cc
+1
-1
lite/operators/compare_op.h
lite/operators/compare_op.h
+1
-1
lite/operators/concat_op.cc
lite/operators/concat_op.cc
+1
-1
lite/operators/concat_op.h
lite/operators/concat_op.h
+1
-1
lite/operators/conditional_block_op.cc
lite/operators/conditional_block_op.cc
+1
-1
lite/operators/conditional_block_op.h
lite/operators/conditional_block_op.h
+1
-1
lite/operators/conv_op.cc
lite/operators/conv_op.cc
+1
-29
lite/operators/conv_op.h
lite/operators/conv_op.h
+1
-3
lite/operators/conv_transpose_op.cc
lite/operators/conv_transpose_op.cc
+1
-1
lite/operators/conv_transpose_op.h
lite/operators/conv_transpose_op.h
+1
-1
lite/operators/crf_decoding_op.cc
lite/operators/crf_decoding_op.cc
+1
-1
lite/operators/crf_decoding_op.h
lite/operators/crf_decoding_op.h
+1
-1
lite/operators/crop_op.cc
lite/operators/crop_op.cc
+1
-1
lite/operators/crop_op.h
lite/operators/crop_op.h
+1
-1
lite/operators/decode_bboxes_op.cc
lite/operators/decode_bboxes_op.cc
+1
-1
lite/operators/decode_bboxes_op.h
lite/operators/decode_bboxes_op.h
+1
-1
lite/operators/density_prior_box_op.cc
lite/operators/density_prior_box_op.cc
+1
-1
lite/operators/density_prior_box_op.h
lite/operators/density_prior_box_op.h
+1
-1
lite/operators/distribute_fpn_proposals_op.cc
lite/operators/distribute_fpn_proposals_op.cc
+1
-1
lite/operators/distribute_fpn_proposals_op.h
lite/operators/distribute_fpn_proposals_op.h
+1
-1
lite/operators/dropout_op.cc
lite/operators/dropout_op.cc
+1
-1
lite/operators/dropout_op.h
lite/operators/dropout_op.h
+1
-1
lite/operators/elementwise_grad_ops.cc
lite/operators/elementwise_grad_ops.cc
+1
-1
lite/operators/elementwise_grad_ops.h
lite/operators/elementwise_grad_ops.h
+1
-1
lite/operators/elementwise_ops.cc
lite/operators/elementwise_ops.cc
+2
-33
lite/operators/elementwise_ops.h
lite/operators/elementwise_ops.h
+2
-3
lite/operators/expand_op.cc
lite/operators/expand_op.cc
+1
-1
lite/operators/expand_op.h
lite/operators/expand_op.h
+1
-1
lite/operators/fake_channel_wise_dequantize_max_abs.h
lite/operators/fake_channel_wise_dequantize_max_abs.h
+1
-1
lite/operators/fake_dequantize_max_abs.h
lite/operators/fake_dequantize_max_abs.h
+1
-1
lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h
lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h
+1
-1
lite/operators/fake_quantize_moving_avg_max_abs.h
lite/operators/fake_quantize_moving_avg_max_abs.h
+1
-1
lite/operators/fake_quantize_range_abs_max.h
lite/operators/fake_quantize_range_abs_max.h
+1
-1
lite/operators/fc_op.cc
lite/operators/fc_op.cc
+1
-28
lite/operators/fc_op.h
lite/operators/fc_op.h
+1
-2
lite/operators/feed_op.cc
lite/operators/feed_op.cc
+1
-1
lite/operators/fetch_op.cc
lite/operators/fetch_op.cc
+1
-1
lite/operators/fill_constant_batch_size_like_op.cc
lite/operators/fill_constant_batch_size_like_op.cc
+1
-1
lite/operators/fill_constant_batch_size_like_op.h
lite/operators/fill_constant_batch_size_like_op.h
+1
-1
lite/operators/fill_constant_op.cc
lite/operators/fill_constant_op.cc
+1
-1
lite/operators/fill_constant_op.h
lite/operators/fill_constant_op.h
+1
-1
lite/operators/flatten_op.cc
lite/operators/flatten_op.cc
+3
-3
lite/operators/flatten_op.h
lite/operators/flatten_op.h
+2
-2
lite/operators/fusion_elementwise_activation_ops.cc
lite/operators/fusion_elementwise_activation_ops.cc
+2
-2
lite/operators/fusion_elementwise_activation_ops.h
lite/operators/fusion_elementwise_activation_ops.h
+2
-2
lite/operators/gather_op.cc
lite/operators/gather_op.cc
+1
-1
lite/operators/gather_op.h
lite/operators/gather_op.h
+1
-1
lite/operators/generate_proposals_op.cc
lite/operators/generate_proposals_op.cc
+1
-1
lite/operators/generate_proposals_op.h
lite/operators/generate_proposals_op.h
+1
-1
lite/operators/grid_sampler_op.cc
lite/operators/grid_sampler_op.cc
+1
-1
lite/operators/grid_sampler_op.h
lite/operators/grid_sampler_op.h
+1
-1
lite/operators/gru_op.cc
lite/operators/gru_op.cc
+1
-1
lite/operators/gru_op.h
lite/operators/gru_op.h
+1
-1
lite/operators/gru_unit_op.cc
lite/operators/gru_unit_op.cc
+1
-1
lite/operators/gru_unit_op.h
lite/operators/gru_unit_op.h
+1
-1
lite/operators/im2sequence_op.cc
lite/operators/im2sequence_op.cc
+1
-1
lite/operators/im2sequence_op.h
lite/operators/im2sequence_op.h
+1
-1
lite/operators/increment_op.cc
lite/operators/increment_op.cc
+1
-1
lite/operators/increment_op.h
lite/operators/increment_op.h
+1
-1
lite/operators/instance_norm_op.cc
lite/operators/instance_norm_op.cc
+1
-1
lite/operators/instance_norm_op.h
lite/operators/instance_norm_op.h
+1
-1
lite/operators/interpolate_op.cc
lite/operators/interpolate_op.cc
+1
-1
lite/operators/interpolate_op.h
lite/operators/interpolate_op.h
+1
-1
lite/operators/io_copy_op.cc
lite/operators/io_copy_op.cc
+1
-1
lite/operators/io_copy_op.h
lite/operators/io_copy_op.h
+1
-1
lite/operators/is_empty_op.cc
lite/operators/is_empty_op.cc
+1
-1
lite/operators/is_empty_op.h
lite/operators/is_empty_op.h
+1
-1
lite/operators/layer_norm_op.cc
lite/operators/layer_norm_op.cc
+1
-1
lite/operators/layer_norm_op.h
lite/operators/layer_norm_op.h
+1
-1
lite/operators/layout_op.cc
lite/operators/layout_op.cc
+1
-1
lite/operators/layout_op.h
lite/operators/layout_op.h
+1
-1
lite/operators/lod_reset_op.cc
lite/operators/lod_reset_op.cc
+1
-1
lite/operators/lod_reset_op.h
lite/operators/lod_reset_op.h
+1
-1
lite/operators/logical_op.cc
lite/operators/logical_op.cc
+2
-2
lite/operators/logical_op.h
lite/operators/logical_op.h
+2
-2
lite/operators/lookup_table_dequant_op.cc
lite/operators/lookup_table_dequant_op.cc
+1
-1
lite/operators/lookup_table_dequant_op.h
lite/operators/lookup_table_dequant_op.h
+1
-1
lite/operators/lookup_table_op.cc
lite/operators/lookup_table_op.cc
+1
-1
lite/operators/lookup_table_op.h
lite/operators/lookup_table_op.h
+1
-1
lite/operators/lookup_table_v2_op.cc
lite/operators/lookup_table_v2_op.cc
+1
-1
lite/operators/lookup_table_v2_op.h
lite/operators/lookup_table_v2_op.h
+1
-1
lite/operators/lrn_op.cc
lite/operators/lrn_op.cc
+1
-1
lite/operators/lrn_op.h
lite/operators/lrn_op.h
+1
-1
lite/operators/lstm_op.cc
lite/operators/lstm_op.cc
+1
-1
lite/operators/lstm_op.h
lite/operators/lstm_op.h
+1
-1
lite/operators/match_matrix_tensor_op.cc
lite/operators/match_matrix_tensor_op.cc
+1
-1
lite/operators/match_matrix_tensor_op.h
lite/operators/match_matrix_tensor_op.h
+1
-1
lite/operators/matmul_op.cc
lite/operators/matmul_op.cc
+1
-1
lite/operators/matmul_op.h
lite/operators/matmul_op.h
+1
-1
lite/operators/mean_grad_op.cc
lite/operators/mean_grad_op.cc
+1
-1
lite/operators/mean_grad_op.h
lite/operators/mean_grad_op.h
+1
-1
lite/operators/mean_op.cc
lite/operators/mean_op.cc
+1
-1
lite/operators/mean_op.h
lite/operators/mean_op.h
+1
-1
lite/operators/merge_lod_tensor_op.cc
lite/operators/merge_lod_tensor_op.cc
+1
-1
lite/operators/merge_lod_tensor_op.h
lite/operators/merge_lod_tensor_op.h
+1
-1
lite/operators/mul_grad_op.cc
lite/operators/mul_grad_op.cc
+1
-1
lite/operators/mul_grad_op.h
lite/operators/mul_grad_op.h
+1
-1
lite/operators/mul_op.cc
lite/operators/mul_op.cc
+1
-1
lite/operators/mul_op.h
lite/operators/mul_op.h
+1
-1
lite/operators/multiclass_nms_op.cc
lite/operators/multiclass_nms_op.cc
+1
-1
lite/operators/multiclass_nms_op.h
lite/operators/multiclass_nms_op.h
+1
-1
lite/operators/negative_op.cc
lite/operators/negative_op.cc
+1
-1
lite/operators/negative_op.h
lite/operators/negative_op.h
+1
-1
lite/operators/norm_op.cc
lite/operators/norm_op.cc
+1
-1
lite/operators/norm_op.h
lite/operators/norm_op.h
+1
-1
lite/operators/op_params.h
lite/operators/op_params.h
+188
-116
lite/operators/pad2d_op.cc
lite/operators/pad2d_op.cc
+1
-1
lite/operators/pad2d_op.h
lite/operators/pad2d_op.h
+1
-1
lite/operators/pool_op.cc
lite/operators/pool_op.cc
+1
-1
lite/operators/pool_op.h
lite/operators/pool_op.h
+1
-1
lite/operators/power_op.cc
lite/operators/power_op.cc
+1
-1
lite/operators/power_op.h
lite/operators/power_op.h
+1
-1
lite/operators/prior_box_op.cc
lite/operators/prior_box_op.cc
+1
-1
lite/operators/prior_box_op.h
lite/operators/prior_box_op.h
+1
-1
lite/operators/range_op.cc
lite/operators/range_op.cc
+1
-1
lite/operators/range_op.h
lite/operators/range_op.h
+1
-1
lite/operators/read_from_array_op.cc
lite/operators/read_from_array_op.cc
+1
-1
lite/operators/read_from_array_op.h
lite/operators/read_from_array_op.h
+1
-1
lite/operators/reduce_max_op.cc
lite/operators/reduce_max_op.cc
+1
-1
lite/operators/reduce_max_op.h
lite/operators/reduce_max_op.h
+1
-1
lite/operators/reduce_mean_op.cc
lite/operators/reduce_mean_op.cc
+1
-1
lite/operators/reduce_mean_op.h
lite/operators/reduce_mean_op.h
+1
-1
lite/operators/reduce_ops.cc
lite/operators/reduce_ops.cc
+1
-1
lite/operators/reduce_ops.h
lite/operators/reduce_ops.h
+1
-1
lite/operators/reduce_prod_op.cc
lite/operators/reduce_prod_op.cc
+1
-1
lite/operators/reduce_prod_op.h
lite/operators/reduce_prod_op.h
+1
-1
lite/operators/relu_op.cc
lite/operators/relu_op.cc
+1
-1
lite/operators/relu_op.h
lite/operators/relu_op.h
+1
-1
lite/operators/reshape_op.cc
lite/operators/reshape_op.cc
+3
-3
lite/operators/reshape_op.h
lite/operators/reshape_op.h
+2
-2
lite/operators/roi_align_op.cc
lite/operators/roi_align_op.cc
+1
-1
lite/operators/roi_align_op.h
lite/operators/roi_align_op.h
+1
-1
lite/operators/scale_op.cc
lite/operators/scale_op.cc
+1
-1
lite/operators/scale_op.h
lite/operators/scale_op.h
+1
-1
lite/operators/search_aligned_mat_mul_op.cc
lite/operators/search_aligned_mat_mul_op.cc
+1
-1
lite/operators/search_aligned_mat_mul_op.h
lite/operators/search_aligned_mat_mul_op.h
+1
-1
lite/operators/search_fc_op.cc
lite/operators/search_fc_op.cc
+1
-1
lite/operators/search_fc_op.h
lite/operators/search_fc_op.h
+1
-1
lite/operators/search_grnn_op.cc
lite/operators/search_grnn_op.cc
+1
-1
lite/operators/search_grnn_op.h
lite/operators/search_grnn_op.h
+1
-1
lite/operators/search_group_padding_op.cc
lite/operators/search_group_padding_op.cc
+1
-1
lite/operators/search_group_padding_op.h
lite/operators/search_group_padding_op.h
+1
-1
lite/operators/search_seq_depadding_op.cc
lite/operators/search_seq_depadding_op.cc
+1
-1
lite/operators/search_seq_depadding_op.h
lite/operators/search_seq_depadding_op.h
+1
-1
lite/operators/search_seq_fc_op.cc
lite/operators/search_seq_fc_op.cc
+1
-1
lite/operators/search_seq_fc_op.h
lite/operators/search_seq_fc_op.h
+1
-1
lite/operators/search_seq_softmax_op.cc
lite/operators/search_seq_softmax_op.cc
+1
-1
lite/operators/search_seq_softmax_op.h
lite/operators/search_seq_softmax_op.h
+1
-1
lite/operators/sequence_arithmetic_op.cc
lite/operators/sequence_arithmetic_op.cc
+1
-1
lite/operators/sequence_arithmetic_op.h
lite/operators/sequence_arithmetic_op.h
+1
-1
lite/operators/sequence_concat_op.cc
lite/operators/sequence_concat_op.cc
+1
-1
lite/operators/sequence_concat_op.h
lite/operators/sequence_concat_op.h
+1
-1
lite/operators/sequence_conv_op.cc
lite/operators/sequence_conv_op.cc
+1
-1
lite/operators/sequence_conv_op.h
lite/operators/sequence_conv_op.h
+1
-1
lite/operators/sequence_expand_as_op.cc
lite/operators/sequence_expand_as_op.cc
+1
-1
lite/operators/sequence_expand_as_op.h
lite/operators/sequence_expand_as_op.h
+1
-1
lite/operators/sequence_expand_op.cc
lite/operators/sequence_expand_op.cc
+1
-1
lite/operators/sequence_expand_op.h
lite/operators/sequence_expand_op.h
+1
-1
lite/operators/sequence_pool_concat_op.cc
lite/operators/sequence_pool_concat_op.cc
+1
-1
lite/operators/sequence_pool_concat_op.h
lite/operators/sequence_pool_concat_op.h
+1
-1
lite/operators/sequence_pool_op.cc
lite/operators/sequence_pool_op.cc
+1
-1
lite/operators/sequence_pool_op.h
lite/operators/sequence_pool_op.h
+1
-1
lite/operators/sequence_reshape_op.cc
lite/operators/sequence_reshape_op.cc
+1
-1
lite/operators/sequence_reshape_op.h
lite/operators/sequence_reshape_op.h
+1
-1
lite/operators/sequence_reverse_op.cc
lite/operators/sequence_reverse_op.cc
+1
-1
lite/operators/sequence_reverse_op.h
lite/operators/sequence_reverse_op.h
+1
-1
lite/operators/sequence_softmax_op.cc
lite/operators/sequence_softmax_op.cc
+1
-1
lite/operators/sequence_softmax_op.h
lite/operators/sequence_softmax_op.h
+1
-1
lite/operators/sequence_topk_avg_pooling_op.cc
lite/operators/sequence_topk_avg_pooling_op.cc
+1
-1
lite/operators/sequence_topk_avg_pooling_op.h
lite/operators/sequence_topk_avg_pooling_op.h
+1
-1
lite/operators/sgd_op.cc
lite/operators/sgd_op.cc
+1
-1
lite/operators/sgd_op.h
lite/operators/sgd_op.h
+1
-1
lite/operators/shape_op.cc
lite/operators/shape_op.cc
+1
-1
lite/operators/shape_op.h
lite/operators/shape_op.h
+1
-1
lite/operators/shuffle_channel_op.cc
lite/operators/shuffle_channel_op.cc
+1
-1
lite/operators/shuffle_channel_op.h
lite/operators/shuffle_channel_op.h
+1
-1
lite/operators/slice_op.cc
lite/operators/slice_op.cc
+1
-1
lite/operators/slice_op.h
lite/operators/slice_op.h
+1
-1
lite/operators/softmax_op.cc
lite/operators/softmax_op.cc
+1
-29
lite/operators/softmax_op.h
lite/operators/softmax_op.h
+1
-2
lite/operators/split_lod_tensor_op.cc
lite/operators/split_lod_tensor_op.cc
+1
-1
lite/operators/split_lod_tensor_op.h
lite/operators/split_lod_tensor_op.h
+1
-1
lite/operators/split_op.cc
lite/operators/split_op.cc
+1
-1
lite/operators/split_op.h
lite/operators/split_op.h
+1
-1
lite/operators/squeeze_op.cc
lite/operators/squeeze_op.cc
+3
-3
lite/operators/squeeze_op.h
lite/operators/squeeze_op.h
+2
-2
lite/operators/stack_op.cc
lite/operators/stack_op.cc
+1
-1
lite/operators/stack_op.h
lite/operators/stack_op.h
+1
-1
lite/operators/subgraph_op.cc
lite/operators/subgraph_op.cc
+1
-1
lite/operators/subgraph_op.h
lite/operators/subgraph_op.h
+1
-1
lite/operators/topk_op.cc
lite/operators/topk_op.cc
+1
-1
lite/operators/topk_op.h
lite/operators/topk_op.h
+1
-1
lite/operators/transpose_op.cc
lite/operators/transpose_op.cc
+2
-2
lite/operators/transpose_op.h
lite/operators/transpose_op.h
+2
-2
lite/operators/uniform_random_op.cc
lite/operators/uniform_random_op.cc
+1
-1
lite/operators/uniform_random_op.h
lite/operators/uniform_random_op.h
+1
-1
lite/operators/unsqueeze_op.cc
lite/operators/unsqueeze_op.cc
+3
-3
lite/operators/unsqueeze_op.h
lite/operators/unsqueeze_op.h
+2
-2
lite/operators/var_conv_2d_op.cc
lite/operators/var_conv_2d_op.cc
+1
-1
lite/operators/var_conv_2d_op.h
lite/operators/var_conv_2d_op.h
+1
-1
lite/operators/while_op.cc
lite/operators/while_op.cc
+1
-1
lite/operators/while_op.h
lite/operators/while_op.h
+1
-1
lite/operators/write_to_array_op.cc
lite/operators/write_to_array_op.cc
+1
-1
lite/operators/write_to_array_op.h
lite/operators/write_to_array_op.h
+1
-1
lite/operators/yolo_box_op.cc
lite/operators/yolo_box_op.cc
+1
-1
lite/operators/yolo_box_op.h
lite/operators/yolo_box_op.h
+1
-1
未找到文件。
lite/core/op_lite.cc
浏览文件 @
c754a38f
...
...
@@ -22,6 +22,61 @@
namespace
paddle
{
namespace
lite
{
bool
OpLite
::
InferShape
()
{
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
if
(
param_
.
input_tensor_ptrs
()
&&
param_
.
output_tensor_ptrs
())
{
return
this
->
InferShapeWithCache
();
}
else
{
// otherwise, InferShapeImpl is applied directly.
return
this
->
InferShapeImpl
();
}
}
bool
OpLite
::
InferShapeWithCache
()
{
// 1. Get vector of current input tensors
auto
*
current_inputs
=
param_
.
input_tensor_ptrs
();
// 2. Get hash value of current inputs shape and lod
size_t
new_hash
=
0
;
for
(
auto
iter
=
current_inputs
->
begin
();
iter
!=
current_inputs
->
end
();
iter
++
)
{
// combined dims value into new_hash value.
auto
&
element_dims
=
(
*
iter
)
->
dims
();
for
(
int
i
=
0
;
i
<
element_dims
.
size
();
i
++
)
{
new_hash
=
lite
::
hash_combine
(
new_hash
,
static_cast
<
int
>
(
element_dims
[
i
]));
}
// combine lod value into new_hash valud.
auto
&
emement_lods
=
(
*
iter
)
->
lod
();
for
(
auto
lod_iter
=
emement_lods
.
begin
();
lod_iter
!=
emement_lods
.
end
();
lod_iter
++
)
{
for
(
int
i
=
0
;
i
<
lod_iter
->
size
();
i
++
)
{
new_hash
=
lite
::
hash_combine
(
new_hash
,
static_cast
<
int
>
(
lod_iter
->
at
(
i
)));
}
}
}
// 3. infer shapes of output tensors
if
(
new_hash
==
io_shape_lod_hash_
&&
new_hash
!=
0
)
{
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto
*
current_outputs
=
param_
.
output_tensor_ptrs
();
for
(
int
i
=
0
;
i
<
current_outputs
->
size
();
i
++
)
{
current_outputs
->
at
(
i
)
->
Resize
(
last_output_shapes
[
i
]);
current_outputs
->
at
(
i
)
->
set_lod
(
last_output_lods
[
i
]);
}
}
else
{
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_
=
new_hash
;
this
->
InferShapeImpl
();
auto
*
current_outputs
=
param_
.
output_tensor_ptrs
();
for
(
int
i
=
0
;
i
<
current_outputs
->
size
();
i
++
)
{
last_output_shapes
[
i
]
=
current_outputs
->
at
(
i
)
->
dims
();
last_output_lods
[
i
]
=
current_outputs
->
at
(
i
)
->
lod
();
}
}
return
true
;
}
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
OpLite
::
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
)
{
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
kernels
;
...
...
lite/core/op_lite.h
浏览文件 @
c754a38f
...
...
@@ -14,6 +14,7 @@
#pragma once
#include <functional>
#include <list>
#include <map>
#include <memory>
...
...
@@ -24,6 +25,7 @@
#include "lite/core/kernel.h"
#include "lite/core/scope.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/operators/op_params.h"
namespace
paddle
{
namespace
lite
{
...
...
@@ -64,8 +66,8 @@ class OpLite : public Registry {
// Check the shape.
virtual
bool
CheckShape
()
const
{
return
true
;
}
// Inference the outputs' shape.
virtual
bool
InferShape
()
const
{
return
true
;
}
virtual
bool
SmartInferShape
()
{
return
this
->
InferShape
();
}
virtual
bool
InferShape
Impl
()
const
{
return
true
;
}
virtual
bool
InferShape
();
// Run this operator.
virtual
bool
Run
();
// Indicate whether the Op runs only once or not
...
...
@@ -151,10 +153,16 @@ class OpLite : public Registry {
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
unique_ptr
<
OpInfo
>
op_info_
;
std
::
vector
<
DDimLite
>
last_input_shapes
;
std
::
vector
<
DDimLite
>
last_output_shapes
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
last_output_lods
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
last_input_lods
;
std
::
vector
<
DDimLite
>
last_output_shapes
{};
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
last_output_lods
{};
size_t
io_shape_lod_hash_
{};
mutable
operators
::
ParamBase
param_
;
private:
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
bool
InferShapeWithCache
();
};
/*
...
...
lite/core/program.cc
浏览文件 @
c754a38f
...
...
@@ -286,8 +286,7 @@ void Instruction::Run() {
return
;
}
// op_->InferShape();
op_
->
SmartInferShape
();
op_
->
InferShape
();
kernel_
->
Launch
();
has_run_
=
true
;
}
...
...
lite/operators/activation_grad_ops.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const {
return
true
;
}
bool
ActivationGradOp
::
InferShape
()
const
{
bool
ActivationGradOp
::
InferShape
Impl
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
return
true
;
}
...
...
lite/operators/activation_grad_ops.h
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/activation_ops.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const {
return
true
;
}
bool
ActivationOp
::
InferShape
()
const
{
bool
ActivationOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
auto
out_lod
=
param_
.
Out
->
mutable_lod
();
*
out_lod
=
param_
.
X
->
lod
();
...
...
lite/operators/activation_ops.h
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ class ActivationOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/affine_channel_op.cc
浏览文件 @
c754a38f
...
...
@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const {
return
true
;
}
bool
AffineChannelOpLite
::
InferShape
()
const
{
bool
AffineChannelOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
X
->
dims
();
param_
.
Out
->
Resize
(
x_dims
);
return
true
;
...
...
lite/operators/affine_channel_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/anchor_generator_op.cc
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const {
return
true
;
}
bool
AnchorGeneratorOpLite
::
InferShape
()
const
{
bool
AnchorGeneratorOpLite
::
InferShape
Impl
()
const
{
auto
input_dims
=
param_
.
Input
->
dims
();
size_t
num_anchors
=
param_
.
aspect_ratios
.
size
()
*
param_
.
anchor_sizes
.
size
();
std
::
vector
<
int64_t
>
output_shape
(
...
...
lite/operators/anchor_generator_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/argmax_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const {
return
true
;
}
bool
ArgmaxOpLite
::
InferShape
()
const
{
bool
ArgmaxOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
X
->
dims
();
int
x_rank
=
x_dims
.
size
();
int
axis
=
param_
.
Axis
;
...
...
lite/operators/argmax_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/assign_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const {
return
true
;
}
bool
AssignOpLite
::
InferShape
()
const
{
bool
AssignOpLite
::
InferShape
Impl
()
const
{
lite
::
DDim
input_dims
;
input_dims
=
param_
.
X
->
dims
();
param_
.
Out
->
Resize
(
lite
::
DDim
(
input_dims
));
...
...
lite/operators/assign_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/assign_value_op.cc
浏览文件 @
c754a38f
...
...
@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const {
return
true
;
}
bool
AssignValueOpLite
::
InferShape
()
const
{
bool
AssignValueOpLite
::
InferShape
Impl
()
const
{
std
::
vector
<
int
>
shape
=
param_
.
shape
;
std
::
vector
<
int64_t
>
out_shape
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
i
++
)
out_shape
.
push_back
(
shape
[
i
]);
...
...
lite/operators/assign_value_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/attention_padding_mask_op.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const {
return
true
;
}
bool
AttentionPaddingMaskOp
::
InferShape
()
const
{
bool
AttentionPaddingMaskOp
::
InferShape
Impl
()
const
{
auto
src_len
=
param_
.
X
->
lod
()[
0
][
1
];
CHECK_EQ
(
src_len
,
param_
.
X
->
dims
()[
1
])
<<
"Mismatch source length, expect: "
<<
src_len
...
...
lite/operators/attention_padding_mask_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/axpy_op.cc
浏览文件 @
c754a38f
...
...
@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const {
return
true
;
}
bool
AxpyOpLite
::
InferShape
()
const
{
bool
AxpyOpLite
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
Bias
->
dims
();
// Set output dims
...
...
lite/operators/axpy_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/batch_norm_op.cc
浏览文件 @
c754a38f
...
...
@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const {
return
true
;
}
bool
BatchNormOp
::
InferShape
()
const
{
bool
BatchNormOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
int64_t
channel_size
=
0
;
switch
(
param_
.
data_layout
)
{
...
...
lite/operators/batch_norm_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/beam_search_decode_op.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const {
return
true
;
}
bool
BeamSearchDecodeOpLite
::
InferShape
()
const
{
return
true
;
}
bool
BeamSearchDecodeOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
BeamSearchDecodeOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
...
...
lite/operators/beam_search_decode_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/beam_search_op.cc
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const {
return
true
;
}
bool
BeamSearchOp
::
InferShape
()
const
{
return
true
;
}
bool
BeamSearchOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
BeamSearchOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
pre_ids
=
scope
->
FindTensor
(
opdesc
.
Input
(
"pre_ids"
).
front
());
...
...
lite/operators/beam_search_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/box_clip_op.cc
浏览文件 @
c754a38f
...
...
@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const {
return
true
;
}
bool
BoxClipOpLite
::
InferShape
()
const
{
bool
BoxClipOpLite
::
InferShape
Impl
()
const
{
auto
*
input
=
param_
.
Input
;
auto
*
output
=
param_
.
Output
;
output
->
Resize
(
input
->
dims
());
...
...
lite/operators/box_clip_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/box_coder_op.cc
浏览文件 @
c754a38f
...
...
@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const {
return
true
;
}
bool
BoxCoderOpLite
::
InferShape
()
const
{
bool
BoxCoderOpLite
::
InferShape
Impl
()
const
{
auto
prior_box_dims
=
param_
.
prior_box
->
dims
();
auto
target_box_dims
=
param_
.
target_box
->
dims
();
std
::
string
code_type
=
param_
.
code_type
;
...
...
lite/operators/box_coder_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/calib_op.cc
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
output
);
return
true
;
}
bool
CalibOpLite
::
InferShape
()
const
{
bool
CalibOpLite
::
InferShape
Impl
()
const
{
param_
.
output
->
Resize
(
param_
.
input
->
dims
());
return
true
;
}
...
...
lite/operators/calib_op.h
浏览文件 @
c754a38f
...
...
@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
...
...
lite/operators/cast_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const {
return
true
;
}
bool
CastOp
::
InferShape
()
const
{
bool
CastOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
out_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/cast_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class CastOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/collect_fpn_proposals_op.cc
浏览文件 @
c754a38f
...
...
@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
return
true
;
}
bool
CollectFpnProposalsOpLite
::
InferShape
()
const
{
bool
CollectFpnProposalsOpLite
::
InferShape
Impl
()
const
{
param_
.
fpn_rois
->
Resize
({
param_
.
post_nms_topN
,
4
});
return
true
;
...
...
lite/operators/collect_fpn_proposals_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/compare_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const {
return
true
;
}
bool
CompareOp
::
InferShape
()
const
{
bool
CompareOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/compare_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class CompareOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/concat_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const {
return
true
;
}
bool
ConcatOpLite
::
InferShape
()
const
{
bool
ConcatOpLite
::
InferShape
Impl
()
const
{
const
std
::
vector
<
Tensor
*>
&
inputs
=
param_
.
x
;
const
size_t
n
=
inputs
.
size
();
CHECK_GT_OR_FALSE
(
n
,
0
);
...
...
lite/operators/concat_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/conditional_block_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const {
return
true
;
}
bool
ConditionalBlockOpLite
::
InferShape
()
const
{
return
true
;
}
bool
ConditionalBlockOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
ConditionalBlockOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
...
...
lite/operators/conditional_block_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/conv_op.cc
浏览文件 @
c754a38f
...
...
@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
}
}
bool
ConvOpLite
::
SmartInferShape
()
{
if
(
!
last_input_shapes
.
empty
())
{
if
(
last_input_shapes
[
0
]
==
param_
.
x
->
dims
()
&&
last_input_lods
[
0
]
==
param_
.
x
->
lod
())
{
param_
.
output
->
Resize
(
last_output_shapes
[
0
]);
param_
.
output
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
x
->
dims
());
last_input_lods
.
push_back
(
param_
.
x
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
output
->
dims
());
last_output_lods
.
push_back
(
param_
.
output
->
lod
());
return
true
;
}
bool
ConvOpLite
::
InferShape
()
const
{
bool
ConvOpLite
::
InferShapeImpl
()
const
{
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
...
...
lite/operators/conv_op.h
浏览文件 @
c754a38f
...
...
@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite {
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
SmartInferShape
()
override
;
bool
InferShapeImpl
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
...
...
lite/operators/conv_transpose_op.cc
浏览文件 @
c754a38f
...
...
@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size,
return
output_size
;
}
bool
ConvTransposeOpLite
::
InferShape
()
const
{
bool
ConvTransposeOpLite
::
InferShape
Impl
()
const
{
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
...
...
lite/operators/conv_transpose_op.h
浏览文件 @
c754a38f
...
...
@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/crf_decoding_op.cc
浏览文件 @
c754a38f
...
...
@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const {
return
true
;
}
bool
CrfDecodingOpLite
::
InferShape
()
const
{
bool
CrfDecodingOpLite
::
InferShape
Impl
()
const
{
auto
emission_dims
=
param_
.
emission
->
dims
();
if
(
param_
.
length
==
nullptr
)
{
param_
.
viterbi_path
->
Resize
({
emission_dims
[
0
],
1
});
...
...
lite/operators/crf_decoding_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/crop_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const {
return
true
;
}
bool
CropOpLite
::
InferShape
()
const
{
bool
CropOpLite
::
InferShape
Impl
()
const
{
// nchw
auto
x_dims
=
param_
.
X
->
dims
();
lite
::
DDim
output_shape
(
x_dims
);
...
...
lite/operators/crop_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class CropOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/decode_bboxes_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const {
return
true
;
}
bool
DecodeBboxesOpLite
::
InferShape
()
const
{
bool
DecodeBboxesOpLite
::
InferShape
Impl
()
const
{
param_
.
bbox_data
->
Resize
(
param_
.
loc_data
->
dims
());
return
true
;
}
...
...
lite/operators/decode_bboxes_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/density_prior_box_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const {
return
true
;
}
bool
DensityPriorBoxOpLite
::
InferShape
()
const
{
return
true
;
}
bool
DensityPriorBoxOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
DensityPriorBoxOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
...
...
lite/operators/density_prior_box_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/distribute_fpn_proposals_op.cc
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const {
return
true
;
}
bool
DistributeFpnProposalsOpLite
::
InferShape
()
const
{
bool
DistributeFpnProposalsOpLite
::
InferShape
Impl
()
const
{
int
num_out_rois
=
param_
.
max_level
-
param_
.
min_level
+
1
;
for
(
int
i
=
0
;
i
<
num_out_rois
;
i
++
)
{
param_
.
multi_fpn_rois
[
i
]
->
Resize
({
-
1
,
4
});
...
...
lite/operators/distribute_fpn_proposals_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/dropout_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const {
return
true
;
}
bool
DropoutOp
::
InferShape
()
const
{
bool
DropoutOp
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
param_
.
output
->
Resize
(
x_dims
);
if
(
param_
.
is_test
==
false
)
{
...
...
lite/operators/dropout_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class DropoutOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
...
...
lite/operators/elementwise_grad_ops.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const {
return
true
;
}
bool
ElementwiseGradOp
::
InferShape
()
const
{
bool
ElementwiseGradOp
::
InferShape
Impl
()
const
{
auto
x_dim
=
param_
.
X
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
if
(
param_
.
XGrad
)
{
...
...
lite/operators/elementwise_grad_ops.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/elementwise_ops.cc
浏览文件 @
c754a38f
...
...
@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
bool
ElementwiseOp
::
SmartInferShape
()
{
if
(
!
last_input_shapes
.
empty
())
{
if
(
last_input_shapes
[
0
]
==
param_
.
X
->
dims
()
&&
last_input_shapes
[
1
]
==
param_
.
Y
->
dims
()
&&
last_input_lods
[
0
]
==
param_
.
X
->
lod
()
&&
last_input_lods
[
1
]
==
param_
.
Y
->
lod
())
{
param_
.
Out
->
Resize
(
last_output_shapes
[
0
]);
param_
.
Out
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
X
->
dims
());
last_input_lods
.
push_back
(
param_
.
X
->
lod
());
last_input_shapes
.
push_back
(
param_
.
Y
->
dims
());
last_input_lods
.
push_back
(
param_
.
Y
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
Out
->
dims
());
last_output_lods
.
push_back
(
param_
.
Out
->
lod
());
return
true
;
}
bool
ElementwiseOp
::
InferShape
()
const
{
bool
ElementwiseOp
::
InferShapeImpl
()
const
{
auto
x_dim
=
param_
.
X
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
if
(
x_dim
==
y_dim
)
{
...
...
@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
// return true;
//}
// bool ElementwiseGradExplicitOp::InferShape() const {
// bool ElementwiseGradExplicitOp::InferShape
Impl
() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// return true;
...
...
lite/operators/elementwise_ops.h
浏览文件 @
c754a38f
...
...
@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
SmartInferShape
()
override
;
bool
InferShapeImpl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite {
// bool CheckShape() const override;
// bool InferShape() const override;
// bool InferShape
Impl
() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...
...
lite/operators/expand_op.cc
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const {
return
true
;
}
bool
ExpandOpLite
::
InferShape
()
const
{
bool
ExpandOpLite
::
InferShape
Impl
()
const
{
DDim
out_dims
(
param_
.
X
->
dims
());
for
(
size_t
i
=
0
;
i
<
param_
.
expand_times
.
size
();
++
i
)
{
out_dims
[
i
]
*=
param_
.
expand_times
[
i
];
...
...
lite/operators/expand_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/fake_channel_wise_dequantize_max_abs.h
浏览文件 @
c754a38f
...
...
@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_dequantize_max_abs.h
浏览文件 @
c754a38f
...
...
@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h
浏览文件 @
c754a38f
...
...
@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_quantize_moving_avg_max_abs.h
浏览文件 @
c754a38f
...
...
@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_quantize_range_abs_max.h
浏览文件 @
c754a38f
...
...
@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fc_op.cc
浏览文件 @
c754a38f
...
...
@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const {
return
true
;
}
bool
FcOpLite
::
SmartInferShape
()
{
if
(
!
last_input_shapes
.
empty
()
&&
!
last_output_shapes
.
empty
())
{
if
(
last_input_shapes
[
0
]
==
param_
.
input
->
dims
()
&&
last_input_lods
[
0
]
==
param_
.
input
->
lod
())
{
param_
.
output
->
Resize
(
last_output_shapes
[
0
]);
param_
.
output
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
input
->
dims
());
last_input_lods
.
push_back
(
param_
.
input
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
output
->
dims
());
last_output_lods
.
push_back
(
param_
.
output
->
lod
());
return
true
;
}
bool
FcOpLite
::
InferShape
()
const
{
bool
FcOpLite
::
InferShapeImpl
()
const
{
const
auto
&
input_dims
=
param_
.
input
->
dims
();
const
auto
&
w_dims
=
param_
.
w
->
dims
();
int
in_num_col_dims
=
param_
.
in_num_col_dims
;
...
...
lite/operators/fc_op.h
浏览文件 @
c754a38f
...
...
@@ -35,8 +35,7 @@ class FcOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
SmartInferShape
()
override
;
bool
InferShapeImpl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/feed_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class FeedOp : public OpLite {
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/fetch_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class FetchOp : public OpLite {
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
...
...
lite/operators/fill_constant_batch_size_like_op.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const {
return
true
;
}
bool
FillConstantBatchSizeLikeOp
::
InferShape
()
const
{
bool
FillConstantBatchSizeLikeOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
output_dim
{
param_
.
shape
.
begin
(),
param_
.
shape
.
end
()};
if
(
param_
.
input_dim_idx
==
0
&&
!
param_
.
input
->
lod
().
empty
())
{
output_dim
[
param_
.
output_dim_idx
]
=
param_
.
input
->
lod
().
back
().
size
()
-
1
;
...
...
lite/operators/fill_constant_batch_size_like_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/fill_constant_op.cc
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const {
return
true
;
}
bool
FillConstantOp
::
InferShape
()
const
{
bool
FillConstantOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
out_shape
;
auto
shape_tensor
=
param_
.
shape_tensor
;
auto
shape_tensor_list
=
param_
.
shape_tensor_list
;
...
...
lite/operators/fill_constant_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/flatten_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const {
return
true
;
}
bool
FlattenOp
::
InferShape
()
const
{
bool
FlattenOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
out_lod
=
param_
.
output
->
mutable_lod
();
...
...
@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const {
return
true
;
}
bool
Flatten2Op
::
InferShape
()
const
{
FlattenOp
::
InferShape
();
bool
Flatten2Op
::
InferShape
Impl
()
const
{
FlattenOp
::
InferShape
Impl
();
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
...
...
lite/operators/flatten_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class FlattenOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/fusion_elementwise_activation_ops.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
return
true
;
}
bool
FusionElementwiseActivationOp
::
InferShape
()
const
{
bool
FusionElementwiseActivationOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
X
->
dims
().
size
()
>=
param_
.
Y
->
dims
().
size
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
...
...
@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
// return true;
// }
// bool FusionElementwiseActivationGradExplicitOp::InferShape() const {
// bool FusionElementwiseActivationGradExplicitOp::InferShape
Impl
() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// param_.Y_grad->Resize(param_.Y->dims());
// return true;
...
...
lite/operators/fusion_elementwise_activation_ops.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite {
// bool CheckShape() const override;
// bool InferShape() const override;
// bool InferShape
Impl
() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...
...
lite/operators/gather_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const {
return
true
;
}
bool
GatherOp
::
InferShape
()
const
{
bool
GatherOp
::
InferShape
Impl
()
const
{
auto
index_dims
=
param_
.
Index
->
dims
();
CHECK
(
index_dims
.
size
()
==
1
||
(
index_dims
.
size
()
==
2
&&
index_dims
[
1
]
==
1
))
...
...
lite/operators/gather_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class GatherOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/generate_proposals_op.cc
浏览文件 @
c754a38f
...
...
@@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const {
return
true
;
}
bool
GenerateProposalsOpLite
::
InferShape
()
const
{
bool
GenerateProposalsOpLite
::
InferShape
Impl
()
const
{
param_
.
RpnRois
->
Resize
(
std
::
vector
<
int64_t
>
({
-
1
,
4
}));
param_
.
RpnRoiProbs
->
Resize
(
std
::
vector
<
int64_t
>
({
-
1
,
1
}));
return
true
;
...
...
lite/operators/generate_proposals_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/grid_sampler_op.cc
浏览文件 @
c754a38f
...
...
@@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const {
return
true
;
}
bool
GridSamplerOp
::
InferShape
()
const
{
bool
GridSamplerOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
param_
.
out
->
Resize
(
x_dims
);
return
true
;
...
...
lite/operators/grid_sampler_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/gru_op.cc
浏览文件 @
c754a38f
...
...
@@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const {
return
true
;
}
bool
GRUOpLite
::
InferShape
()
const
{
bool
GRUOpLite
::
InferShape
Impl
()
const
{
const
auto
&
input_dims
=
param_
.
input
->
dims
();
const
auto
&
weight_dims
=
param_
.
weight
->
dims
();
int
frame_size
=
weight_dims
[
0
];
...
...
lite/operators/gru_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class GRUOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/gru_unit_op.cc
浏览文件 @
c754a38f
...
...
@@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const {
return
true
;
}
bool
GRUUnitOpLite
::
InferShape
()
const
{
bool
GRUUnitOpLite
::
InferShape
Impl
()
const
{
auto
input_dims
=
param_
.
input
->
dims
();
auto
hidden_prev_dims
=
param_
.
hidden_prev
->
dims
();
auto
weight_dims
=
param_
.
weight
->
dims
();
...
...
lite/operators/gru_unit_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/im2sequence_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ inline int Im2SeqOutputSize(
}
bool
Im2SequenceOp
::
CheckShape
()
const
{
return
true
;
}
bool
Im2SequenceOp
::
InferShape
()
const
{
bool
Im2SequenceOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/im2sequence_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/increment_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const {
return
true
;
}
bool
IncrementOp
::
InferShape
()
const
{
bool
IncrementOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
out_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/increment_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class IncrementOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/instance_norm_op.cc
浏览文件 @
c754a38f
...
...
@@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const {
return
true
;
}
bool
InstanceNormOp
::
InferShape
()
const
{
bool
InstanceNormOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
int64_t
batch_size
=
x_dims
[
0
];
int64_t
channel_size
=
x_dims
[
1
];
...
...
lite/operators/instance_norm_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/interpolate_op.cc
浏览文件 @
c754a38f
...
...
@@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const {
return
true
;
}
bool
InterpolateOp
::
InferShape
()
const
{
bool
InterpolateOp
::
InferShape
Impl
()
const
{
auto
X
=
param_
.
X
;
int
n
=
X
->
dims
()[
0
];
...
...
lite/operators/interpolate_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class InterpolateOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/io_copy_op.cc
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
y
);
return
true
;
}
bool
IoCopyOp
::
InferShape
()
const
{
bool
IoCopyOp
::
InferShape
Impl
()
const
{
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
return
true
;
}
...
...
lite/operators/io_copy_op.h
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ class IoCopyOp : public OpLite {
public:
explicit
IoCopyOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
Run
()
override
;
std
::
string
DebugString
()
const
override
;
...
...
lite/operators/is_empty_op.cc
浏览文件 @
c754a38f
...
...
@@ -21,7 +21,7 @@ namespace operators {
bool
IsEmptyOp
::
CheckShape
()
const
{
return
true
;
}
bool
IsEmptyOp
::
InferShape
()
const
{
return
true
;
}
bool
IsEmptyOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
IsEmptyOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
X
=
...
...
lite/operators/is_empty_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/layer_norm_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const {
return
true
;
}
bool
LayerNormOp
::
InferShape
()
const
{
bool
LayerNormOp
::
InferShape
Impl
()
const
{
auto
out_dims
=
param_
.
X
->
dims
();
param_
.
Y
->
Resize
(
out_dims
);
auto
inner_size
=
out_dims
.
Flatten2D
(
param_
.
begin_norm_axis
)[
0
];
...
...
lite/operators/layer_norm_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class LayerNormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/layout_op.cc
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
y
);
return
true
;
}
bool
LayoutOp
::
InferShape
()
const
{
bool
LayoutOp
::
InferShape
Impl
()
const
{
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
return
true
;
}
...
...
lite/operators/layout_op.h
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ class LayoutOp : public OpLite {
public:
explicit
LayoutOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
Run
()
override
;
std
::
string
DebugString
()
const
override
;
...
...
lite/operators/lod_reset_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const {
return
true
;
}
bool
LodResetOp
::
InferShape
()
const
{
bool
LodResetOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
...
...
lite/operators/lod_reset_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class LodResetOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/logical_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const {
return
true
;
}
bool
BinaryLogicalOp
::
InferShape
()
const
{
bool
BinaryLogicalOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
...
...
@@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const {
return
true
;
}
bool
UnaryLogicalOp
::
InferShape
()
const
{
bool
UnaryLogicalOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/logical_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lookup_table_dequant_op.cc
浏览文件 @
c754a38f
...
...
@@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const {
return
true
;
}
bool
LookupTableDequantOpLite
::
InferShape
()
const
{
bool
LookupTableDequantOpLite
::
InferShape
Impl
()
const
{
const
auto
&
table_dims
=
param_
.
W
->
dims
();
const
auto
&
ids_dims
=
param_
.
Ids
->
dims
();
...
...
lite/operators/lookup_table_dequant_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lookup_table_op.cc
浏览文件 @
c754a38f
...
...
@@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const {
return
true
;
}
bool
LookupTableOpLite
::
InferShape
()
const
{
bool
LookupTableOpLite
::
InferShape
Impl
()
const
{
const
auto
&
table_dims
=
param_
.
W
->
dims
();
const
auto
&
ids_dims
=
param_
.
Ids
->
dims
();
...
...
lite/operators/lookup_table_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lookup_table_v2_op.cc
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const {
return
true
;
}
bool
LookupTableV2OpLite
::
InferShape
()
const
{
bool
LookupTableV2OpLite
::
InferShape
Impl
()
const
{
auto
table_dims
=
param_
.
W
->
dims
();
auto
ids_dims
=
param_
.
Ids
->
dims
();
...
...
lite/operators/lookup_table_v2_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lrn_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const {
return
true
;
}
bool
LrnOpLite
::
InferShape
()
const
{
bool
LrnOpLite
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
...
...
lite/operators/lrn_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class LrnOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lstm_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const {
return
true
;
}
bool
LstmOp
::
InferShape
()
const
{
bool
LstmOp
::
InferShape
Impl
()
const
{
auto
in_dims
=
param_
.
Input
->
dims
();
if
(
param_
.
H0
)
{
CHECK
(
param_
.
C0
)
<<
"lstm must has H0 and C0 in the same time"
;
...
...
lite/operators/lstm_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class LstmOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/match_matrix_tensor_op.cc
浏览文件 @
c754a38f
...
...
@@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const {
return
true
;
}
bool
MatchMatrixTensorOpLite
::
InferShape
()
const
{
bool
MatchMatrixTensorOpLite
::
InferShape
Impl
()
const
{
const
Tensor
*
x
=
param_
.
x
;
const
Tensor
*
y
=
param_
.
y
;
DDim
x_dims
=
param_
.
x
->
dims
();
...
...
lite/operators/match_matrix_tensor_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/matmul_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const {
return
true
;
}
bool
MatMulOpLite
::
InferShape
()
const
{
bool
MatMulOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
y_dims
=
param_
.
Y
->
dims
();
bool
x_transpose
=
param_
.
transpose_X
;
...
...
lite/operators/matmul_op.h
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/mean_grad_op.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const {
return
true
;
}
bool
MeanGradOp
::
InferShape
()
const
{
bool
MeanGradOp
::
InferShape
Impl
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
...
...
lite/operators/mean_grad_op.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class MeanGradOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/mean_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const {
return
true
;
}
bool
MeanOp
::
InferShape
()
const
{
bool
MeanOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
std
::
vector
<
int64_t
>
{
1
});
return
true
;
}
...
...
lite/operators/mean_op.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class MeanOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/merge_lod_tensor_op.cc
浏览文件 @
c754a38f
...
...
@@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const {
return
true
;
}
bool
MergeLodTensorOpLite
::
InferShape
()
const
{
bool
MergeLodTensorOpLite
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
in_true
->
dims
();
param_
.
out
->
Resize
(
dims
);
return
true
;
...
...
lite/operators/merge_lod_tensor_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/mul_grad_op.cc
浏览文件 @
c754a38f
...
...
@@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const {
return
true
;
}
bool
MulGradOpLite
::
InferShape
()
const
{
bool
MulGradOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
y_dims
=
param_
.
y
->
dims
();
if
(
param_
.
x_grad
)
{
...
...
lite/operators/mul_grad_op.h
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/mul_op.cc
浏览文件 @
c754a38f
...
...
@@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const {
return
true
;
}
bool
MulOpLite
::
InferShape
()
const
{
bool
MulOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
y_dims
=
param_
.
y
->
dims
();
...
...
lite/operators/mul_op.h
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ class MulOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
...
...
lite/operators/multiclass_nms_op.cc
浏览文件 @
c754a38f
...
...
@@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const {
return
true
;
}
bool
MulticlassNmsOpLite
::
InferShape
()
const
{
bool
MulticlassNmsOpLite
::
InferShape
Impl
()
const
{
auto
box_dims
=
param_
.
bboxes
->
dims
();
auto
score_dims
=
param_
.
scores
->
dims
();
auto
score_size
=
score_dims
.
size
();
...
...
lite/operators/multiclass_nms_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/negative_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const {
return
true
;
}
bool
NegativeOpLite
::
InferShape
()
const
{
bool
NegativeOpLite
::
InferShape
Impl
()
const
{
lite
::
DDim
input_dims
;
input_dims
=
param_
.
X
->
dims
();
param_
.
Out
->
Resize
(
lite
::
DDim
(
input_dims
));
...
...
lite/operators/negative_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/norm_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool NormOp::CheckShape() const {
return
true
;
}
bool
NormOp
::
InferShape
()
const
{
bool
NormOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
out_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/norm_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class NormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/op_params.h
浏览文件 @
c754a38f
...
...
@@ -24,6 +24,7 @@
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h"
#include "lite/utils/variant.h"
/*
* This file contains all the argument parameter data structure for operators.
*/
...
...
@@ -32,6 +33,16 @@ namespace paddle {
namespace
lite
{
namespace
operators
{
struct
ParamBase
{
public:
const
std
::
vector
<
Tensor
*>*
input_tensor_ptrs
()
const
{
return
nullptr
;
}
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
return
nullptr
;
}
protected:
std
::
shared_ptr
<
std
::
vector
<
const
Tensor
*>>
input_tensor_ptrs_cache_
{
nullptr
};
std
::
shared_ptr
<
std
::
vector
<
Tensor
*>>
output_tensor_ptrs_cache_
{
nullptr
};
};
using
param_t
=
Any
;
#define WITH_INT8_CONFIG \
bool enable_int8{false}; \
...
...
@@ -41,38 +52,38 @@ using param_t = Any;
int bit_length{8};
/// ----------------------- Functional operators ------------------------------
struct
FeedParam
{
struct
FeedParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
>*
feed_list
{};
lite
::
Tensor
*
out
{};
int
col
;
};
struct
FetchParam
{
struct
FetchParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{};
std
::
vector
<
lite
::
Tensor
>*
fetch_list
{};
int
col
;
};
// Helper op for lite framework
struct
IoCopyParam
{
struct
IoCopyParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
int
process_type
{
0
};
};
struct
LayoutParam
{
struct
LayoutParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
int
process_type
{
0
};
};
struct
CalibParam
{
struct
CalibParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{};
lite
::
Tensor
*
output
{};
float
scale
;
};
struct
SubgraphParam
{
struct
SubgraphParam
:
ParamBase
{
std
::
vector
<
std
::
string
>
input_names
{};
std
::
vector
<
std
::
string
>
output_names
{};
std
::
vector
<
std
::
string
>
input_data_names
{};
...
...
@@ -84,7 +95,7 @@ struct SubgraphParam {
/// -------------------------- NN operators ------------------------------------
struct
FcParam
{
struct
FcParam
:
ParamBase
{
lite
::
Tensor
*
input
{
nullptr
};
lite
::
Tensor
*
w
{
nullptr
};
lite
::
Tensor
*
bias
{
nullptr
};
...
...
@@ -95,9 +106,24 @@ struct FcParam {
bool
padding_weights
{
false
};
// for int8
WITH_INT8_CONFIG
};
struct
SearchSeqFcParam
{
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
input
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
output
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
struct
SearchSeqFcParam
:
ParamBase
{
lite
::
Tensor
*
x
{
nullptr
};
lite
::
Tensor
*
w
{
nullptr
};
lite
::
Tensor
*
b
{
nullptr
};
...
...
@@ -106,7 +132,7 @@ struct SearchSeqFcParam {
};
// For Interpolate Op
struct
InterpolateParam
{
struct
InterpolateParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
OutSize
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -123,7 +149,7 @@ struct InterpolateParam {
};
// For Mul Op
struct
MulParam
{
struct
MulParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
output
{};
...
...
@@ -134,7 +160,7 @@ struct MulParam {
WITH_INT8_CONFIG
};
struct
MulGradParam
{
struct
MulGradParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
output_grad
{};
...
...
@@ -146,7 +172,7 @@ struct MulGradParam {
};
// For ReduceMean Op
struct
ReduceMeanParam
{
struct
ReduceMeanParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -155,7 +181,7 @@ struct ReduceMeanParam {
};
// For Stack Op
struct
StackParam
{
struct
StackParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
X
;
lite
::
Tensor
*
Out
{};
...
...
@@ -163,7 +189,7 @@ struct StackParam {
};
// For Power Op
struct
PowerParam
{
struct
PowerParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -172,7 +198,7 @@ struct PowerParam {
float
power
{};
};
struct
ShuffleChannelParam
{
struct
ShuffleChannelParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -180,7 +206,7 @@ struct ShuffleChannelParam {
};
// For Yolobox
struct
YoloBoxParam
{
struct
YoloBoxParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
ImgSize
{};
lite
::
Tensor
*
Boxes
{};
...
...
@@ -193,7 +219,7 @@ struct YoloBoxParam {
};
// For Scale Op
struct
ScaleParam
{
struct
ScaleParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
...
...
@@ -203,14 +229,29 @@ struct ScaleParam {
};
// For Softmax op
struct
SoftmaxParam
{
struct
SoftmaxParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
int
axis
{
-
1
};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
x
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
output
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
// For Reshape and Reshape2 Op
struct
ReshapeParam
{
struct
ReshapeParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
std
::
vector
<
const
lite
::
Tensor
*>
shape_tensor_vct
{};
const
lite
::
Tensor
*
shape_tensor
{};
...
...
@@ -222,7 +263,7 @@ struct ReshapeParam {
};
// For Concat op
struct
ConcatParam
{
struct
ConcatParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
x
{};
lite
::
Tensor
*
output
{};
int
axis
{
0
};
...
...
@@ -230,7 +271,7 @@ struct ConcatParam {
};
/// ----------------------- activation operators ----------------------
struct
ActivationParam
{
struct
ActivationParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
float
Leaky_relu_alpha
{
0
};
// leaky_relu param
float
Relu_clipped_coef
{
6
};
// relu_clipped param
...
...
@@ -245,7 +286,7 @@ struct ActivationParam {
lite_api
::
ActivationType
active_type
;
};
struct
ActivationGradParam
{
struct
ActivationGradParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Out
{};
// for backward
...
...
@@ -254,7 +295,7 @@ struct ActivationGradParam {
};
// For Convolution op
struct
ConvParam
{
struct
ConvParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
filter
{};
lite
::
Tensor
*
bias
{
nullptr
};
...
...
@@ -294,10 +335,26 @@ struct ConvParam {
std
::
vector
<
int
>
output_size
;
// for int8
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
x
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
output
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
// For BatchNorm op
struct
BatchNormParam
{
struct
BatchNormParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
bias
{};
lite
::
Tensor
*
scale
{};
...
...
@@ -316,7 +373,7 @@ struct BatchNormParam {
};
// For Pooling op
struct
PoolParam
{
struct
PoolParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
std
::
string
pooling_type
{
""
};
...
...
@@ -340,7 +397,7 @@ struct PoolParam {
};
// For Dropout op
struct
DropoutParam
{
struct
DropoutParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
mask
{};
...
...
@@ -352,7 +409,7 @@ struct DropoutParam {
};
// For Split op
struct
SplitParam
{
struct
SplitParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
std
::
vector
<
lite
::
Tensor
*>
output
{};
lite
::
Tensor
*
axis_tensor
;
...
...
@@ -364,7 +421,7 @@ struct SplitParam {
};
// For Transpose op
struct
TransposeParam
{
struct
TransposeParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
xshape
{};
...
...
@@ -375,7 +432,7 @@ struct TransposeParam {
};
/// ----------------------- element wise operators ----------------------
struct
ElementwiseParam
{
struct
ElementwiseParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -384,9 +441,24 @@ struct ElementwiseParam {
WITH_INT8_CONFIG
float
x_input_scale
{
1.0
};
float
y_input_scale
{
1.0
};
};
struct
ElementwiseGradParam
{
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
X
,
Y
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
Out
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
struct
ElementwiseGradParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
OutGrad
{};
...
...
@@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
};
/// ----------------------- mean operators ----------------------
struct
MeanParam
{
struct
MeanParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
struct
MeanGradParam
{
struct
MeanGradParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Out_grad
{};
// for backward
...
...
@@ -417,7 +489,7 @@ struct MeanGradParam {
};
/// ----------------------- fill_constant operators ----------------------
struct
FillConstantParam
{
struct
FillConstantParam
:
ParamBase
{
int
dtype
{
static_cast
<
int
>
(
VarDescAPI
::
VarDataType
::
FP32
)};
std
::
vector
<
int64_t
>
shape
{};
lite
::
Tensor
*
shape_tensor
{
nullptr
};
...
...
@@ -429,7 +501,7 @@ struct FillConstantParam {
lite
::
Tensor
*
out
{};
};
struct
FillConstantBatchSizeLikeParam
{
struct
FillConstantBatchSizeLikeParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{
nullptr
};
lite
::
Tensor
*
out
{
nullptr
};
...
...
@@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam {
};
//
struct
FakeQuantizeMovingAvgMaxAbsParam
{
struct
FakeQuantizeMovingAvgMaxAbsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
in_scale
{};
const
lite
::
Tensor
*
in_accum
{};
...
...
@@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam {
float
moving_rate
{
0.9
};
};
struct
FakeDequantizeMaxAbsParam
{
struct
FakeDequantizeMaxAbsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
in_scale
{};
lite
::
Tensor
*
out
{};
float
max_range
;
};
struct
FakeChannelWiseDequantizeMaxAbsParam
{
struct
FakeChannelWiseDequantizeMaxAbsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
std
::
vector
<
const
lite
::
Tensor
*>
scale_tensors
{};
lite
::
Tensor
*
out
{};
...
...
@@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam {
};
/// ----------------------- sgd operators ----------------------
struct
SGDParam
{
struct
SGDParam
:
ParamBase
{
int
dtype
{
static_cast
<
int
>
(
VarDescAPI
::
VarDataType
::
FP32
)};
const
lite
::
Tensor
*
Param
{};
...
...
@@ -482,7 +554,7 @@ struct SGDParam {
};
/// ----------------------- uniform_random operators ----------------------
struct
UniformRandomParam
{
struct
UniformRandomParam
:
ParamBase
{
std
::
vector
<
int64_t
>
shape
{};
float
min
{
-
1.0
f
};
float
max
{
1.0
f
};
...
...
@@ -491,12 +563,12 @@ struct UniformRandomParam {
lite
::
Tensor
*
Out
{};
};
/// ----------------------- negative operators --------------
struct
NegativeParam
{
struct
NegativeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
/// ----------------------- pad2d operators ----------------------
struct
Pad2dParam
{
struct
Pad2dParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
paddings
{
0
,
0
,
0
,
0
};
...
...
@@ -506,7 +578,7 @@ struct Pad2dParam {
};
/// ----------------------- Crop operators ----------------------
struct
CropParam
{
struct
CropParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
offsets
;
...
...
@@ -514,21 +586,21 @@ struct CropParam {
};
///----------------------- argmax operators ----------------------
struct
ArgmaxParam
{
struct
ArgmaxParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
int
Axis
{
0
};
};
///----------------------- axpy operators ----------------------
struct
AxpyParam
{
struct
AxpyParam
:
ParamBase
{
lite
::
Tensor
*
Scale
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Bias
{};
lite
::
Tensor
*
Out
{};
};
/// ----------------------- GRU unit operators ----------------------f
struct
GRUUnitParam
{
struct
GRUUnitParam
:
ParamBase
{
enum
ActType
{
identity
,
sigmoid
,
tanh
,
relu
};
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
hidden_prev
{
nullptr
};
...
...
@@ -544,7 +616,7 @@ struct GRUUnitParam {
};
/// ------------------------------ lrn operators ------------------------------
struct
LrnParam
{
struct
LrnParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
int
n
{
5
};
...
...
@@ -555,7 +627,7 @@ struct LrnParam {
};
/// ----------------------- decode_bboxes operators ----------------------
struct
DecodeBboxesParam
{
struct
DecodeBboxesParam
:
ParamBase
{
const
lite
::
Tensor
*
loc_data
{};
const
lite
::
Tensor
*
prior_data
{};
lite
::
Tensor
*
bbox_data
{};
...
...
@@ -571,7 +643,7 @@ struct DecodeBboxesParam {
};
/// ----------------------- box_coder operators ----------------------
struct
BoxCoderParam
{
struct
BoxCoderParam
:
ParamBase
{
const
lite
::
Tensor
*
prior_box
{};
const
lite
::
Tensor
*
prior_box_var
{};
const
lite
::
Tensor
*
target_box
{};
...
...
@@ -584,7 +656,7 @@ struct BoxCoderParam {
};
/// ----------------------- multiclass_nms operators ----------------------
struct
MulticlassNmsParam
{
struct
MulticlassNmsParam
:
ParamBase
{
const
lite
::
Tensor
*
bboxes
{};
const
lite
::
Tensor
*
scores
{};
lite
::
Tensor
*
out
{};
...
...
@@ -599,7 +671,7 @@ struct MulticlassNmsParam {
};
/// ----------------------- priorbox operators ----------------------
struct
PriorBoxParam
{
struct
PriorBoxParam
:
ParamBase
{
lite
::
Tensor
*
input
{};
lite
::
Tensor
*
image
{};
lite
::
Tensor
*
boxes
{};
...
...
@@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam {
std
::
vector
<
int
>
density_sizes
;
};
/// ----------------------- GRU operators ----------------------f
struct
GRUParam
{
struct
GRUParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
h0
{
nullptr
};
const
lite
::
Tensor
*
weight
{
nullptr
};
...
...
@@ -645,7 +717,7 @@ struct GRUParam {
};
/// ----------------------- BeamSearchDecode operators ----------------------f
struct
BeamSearchDecodeParam
{
struct
BeamSearchDecodeParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
>*
ids
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
scores
{
nullptr
};
lite
::
Tensor
*
sentence_ids
{
nullptr
};
...
...
@@ -655,21 +727,21 @@ struct BeamSearchDecodeParam {
};
/// ----------------------- LookupTable operators ----------------------f
struct
LookupTableParam
{
struct
LookupTableParam
:
ParamBase
{
const
lite
::
Tensor
*
W
{
nullptr
};
const
lite
::
Tensor
*
Ids
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
int64_t
padding_idx
{
-
1
};
};
struct
LookupTableDequantParam
{
struct
LookupTableDequantParam
:
ParamBase
{
lite
::
Tensor
*
W
{
nullptr
};
lite
::
Tensor
*
Ids
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
int64_t
padding_idx
{
-
1
};
};
struct
Im2SequenceParam
{
struct
Im2SequenceParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -679,19 +751,19 @@ struct Im2SequenceParam {
std
::
vector
<
int
>
out_strides
{
1
,
1
};
};
struct
SequenceSoftmaxParam
{
struct
SequenceSoftmaxParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
struct
NormParam
{
struct
NormParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Norm
{};
int
axis
{
1
};
float
epsilon
{
1e-10
};
};
struct
LayerNormParam
{
struct
LayerNormParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Scale
{};
const
lite
::
Tensor
*
Bias
{};
...
...
@@ -702,13 +774,13 @@ struct LayerNormParam {
float
epsilon
{
1e-5
};
};
struct
LogicalParam
{
struct
LogicalParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
};
struct
CompareParam
{
struct
CompareParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
bool
force_cpu
{
0
};
...
...
@@ -716,7 +788,7 @@ struct CompareParam {
lite
::
Tensor
*
Out
{};
};
struct
WhileParam
{
struct
WhileParam
:
ParamBase
{
Scope
*
scope
{};
Tensor
*
cond
{};
cpp
::
BlockDesc
*
sub_block
{};
...
...
@@ -724,32 +796,32 @@ struct WhileParam {
std
::
vector
<
Tensor
*>
outs
{};
};
struct
TopkParam
{
struct
TopkParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Indices
{};
int
K
{
1
};
};
struct
IncrementParam
{
struct
IncrementParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
float
step
{
1
};
};
struct
WriteToArrayParam
{
struct
WriteToArrayParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{
nullptr
};
const
lite
::
Tensor
*
I
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
Out
{
nullptr
};
};
struct
ReadFromArrayParam
{
struct
ReadFromArrayParam
:
ParamBase
{
const
std
::
vector
<
lite
::
Tensor
>*
X
{
nullptr
};
const
lite
::
Tensor
*
I
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
};
struct
BeamSearchParam
{
struct
BeamSearchParam
:
ParamBase
{
const
lite
::
Tensor
*
pre_ids
{};
const
lite
::
Tensor
*
pre_scores
{};
const
lite
::
Tensor
*
ids
{};
...
...
@@ -763,7 +835,7 @@ struct BeamSearchParam {
bool
is_accumulated
;
};
struct
SequencePoolParam
{
struct
SequencePoolParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
std
::
string
pool_type
{
"AVERAGE"
};
...
...
@@ -773,7 +845,7 @@ struct SequencePoolParam {
#endif
};
struct
SequenceConvParam
{
struct
SequenceConvParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Filter
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -782,13 +854,13 @@ struct SequenceConvParam {
int
contextLength
;
};
struct
SequencePoolConcatParam
{
struct
SequencePoolConcatParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
std
::
string
>
pool_type
{};
};
struct
SearchGroupPaddingParam
{
struct
SearchGroupPaddingParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out_emb_padding
{};
lite
::
Tensor
*
out_new
{};
...
...
@@ -796,36 +868,36 @@ struct SearchGroupPaddingParam {
int
pad_id
;
};
struct
SequenceReshapeParam
{
struct
SequenceReshapeParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
int
new_dim
;
};
struct
SequenceExpandParam
{
struct
SequenceExpandParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
int
ref_level
{
-
1
};
};
struct
SequenceExpandAsParam
{
struct
SequenceExpandAsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{
nullptr
};
const
lite
::
Tensor
*
y
{
nullptr
};
lite
::
Tensor
*
out
{
nullptr
};
};
struct
SequenceReverseParam
{
struct
SequenceReverseParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
struct
SequenceConcatParam
{
struct
SequenceConcatParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
X
{};
lite
::
Tensor
*
Out
{};
};
struct
AttentionPaddingMaskParam
{
struct
AttentionPaddingMaskParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
int
pad_id
;
...
...
@@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam {
lite
::
Tensor
*
pad_begin
{};
};
struct
SequenceArithmeticParam
{
struct
SequenceArithmeticParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
int
op_type
{
1
};
lite
::
Tensor
*
Out
{};
};
struct
ReduceMaxParam
{
struct
ReduceMaxParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
dim
{};
bool
keep_dim
{
false
};
};
struct
LodResetParam
{
struct
LodResetParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -856,12 +928,12 @@ struct LodResetParam {
bool
append
;
};
struct
IsEmptyParam
{
struct
IsEmptyParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
struct
ReduceParam
{
struct
ReduceParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
std
::
vector
<
int
>
dim
{
0
};
...
...
@@ -869,7 +941,7 @@ struct ReduceParam {
bool
reduce_all
{
false
};
};
struct
VarConv2DParam
{
struct
VarConv2DParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
ROW
{};
const
lite
::
Tensor
*
COLUMN
{};
...
...
@@ -888,19 +960,19 @@ struct VarConv2DParam {
};
/// ----------------------- shape operators ----------------------
struct
ShapeParam
{
struct
ShapeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
struct
CastParam
{
struct
CastParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
int
out_dtype
{
2
};
int
in_dtype
{
2
};
};
struct
SliceParam
{
struct
SliceParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
axes
{};
...
...
@@ -914,7 +986,7 @@ struct SliceParam {
lite
::
Tensor
*
EndsTensor
{
nullptr
};
};
struct
AffineChannelParam
{
struct
AffineChannelParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
// X is 4D tensor
const
lite
::
Tensor
*
Scale
{};
const
lite
::
Tensor
*
Bias
{};
...
...
@@ -922,7 +994,7 @@ struct AffineChannelParam {
lite
::
Tensor
*
Out
{};
};
struct
AnchorGeneratorParam
{
struct
AnchorGeneratorParam
:
ParamBase
{
const
lite
::
Tensor
*
Input
{};
std
::
vector
<
float
>
anchor_sizes
{};
std
::
vector
<
float
>
aspect_ratios
{};
...
...
@@ -934,7 +1006,7 @@ struct AnchorGeneratorParam {
lite
::
Tensor
*
Variances
{};
};
struct
GenerateProposalsParam
{
struct
GenerateProposalsParam
:
ParamBase
{
// inputs
const
lite
::
Tensor
*
Scores
{};
const
lite
::
Tensor
*
BboxDeltas
{};
...
...
@@ -954,14 +1026,14 @@ struct GenerateProposalsParam {
lite
::
Tensor
*
RpnRoiProbs
{};
};
/// ----------------------- squeeze operators ----------------------
struct
SqueezeParam
{
struct
SqueezeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
XShape
{};
std
::
vector
<
int
>
axes
{};
};
struct
UnsqueezeParam
{
struct
UnsqueezeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
XShape
{};
...
...
@@ -971,14 +1043,14 @@ struct UnsqueezeParam {
};
/// ----------------------- expand operators ----------------------
struct
ExpandParam
{
struct
ExpandParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
expand_times
{};
};
/// ----------------------- matmul operators ----------------------
struct
MatMulParam
{
struct
MatMulParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -987,20 +1059,20 @@ struct MatMulParam {
float
alpha
{
1.0
f
};
};
struct
GatherParam
{
struct
GatherParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Index
{};
lite
::
Tensor
*
Out
{};
};
/// ----------------------- assign operators -----------------------
struct
AssignParam
{
struct
AssignParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
};
/// ----------------------- roi_align operators -----------------------
struct
RoiAlignParam
{
struct
RoiAlignParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
ROIs
{};
lite
::
Tensor
*
Out
{};
...
...
@@ -1011,13 +1083,13 @@ struct RoiAlignParam {
};
/// ----------------------- box_clip operators -----------------------
struct
BoxClipParam
{
struct
BoxClipParam
:
ParamBase
{
const
lite
::
Tensor
*
Input
{};
const
lite
::
Tensor
*
ImInfo
{};
lite
::
Tensor
*
Output
{};
};
struct
RangeParam
{
struct
RangeParam
:
ParamBase
{
const
lite
::
Tensor
*
Start
;
const
lite
::
Tensor
*
End
;
const
lite
::
Tensor
*
Step
;
...
...
@@ -1025,7 +1097,7 @@ struct RangeParam {
};
/// ----------------------- assign_value operators -----------------------
struct
AssignValueParam
{
struct
AssignValueParam
:
ParamBase
{
std
::
vector
<
int
>
shape
{};
int
dtype
{};
std
::
vector
<
float
>
fp32_values
{};
...
...
@@ -1034,7 +1106,7 @@ struct AssignValueParam {
};
/// --------------- sequence_topk_avg_pooling operators ------------------
struct
SequenceTopkAvgPoolingParam
{
struct
SequenceTopkAvgPoolingParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
ROW
{};
const
lite
::
Tensor
*
COLUMN
{};
...
...
@@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam {
};
/// --------------- search_fc operators ------------------
struct
SearchFcParam
{
struct
SearchFcParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
W
{};
const
lite
::
Tensor
*
b
{};
...
...
@@ -1053,7 +1125,7 @@ struct SearchFcParam {
int
out_size
{};
};
/// --------------------- match_matrix_tensor operators --------------------
struct
MatchMatrixTensorParam
{
struct
MatchMatrixTensorParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
w
{};
...
...
@@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam {
};
/// --------------------- search_seq_depadding operators --------------------
struct
SearchSeqDepaddingParam
{
struct
SearchSeqDepaddingParam
:
ParamBase
{
const
lite
::
Tensor
*
pad
{};
const
lite
::
Tensor
*
src
{};
lite
::
Tensor
*
out
{};
};
/// --------------------- search_grnn operators --------------------
struct
SearchGrnnParam
{
struct
SearchGrnnParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
wi
{};
const
lite
::
Tensor
*
wh
{};
...
...
@@ -1084,7 +1156,7 @@ struct SearchGrnnParam {
lite
::
Tensor
*
layout_input
{};
};
struct
SplitLodTensorParam
{
struct
SplitLodTensorParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
mask
{};
lite
::
Tensor
*
out_true
{};
...
...
@@ -1092,7 +1164,7 @@ struct SplitLodTensorParam {
int
level
{};
};
struct
MergeLodTensorParam
{
struct
MergeLodTensorParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
mask
{};
const
lite
::
Tensor
*
in_true
{};
...
...
@@ -1101,7 +1173,7 @@ struct MergeLodTensorParam {
int
level
{};
};
struct
ConditionalBlockParam
{
struct
ConditionalBlockParam
:
ParamBase
{
const
lite
::
Tensor
*
cond
{};
std
::
vector
<
lite
::
Tensor
*>
x
{};
std
::
vector
<
lite
::
Tensor
*>
outs
{};
...
...
@@ -1110,14 +1182,14 @@ struct ConditionalBlockParam {
bool
is_scalar_condition
{};
};
struct
CollectFpnProposalsParam
{
struct
CollectFpnProposalsParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
multi_level_rois
{};
std
::
vector
<
lite
::
Tensor
*>
multi_level_scores
{};
lite
::
Tensor
*
fpn_rois
{};
int
post_nms_topN
{};
};
struct
DistributeFpnProposalsParam
{
struct
DistributeFpnProposalsParam
:
ParamBase
{
const
lite
::
Tensor
*
fpn_rois
{};
std
::
vector
<
lite
::
Tensor
*>
multi_fpn_rois
{};
lite
::
Tensor
*
restore_index
{};
...
...
@@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam {
};
/// --------------------- instance_norm operators --------------------
struct
InstanceNormParam
{
struct
InstanceNormParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
bias
{};
...
...
@@ -1138,12 +1210,12 @@ struct InstanceNormParam {
float
epsilon
;
};
/// --------------------- grid sampler operators --------------------
struct
GridSamplerParam
{
struct
GridSamplerParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
grid
{};
};
struct
LstmParam
{
struct
LstmParam
:
ParamBase
{
lite
::
Tensor
*
Input
{};
lite
::
Tensor
*
Weight
{};
lite
::
Tensor
*
Bias
{};
...
...
@@ -1160,7 +1232,7 @@ struct LstmParam {
std
::
string
candidate_activation
;
};
struct
CrfDecodingParam
{
struct
CrfDecodingParam
:
ParamBase
{
lite
::
Tensor
*
emission
{};
lite
::
Tensor
*
transition
{};
lite
::
Tensor
*
label
{};
...
...
lite/operators/pad2d_op.cc
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const {
return
true
;
}
bool
Pad2dOpLite
::
InferShape
()
const
{
bool
Pad2dOpLite
::
InferShape
Impl
()
const
{
// nchw
auto
x_dims
=
param_
.
X
->
dims
();
int
out_h
=
x_dims
[
2
]
+
param_
.
paddings
[
0
]
+
param_
.
paddings
[
1
];
...
...
lite/operators/pad2d_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/pool_op.cc
浏览文件 @
c754a38f
...
...
@@ -60,7 +60,7 @@ int PoolOutputSize(int input_size,
return
output_size
;
}
bool
PoolOpLite
::
InferShape
()
const
{
bool
PoolOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
int
>&
ksize
=
param_
.
ksize
;
// dynamic update 4-pad
...
...
lite/operators/pool_op.h
浏览文件 @
c754a38f
...
...
@@ -37,7 +37,7 @@ class PoolOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
...
...
lite/operators/power_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const {
return
true
;
}
bool
PowerOp
::
InferShape
()
const
{
bool
PowerOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
...
...
lite/operators/power_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class PowerOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/prior_box_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const {
return
true
;
}
bool
PriorBoxOpLite
::
InferShape
()
const
{
return
true
;
}
bool
PriorBoxOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
PriorBoxOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
input
=
opdesc
.
Input
(
"Input"
).
front
();
...
...
lite/operators/prior_box_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/range_op.cc
浏览文件 @
c754a38f
...
...
@@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) {
:
std
::
ceil
(
std
::
abs
((
end
-
start
)
/
step
));
}
bool
RangeOpLite
::
InferShape
()
const
{
bool
RangeOpLite
::
InferShape
Impl
()
const
{
int
start
=
param_
.
Start
->
data
<
float
>
()[
0
];
int
end
=
param_
.
End
->
data
<
float
>
()[
0
];
int
step
=
param_
.
Step
->
data
<
float
>
()[
0
];
...
...
lite/operators/range_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class RangeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/read_from_array_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const {
return
true
;
}
bool
ReadFromArrayOp
::
InferShape
()
const
{
bool
ReadFromArrayOp
::
InferShape
Impl
()
const
{
int
id
=
param_
.
I
->
data
<
int64_t
>
()[
0
];
auto
out_dims
=
(
*
param_
.
X
)[
id
].
dims
();
param_
.
Out
->
Resize
(
out_dims
);
...
...
lite/operators/read_from_array_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/reduce_max_op.cc
浏览文件 @
c754a38f
...
...
@@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const {
return
true
;
}
bool
ReduceMaxOp
::
InferShape
()
const
{
bool
ReduceMaxOp
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
dim
;
auto
x_dims
=
param_
.
X
->
dims
();
bool
reduce_all
=
false
;
...
...
lite/operators/reduce_max_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite {
ReduceMaxOp
()
{}
explicit
ReduceMaxOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/reduce_mean_op.cc
浏览文件 @
c754a38f
...
...
@@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const {
return
true
;
}
bool
ReduceMeanOp
::
InferShape
()
const
{
bool
ReduceMeanOp
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
dim
;
auto
x_dims
=
param_
.
X
->
dims
();
bool
reduce_all
=
false
;
...
...
lite/operators/reduce_mean_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite {
ReduceMeanOp
()
{}
explicit
ReduceMeanOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/reduce_ops.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const {
return
true
;
}
bool
ReduceOp
::
InferShape
()
const
{
bool
ReduceOp
::
InferShape
Impl
()
const
{
const
auto
&
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
auto
dims
=
param_
.
dim
;
...
...
lite/operators/reduce_ops.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class ReduceOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/reduce_prod_op.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const {
return
true
;
}
bool
ReduceProdOpLite
::
InferShape
()
const
{
bool
ReduceProdOpLite
::
InferShape
Impl
()
const
{
auto
x
=
param_
.
x
;
auto
out
=
param_
.
output
;
std
::
vector
<
int
>
dim
=
param_
.
dim
;
...
...
lite/operators/reduce_prod_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/relu_op.cc
浏览文件 @
c754a38f
...
...
@@ -20,7 +20,7 @@ namespace lite {
namespace
operators
{
bool
ReluOp
::
CheckShape
()
const
{
return
true
;
}
bool
ReluOp
::
InferShape
()
const
{
bool
ReluOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
...
...
lite/operators/relu_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class ReluOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/reshape_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const {
return
true
;
}
bool
ReshapeOp
::
InferShape
()
const
{
bool
ReshapeOp
::
InferShape
Impl
()
const
{
const
auto
&
shape_tensor_vct
=
param_
.
shape_tensor_vct
;
auto
*
shape_tensor
=
param_
.
shape_tensor
;
const
auto
&
shape_vct
=
param_
.
shape_vct
;
...
...
@@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const {
return
true
;
}
bool
Reshape2Op
::
InferShape
()
const
{
ReshapeOp
::
InferShape
();
bool
Reshape2Op
::
InferShape
Impl
()
const
{
ReshapeOp
::
InferShape
Impl
();
const
auto
&
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
...
...
lite/operators/reshape_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class ReshapeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/roi_align_op.cc
浏览文件 @
c754a38f
...
...
@@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const {
return
true
;
}
bool
RoiAlignOpLite
::
InferShape
()
const
{
bool
RoiAlignOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
X
->
dims
();
auto
rois_dims
=
param_
.
ROIs
->
dims
();
...
...
lite/operators/roi_align_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/scale_op.cc
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const {
return
true
;
}
bool
ScaleOp
::
InferShape
()
const
{
bool
ScaleOp
::
InferShape
Impl
()
const
{
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
return
true
;
}
...
...
lite/operators/scale_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class ScaleOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_aligned_mat_mul_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const {
return
true
;
}
bool
SearchAlignedMatMulOpLite
::
InferShape
()
const
{
bool
SearchAlignedMatMulOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
y_dims
=
param_
.
Y
->
dims
();
const
auto
&
x_lod
=
param_
.
X
->
lod
();
...
...
lite/operators/search_aligned_mat_mul_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/search_fc_op.cc
浏览文件 @
c754a38f
...
...
@@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const {
return
true
;
}
bool
SearchFcOpLite
::
InferShape
()
const
{
bool
SearchFcOpLite
::
InferShape
Impl
()
const
{
auto
out_size
=
param_
.
out_size
;
lite
::
DDim
dims
(
std
::
vector
<
int64_t
>
({
-
1
,
out_size
}));
param_
.
Out
->
Resize
(
dims
);
...
...
lite/operators/search_fc_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_grnn_op.cc
浏览文件 @
c754a38f
...
...
@@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const {
return
true
;
}
bool
SearchGrnnOpLite
::
InferShape
()
const
{
bool
SearchGrnnOpLite
::
InferShape
Impl
()
const
{
const
auto
&
x_dims
=
param_
.
x
->
dims
();
const
auto
&
x_lod
=
param_
.
x
->
lod
();
CHECK_OR_FALSE
(
!
x_lod
.
empty
());
...
...
lite/operators/search_grnn_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_group_padding_op.cc
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const {
return
true
;
}
bool
SearchGroupPaddingOp
::
InferShape
()
const
{
bool
SearchGroupPaddingOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
x_dims
=
param_
.
x
->
dims
().
Vectorize
();
param_
.
out_emb_padding
->
Resize
({
-
1
,
x_dims
[
1
]});
...
...
lite/operators/search_group_padding_op.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite {
SearchGroupPaddingOp
()
{}
explicit
SearchGroupPaddingOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"search_group_padding"
;
}
...
...
lite/operators/search_seq_depadding_op.cc
浏览文件 @
c754a38f
...
...
@@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const {
return
true
;
}
bool
SearchSeqDepaddingOpLite
::
InferShape
()
const
{
bool
SearchSeqDepaddingOpLite
::
InferShape
Impl
()
const
{
DDim
pad_dims
=
param_
.
pad
->
dims
();
param_
.
out
->
Resize
({
-
1
,
pad_dims
[
1
]});
return
true
;
...
...
lite/operators/search_seq_depadding_op.h
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_seq_fc_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const {
return
true
;
}
bool
SearchSeqFcOpLite
::
InferShape
()
const
{
bool
SearchSeqFcOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
w_dims
=
param_
.
w
->
dims
();
const
auto
&
x_lod
=
param_
.
x
->
lod
();
...
...
lite/operators/search_seq_fc_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/search_seq_softmax_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const {
return
true
;
}
bool
SearchSeqSoftmaxOp
::
InferShape
()
const
{
bool
SearchSeqSoftmaxOp
::
InferShape
Impl
()
const
{
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
param_
.
output
->
set_lod
(
param_
.
x
->
lod
());
return
true
;
...
...
lite/operators/search_seq_softmax_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_arithmetic_op.cc
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const {
return
true
;
}
bool
SequenceArithmeticOp
::
InferShape
()
const
{
bool
SequenceArithmeticOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
set_lod
(
param_
.
X
->
lod
());
return
true
;
...
...
lite/operators/sequence_arithmetic_op.h
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_concat_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const {
return
true
;
}
bool
SequenceConcatOp
::
InferShape
()
const
{
return
true
;
}
bool
SequenceConcatOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
SequenceConcatOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
...
...
lite/operators/sequence_concat_op.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite {
SequenceConcatOp
()
{}
explicit
SequenceConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"sequence_concat"
;
}
...
...
lite/operators/sequence_conv_op.cc
浏览文件 @
c754a38f
...
...
@@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const {
return
true
;
}
bool
SequenceConvOp
::
InferShape
()
const
{
bool
SequenceConvOp
::
InferShape
Impl
()
const
{
const
auto
*
input
=
param_
.
X
;
const
auto
*
filter
=
param_
.
Filter
;
auto
in_dims
=
input
->
dims
();
...
...
lite/operators/sequence_conv_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite {
SequenceConvOp
()
{}
explicit
SequenceConvOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/sequence_expand_as_op.cc
浏览文件 @
c754a38f
...
...
@@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const {
return
true
;
}
bool
SequenceExpandAsOpLite
::
InferShape
()
const
{
bool
SequenceExpandAsOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
y_lod
=
param_
.
y
->
lod
();
auto
out_dims
=
x_dims
;
...
...
lite/operators/sequence_expand_as_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_expand_op.cc
浏览文件 @
c754a38f
...
...
@@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const {
return
true
;
}
bool
SequenceExpandOp
::
InferShape
()
const
{
bool
SequenceExpandOp
::
InferShape
Impl
()
const
{
const
auto
x_lod
=
param_
.
X
->
lod
();
auto
x_dims
=
param_
.
X
->
dims
();
int
ref_level
=
param_
.
ref_level
;
...
...
lite/operators/sequence_expand_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_pool_concat_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const {
return
true
;
}
bool
SequencePoolConcatOp
::
InferShape
()
const
{
bool
SequencePoolConcatOp
::
InferShape
Impl
()
const
{
int
out_dim
=
0
;
for
(
int
i
=
0
;
i
<
param_
.
X
.
size
();
++
i
)
{
out_dim
+=
param_
.
X
[
i
]
->
dims
().
count
(
1
,
param_
.
X
[
i
]
->
dims
().
size
());
...
...
lite/operators/sequence_pool_concat_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite {
SequencePoolConcatOp
()
{}
explicit
SequencePoolConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/sequence_pool_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const {
return
true
;
}
bool
SequencePoolOp
::
InferShape
()
const
{
bool
SequencePoolOp
::
InferShape
Impl
()
const
{
const
auto
*
input
=
param_
.
X
;
auto
out_dims
=
input
->
dims
();
out_dims
[
0
]
=
input
->
lod
()[
0
].
size
()
-
1
;
...
...
lite/operators/sequence_pool_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite {
SequencePoolOp
()
{}
explicit
SequencePoolOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/sequence_reshape_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const {
return
true
;
}
bool
SequenceReshapeOp
::
InferShape
()
const
{
bool
SequenceReshapeOp
::
InferShape
Impl
()
const
{
int
new_dim
=
param_
.
new_dim
;
auto
x_numel
=
param_
.
x
->
dims
().
production
();
std
::
vector
<
int64_t
>
out_shape
{
x_numel
/
new_dim
,
...
...
lite/operators/sequence_reshape_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_reverse_op.cc
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const {
return
true
;
}
bool
SequenceReverseOp
::
InferShape
()
const
{
bool
SequenceReverseOp
::
InferShape
Impl
()
const
{
const
auto
*
input
=
param_
.
X
;
auto
out_dims
=
input
->
dims
();
param_
.
Out
->
Resize
(
out_dims
);
...
...
lite/operators/sequence_reverse_op.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite {
SequenceReverseOp
()
{}
explicit
SequenceReverseOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"sequence_reverse"
;
}
...
...
lite/operators/sequence_softmax_op.cc
浏览文件 @
c754a38f
...
...
@@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
}
bool
SequenceSoftmaxOp
::
InferShape
()
const
{
bool
SequenceSoftmaxOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/sequence_softmax_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_topk_avg_pooling_op.cc
浏览文件 @
c754a38f
...
...
@@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const {
return
true
;
}
bool
SequenceTopkAvgPoolingOpLite
::
InferShape
()
const
{
bool
SequenceTopkAvgPoolingOpLite
::
InferShape
Impl
()
const
{
int
channel_num
=
param_
.
channel_num
;
std
::
vector
<
int
>
topks
=
param_
.
topks
;
auto
row_dim
=
param_
.
ROW
->
dims
();
...
...
lite/operators/sequence_topk_avg_pooling_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sgd_op.cc
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const {
return
true
;
}
bool
SGDOpLite
::
InferShape
()
const
{
bool
SGDOpLite
::
InferShape
Impl
()
const
{
param_
.
ParamOut
->
Resize
(
param_
.
Param
->
dims
());
return
true
;
}
...
...
lite/operators/sgd_op.h
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ class SGDOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/shape_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const {
return
true
;
}
bool
ShapeOpLite
::
InferShape
()
const
{
bool
ShapeOpLite
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
shape_vec
;
shape_vec
.
push_back
(
static_cast
<
int64_t
>
(
param_
.
X
->
dims
().
size
()));
param_
.
Out
->
Resize
(
shape_vec
);
...
...
lite/operators/shape_op.h
浏览文件 @
c754a38f
...
...
@@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/shuffle_channel_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const {
return
true
;
}
bool
ShuffleChannelOpLite
::
InferShape
()
const
{
bool
ShuffleChannelOpLite
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
}
...
...
lite/operators/shuffle_channel_op.h
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/slice_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const {
return
true
;
}
bool
SliceOp
::
InferShape
()
const
{
bool
SliceOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
auto
in_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/slice_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class SliceOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/softmax_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const {
return
true
;
}
bool
SoftmaxOp
::
SmartInferShape
()
{
if
(
!
last_input_shapes
.
empty
()
&&
!
last_output_shapes
.
empty
())
{
if
(
param_
.
x
->
dims
()
==
last_input_shapes
[
0
]
&&
param_
.
x
->
lod
()
==
last_input_lods
[
0
])
{
param_
.
output
->
Resize
(
last_output_shapes
[
0
]);
param_
.
output
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
x
->
dims
());
last_input_lods
.
push_back
(
param_
.
x
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
output
->
dims
());
last_output_lods
.
push_back
(
param_
.
output
->
lod
());
return
true
;
}
bool
SoftmaxOp
::
InferShape
()
const
{
bool
SoftmaxOp
::
InferShapeImpl
()
const
{
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
auto
out_lod
=
param_
.
output
->
mutable_lod
();
*
out_lod
=
param_
.
x
->
lod
();
...
...
lite/operators/softmax_op.h
浏览文件 @
c754a38f
...
...
@@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
SmartInferShape
()
override
;
bool
InferShapeImpl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/split_lod_tensor_op.cc
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const {
return
true
;
}
bool
SplitLodTensorOpLite
::
InferShape
()
const
{
bool
SplitLodTensorOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
param_
.
out_true
->
Resize
(
x_dims
);
param_
.
out_false
->
Resize
(
x_dims
);
...
...
lite/operators/split_lod_tensor_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/split_op.cc
浏览文件 @
c754a38f
...
...
@@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const {
return
true
;
}
bool
SplitOp
::
InferShape
()
const
{
bool
SplitOp
::
InferShape
Impl
()
const
{
const
auto
&
outs
=
param_
.
output
;
auto
in_dims
=
param_
.
x
->
dims
();
int
axis
=
param_
.
axis
;
...
...
lite/operators/split_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class SplitOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/squeeze_op.cc
浏览文件 @
c754a38f
...
...
@@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const {
return
true
;
}
bool
SqueezeOp
::
InferShape
()
const
{
bool
SqueezeOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int
>
squeeze_dims
=
param_
.
axes
;
DDim
in_dims
=
param_
.
X
->
dims
();
DDim
out_dim
=
GetOutputShape
(
squeeze_dims
,
in_dims
,
true
);
...
...
@@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const {
return
true
;
}
bool
Squeeze2Op
::
InferShape
()
const
{
SqueezeOp
::
InferShape
();
bool
Squeeze2Op
::
InferShape
Impl
()
const
{
SqueezeOp
::
InferShape
Impl
();
auto
x_dims
=
param_
.
X
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
1
);
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
...
...
lite/operators/squeeze_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class SqueezeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/stack_op.cc
浏览文件 @
c754a38f
...
...
@@ -32,7 +32,7 @@ bool StackOp::CheckShape() const {
return
true
;
}
bool
StackOp
::
InferShape
()
const
{
bool
StackOp
::
InferShape
Impl
()
const
{
auto
input
=
param_
.
X
;
auto
input_dims
=
input
[
0
]
->
dims
();
int
axis
=
param_
.
axis
;
...
...
lite/operators/stack_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class StackOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/subgraph_op.cc
浏览文件 @
c754a38f
...
...
@@ -22,7 +22,7 @@ namespace operators {
bool
SubgraphOp
::
CheckShape
()
const
{
return
true
;
}
bool
SubgraphOp
::
InferShape
()
const
{
return
CheckShape
();
/* enrich me */
}
bool
SubgraphOp
::
InferShape
Impl
()
const
{
return
CheckShape
();
/* enrich me */
}
bool
SubgraphOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
param_
.
input_names
=
op_desc
.
Input
(
"Inputs"
);
...
...
lite/operators/subgraph_op.h
浏览文件 @
c754a38f
...
...
@@ -35,7 +35,7 @@ class SubgraphOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/topk_op.cc
浏览文件 @
c754a38f
...
...
@@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const {
return
true
;
}
bool
TopkOp
::
InferShape
()
const
{
bool
TopkOp
::
InferShape
Impl
()
const
{
auto
out_dims
=
param_
.
X
->
dims
();
out_dims
[
out_dims
.
size
()
-
1
]
=
param_
.
K
;
auto
out
=
param_
.
Out
;
...
...
lite/operators/topk_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class TopkOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/transpose_op.cc
浏览文件 @
c754a38f
...
...
@@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const {
return
true
;
}
bool
TransposeOp
::
InferShape
()
const
{
bool
TransposeOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
...
...
@@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const {
return
true
;
}
bool
Transpose2Op
::
InferShape
()
const
{
bool
Transpose2Op
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
...
...
lite/operators/transpose_op.h
浏览文件 @
c754a38f
...
...
@@ -31,7 +31,7 @@ class TransposeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -50,7 +50,7 @@ class Transpose2Op : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/uniform_random_op.cc
浏览文件 @
c754a38f
...
...
@@ -22,7 +22,7 @@ namespace operators {
bool
UniformRandomOpLite
::
CheckShape
()
const
{
return
true
;
}
bool
UniformRandomOpLite
::
InferShape
()
const
{
bool
UniformRandomOpLite
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
shape
);
return
true
;
}
...
...
lite/operators/uniform_random_op.h
浏览文件 @
c754a38f
...
...
@@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/unsqueeze_op.cc
浏览文件 @
c754a38f
...
...
@@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const {
return
true
;
}
bool
UnsqueezeOp
::
InferShape
()
const
{
bool
UnsqueezeOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int
>
final_axes
;
auto
axes
=
param_
.
axes
;
auto
*
axes_tensor
=
param_
.
axes_tensor
;
...
...
@@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const {
return
true
;
}
bool
Unsqueeze2Op
::
InferShape
()
const
{
UnsqueezeOp
::
InferShape
();
bool
Unsqueeze2Op
::
InferShape
Impl
()
const
{
UnsqueezeOp
::
InferShape
Impl
();
auto
x_dims
=
param_
.
X
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
1
);
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
...
...
lite/operators/unsqueeze_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
@@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/var_conv_2d_op.cc
浏览文件 @
c754a38f
...
...
@@ -21,7 +21,7 @@ namespace operators {
bool
VarConv2dOp
::
CheckShape
()
const
{
return
true
;
}
bool
VarConv2dOp
::
InferShape
()
const
{
return
true
;
}
bool
VarConv2dOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
VarConv2dOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
X
=
const_cast
<
lite
::
Tensor
*>
(
...
...
lite/operators/var_conv_2d_op.h
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite {
VarConv2dOp
()
{}
explicit
VarConv2dOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"var_conv_2d"
;
}
...
...
lite/operators/while_op.cc
浏览文件 @
c754a38f
...
...
@@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const {
return
true
;
}
bool
WhileOpLite
::
InferShape
()
const
{
return
true
;
}
bool
WhileOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
WhileOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
inputs
=
op_desc
.
Input
(
"X"
);
...
...
lite/operators/while_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class WhileOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/write_to_array_op.cc
浏览文件 @
c754a38f
...
...
@@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const {
return
true
;
}
bool
WriteToArrayOp
::
InferShape
()
const
{
bool
WriteToArrayOp
::
InferShape
Impl
()
const
{
int
id
=
param_
.
I
->
data
<
int64_t
>
()[
0
];
if
(
param_
.
Out
->
size
()
<
id
+
1
)
{
param_
.
Out
->
resize
(
id
+
1
);
...
...
lite/operators/write_to_array_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/yolo_box_op.cc
浏览文件 @
c754a38f
...
...
@@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const {
return
true
;
}
bool
YoloBoxOp
::
InferShape
()
const
{
bool
YoloBoxOp
::
InferShape
Impl
()
const
{
auto
*
X
=
param_
.
X
;
auto
anchors
=
param_
.
anchors
;
int
anchor_num
=
anchors
.
size
()
/
2
;
...
...
lite/operators/yolo_box_op.h
浏览文件 @
c754a38f
...
...
@@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录