Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
c754a38f
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
332
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c754a38f
编写于
3月 31, 2020
作者:
H
huzhiqiang
提交者:
GitHub
3月 31, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[operator] add InferShapeImpl method (#3294)
上级
50638e96
变更
243
显示空白变更内容
内联
并排
Showing
243 changed file
with
517 addition
and
502 deletion
+517
-502
lite/core/op_lite.cc
lite/core/op_lite.cc
+55
-0
lite/core/op_lite.h
lite/core/op_lite.h
+14
-6
lite/core/program.cc
lite/core/program.cc
+1
-2
lite/operators/activation_grad_ops.cc
lite/operators/activation_grad_ops.cc
+1
-1
lite/operators/activation_grad_ops.h
lite/operators/activation_grad_ops.h
+1
-1
lite/operators/activation_ops.cc
lite/operators/activation_ops.cc
+1
-1
lite/operators/activation_ops.h
lite/operators/activation_ops.h
+1
-1
lite/operators/affine_channel_op.cc
lite/operators/affine_channel_op.cc
+1
-1
lite/operators/affine_channel_op.h
lite/operators/affine_channel_op.h
+1
-1
lite/operators/anchor_generator_op.cc
lite/operators/anchor_generator_op.cc
+1
-1
lite/operators/anchor_generator_op.h
lite/operators/anchor_generator_op.h
+1
-1
lite/operators/argmax_op.cc
lite/operators/argmax_op.cc
+1
-1
lite/operators/argmax_op.h
lite/operators/argmax_op.h
+1
-1
lite/operators/assign_op.cc
lite/operators/assign_op.cc
+1
-1
lite/operators/assign_op.h
lite/operators/assign_op.h
+1
-1
lite/operators/assign_value_op.cc
lite/operators/assign_value_op.cc
+1
-1
lite/operators/assign_value_op.h
lite/operators/assign_value_op.h
+1
-1
lite/operators/attention_padding_mask_op.cc
lite/operators/attention_padding_mask_op.cc
+1
-1
lite/operators/attention_padding_mask_op.h
lite/operators/attention_padding_mask_op.h
+1
-1
lite/operators/axpy_op.cc
lite/operators/axpy_op.cc
+1
-1
lite/operators/axpy_op.h
lite/operators/axpy_op.h
+1
-1
lite/operators/batch_norm_op.cc
lite/operators/batch_norm_op.cc
+1
-1
lite/operators/batch_norm_op.h
lite/operators/batch_norm_op.h
+1
-1
lite/operators/beam_search_decode_op.cc
lite/operators/beam_search_decode_op.cc
+1
-1
lite/operators/beam_search_decode_op.h
lite/operators/beam_search_decode_op.h
+1
-1
lite/operators/beam_search_op.cc
lite/operators/beam_search_op.cc
+1
-1
lite/operators/beam_search_op.h
lite/operators/beam_search_op.h
+1
-1
lite/operators/box_clip_op.cc
lite/operators/box_clip_op.cc
+1
-1
lite/operators/box_clip_op.h
lite/operators/box_clip_op.h
+1
-1
lite/operators/box_coder_op.cc
lite/operators/box_coder_op.cc
+1
-1
lite/operators/box_coder_op.h
lite/operators/box_coder_op.h
+1
-1
lite/operators/calib_op.cc
lite/operators/calib_op.cc
+1
-1
lite/operators/calib_op.h
lite/operators/calib_op.h
+1
-1
lite/operators/cast_op.cc
lite/operators/cast_op.cc
+1
-1
lite/operators/cast_op.h
lite/operators/cast_op.h
+1
-1
lite/operators/collect_fpn_proposals_op.cc
lite/operators/collect_fpn_proposals_op.cc
+1
-1
lite/operators/collect_fpn_proposals_op.h
lite/operators/collect_fpn_proposals_op.h
+1
-1
lite/operators/compare_op.cc
lite/operators/compare_op.cc
+1
-1
lite/operators/compare_op.h
lite/operators/compare_op.h
+1
-1
lite/operators/concat_op.cc
lite/operators/concat_op.cc
+1
-1
lite/operators/concat_op.h
lite/operators/concat_op.h
+1
-1
lite/operators/conditional_block_op.cc
lite/operators/conditional_block_op.cc
+1
-1
lite/operators/conditional_block_op.h
lite/operators/conditional_block_op.h
+1
-1
lite/operators/conv_op.cc
lite/operators/conv_op.cc
+1
-29
lite/operators/conv_op.h
lite/operators/conv_op.h
+1
-3
lite/operators/conv_transpose_op.cc
lite/operators/conv_transpose_op.cc
+1
-1
lite/operators/conv_transpose_op.h
lite/operators/conv_transpose_op.h
+1
-1
lite/operators/crf_decoding_op.cc
lite/operators/crf_decoding_op.cc
+1
-1
lite/operators/crf_decoding_op.h
lite/operators/crf_decoding_op.h
+1
-1
lite/operators/crop_op.cc
lite/operators/crop_op.cc
+1
-1
lite/operators/crop_op.h
lite/operators/crop_op.h
+1
-1
lite/operators/decode_bboxes_op.cc
lite/operators/decode_bboxes_op.cc
+1
-1
lite/operators/decode_bboxes_op.h
lite/operators/decode_bboxes_op.h
+1
-1
lite/operators/density_prior_box_op.cc
lite/operators/density_prior_box_op.cc
+1
-1
lite/operators/density_prior_box_op.h
lite/operators/density_prior_box_op.h
+1
-1
lite/operators/distribute_fpn_proposals_op.cc
lite/operators/distribute_fpn_proposals_op.cc
+1
-1
lite/operators/distribute_fpn_proposals_op.h
lite/operators/distribute_fpn_proposals_op.h
+1
-1
lite/operators/dropout_op.cc
lite/operators/dropout_op.cc
+1
-1
lite/operators/dropout_op.h
lite/operators/dropout_op.h
+1
-1
lite/operators/elementwise_grad_ops.cc
lite/operators/elementwise_grad_ops.cc
+1
-1
lite/operators/elementwise_grad_ops.h
lite/operators/elementwise_grad_ops.h
+1
-1
lite/operators/elementwise_ops.cc
lite/operators/elementwise_ops.cc
+2
-33
lite/operators/elementwise_ops.h
lite/operators/elementwise_ops.h
+2
-3
lite/operators/expand_op.cc
lite/operators/expand_op.cc
+1
-1
lite/operators/expand_op.h
lite/operators/expand_op.h
+1
-1
lite/operators/fake_channel_wise_dequantize_max_abs.h
lite/operators/fake_channel_wise_dequantize_max_abs.h
+1
-1
lite/operators/fake_dequantize_max_abs.h
lite/operators/fake_dequantize_max_abs.h
+1
-1
lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h
lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h
+1
-1
lite/operators/fake_quantize_moving_avg_max_abs.h
lite/operators/fake_quantize_moving_avg_max_abs.h
+1
-1
lite/operators/fake_quantize_range_abs_max.h
lite/operators/fake_quantize_range_abs_max.h
+1
-1
lite/operators/fc_op.cc
lite/operators/fc_op.cc
+1
-28
lite/operators/fc_op.h
lite/operators/fc_op.h
+1
-2
lite/operators/feed_op.cc
lite/operators/feed_op.cc
+1
-1
lite/operators/fetch_op.cc
lite/operators/fetch_op.cc
+1
-1
lite/operators/fill_constant_batch_size_like_op.cc
lite/operators/fill_constant_batch_size_like_op.cc
+1
-1
lite/operators/fill_constant_batch_size_like_op.h
lite/operators/fill_constant_batch_size_like_op.h
+1
-1
lite/operators/fill_constant_op.cc
lite/operators/fill_constant_op.cc
+1
-1
lite/operators/fill_constant_op.h
lite/operators/fill_constant_op.h
+1
-1
lite/operators/flatten_op.cc
lite/operators/flatten_op.cc
+3
-3
lite/operators/flatten_op.h
lite/operators/flatten_op.h
+2
-2
lite/operators/fusion_elementwise_activation_ops.cc
lite/operators/fusion_elementwise_activation_ops.cc
+2
-2
lite/operators/fusion_elementwise_activation_ops.h
lite/operators/fusion_elementwise_activation_ops.h
+2
-2
lite/operators/gather_op.cc
lite/operators/gather_op.cc
+1
-1
lite/operators/gather_op.h
lite/operators/gather_op.h
+1
-1
lite/operators/generate_proposals_op.cc
lite/operators/generate_proposals_op.cc
+1
-1
lite/operators/generate_proposals_op.h
lite/operators/generate_proposals_op.h
+1
-1
lite/operators/grid_sampler_op.cc
lite/operators/grid_sampler_op.cc
+1
-1
lite/operators/grid_sampler_op.h
lite/operators/grid_sampler_op.h
+1
-1
lite/operators/gru_op.cc
lite/operators/gru_op.cc
+1
-1
lite/operators/gru_op.h
lite/operators/gru_op.h
+1
-1
lite/operators/gru_unit_op.cc
lite/operators/gru_unit_op.cc
+1
-1
lite/operators/gru_unit_op.h
lite/operators/gru_unit_op.h
+1
-1
lite/operators/im2sequence_op.cc
lite/operators/im2sequence_op.cc
+1
-1
lite/operators/im2sequence_op.h
lite/operators/im2sequence_op.h
+1
-1
lite/operators/increment_op.cc
lite/operators/increment_op.cc
+1
-1
lite/operators/increment_op.h
lite/operators/increment_op.h
+1
-1
lite/operators/instance_norm_op.cc
lite/operators/instance_norm_op.cc
+1
-1
lite/operators/instance_norm_op.h
lite/operators/instance_norm_op.h
+1
-1
lite/operators/interpolate_op.cc
lite/operators/interpolate_op.cc
+1
-1
lite/operators/interpolate_op.h
lite/operators/interpolate_op.h
+1
-1
lite/operators/io_copy_op.cc
lite/operators/io_copy_op.cc
+1
-1
lite/operators/io_copy_op.h
lite/operators/io_copy_op.h
+1
-1
lite/operators/is_empty_op.cc
lite/operators/is_empty_op.cc
+1
-1
lite/operators/is_empty_op.h
lite/operators/is_empty_op.h
+1
-1
lite/operators/layer_norm_op.cc
lite/operators/layer_norm_op.cc
+1
-1
lite/operators/layer_norm_op.h
lite/operators/layer_norm_op.h
+1
-1
lite/operators/layout_op.cc
lite/operators/layout_op.cc
+1
-1
lite/operators/layout_op.h
lite/operators/layout_op.h
+1
-1
lite/operators/lod_reset_op.cc
lite/operators/lod_reset_op.cc
+1
-1
lite/operators/lod_reset_op.h
lite/operators/lod_reset_op.h
+1
-1
lite/operators/logical_op.cc
lite/operators/logical_op.cc
+2
-2
lite/operators/logical_op.h
lite/operators/logical_op.h
+2
-2
lite/operators/lookup_table_dequant_op.cc
lite/operators/lookup_table_dequant_op.cc
+1
-1
lite/operators/lookup_table_dequant_op.h
lite/operators/lookup_table_dequant_op.h
+1
-1
lite/operators/lookup_table_op.cc
lite/operators/lookup_table_op.cc
+1
-1
lite/operators/lookup_table_op.h
lite/operators/lookup_table_op.h
+1
-1
lite/operators/lookup_table_v2_op.cc
lite/operators/lookup_table_v2_op.cc
+1
-1
lite/operators/lookup_table_v2_op.h
lite/operators/lookup_table_v2_op.h
+1
-1
lite/operators/lrn_op.cc
lite/operators/lrn_op.cc
+1
-1
lite/operators/lrn_op.h
lite/operators/lrn_op.h
+1
-1
lite/operators/lstm_op.cc
lite/operators/lstm_op.cc
+1
-1
lite/operators/lstm_op.h
lite/operators/lstm_op.h
+1
-1
lite/operators/match_matrix_tensor_op.cc
lite/operators/match_matrix_tensor_op.cc
+1
-1
lite/operators/match_matrix_tensor_op.h
lite/operators/match_matrix_tensor_op.h
+1
-1
lite/operators/matmul_op.cc
lite/operators/matmul_op.cc
+1
-1
lite/operators/matmul_op.h
lite/operators/matmul_op.h
+1
-1
lite/operators/mean_grad_op.cc
lite/operators/mean_grad_op.cc
+1
-1
lite/operators/mean_grad_op.h
lite/operators/mean_grad_op.h
+1
-1
lite/operators/mean_op.cc
lite/operators/mean_op.cc
+1
-1
lite/operators/mean_op.h
lite/operators/mean_op.h
+1
-1
lite/operators/merge_lod_tensor_op.cc
lite/operators/merge_lod_tensor_op.cc
+1
-1
lite/operators/merge_lod_tensor_op.h
lite/operators/merge_lod_tensor_op.h
+1
-1
lite/operators/mul_grad_op.cc
lite/operators/mul_grad_op.cc
+1
-1
lite/operators/mul_grad_op.h
lite/operators/mul_grad_op.h
+1
-1
lite/operators/mul_op.cc
lite/operators/mul_op.cc
+1
-1
lite/operators/mul_op.h
lite/operators/mul_op.h
+1
-1
lite/operators/multiclass_nms_op.cc
lite/operators/multiclass_nms_op.cc
+1
-1
lite/operators/multiclass_nms_op.h
lite/operators/multiclass_nms_op.h
+1
-1
lite/operators/negative_op.cc
lite/operators/negative_op.cc
+1
-1
lite/operators/negative_op.h
lite/operators/negative_op.h
+1
-1
lite/operators/norm_op.cc
lite/operators/norm_op.cc
+1
-1
lite/operators/norm_op.h
lite/operators/norm_op.h
+1
-1
lite/operators/op_params.h
lite/operators/op_params.h
+188
-116
lite/operators/pad2d_op.cc
lite/operators/pad2d_op.cc
+1
-1
lite/operators/pad2d_op.h
lite/operators/pad2d_op.h
+1
-1
lite/operators/pool_op.cc
lite/operators/pool_op.cc
+1
-1
lite/operators/pool_op.h
lite/operators/pool_op.h
+1
-1
lite/operators/power_op.cc
lite/operators/power_op.cc
+1
-1
lite/operators/power_op.h
lite/operators/power_op.h
+1
-1
lite/operators/prior_box_op.cc
lite/operators/prior_box_op.cc
+1
-1
lite/operators/prior_box_op.h
lite/operators/prior_box_op.h
+1
-1
lite/operators/range_op.cc
lite/operators/range_op.cc
+1
-1
lite/operators/range_op.h
lite/operators/range_op.h
+1
-1
lite/operators/read_from_array_op.cc
lite/operators/read_from_array_op.cc
+1
-1
lite/operators/read_from_array_op.h
lite/operators/read_from_array_op.h
+1
-1
lite/operators/reduce_max_op.cc
lite/operators/reduce_max_op.cc
+1
-1
lite/operators/reduce_max_op.h
lite/operators/reduce_max_op.h
+1
-1
lite/operators/reduce_mean_op.cc
lite/operators/reduce_mean_op.cc
+1
-1
lite/operators/reduce_mean_op.h
lite/operators/reduce_mean_op.h
+1
-1
lite/operators/reduce_ops.cc
lite/operators/reduce_ops.cc
+1
-1
lite/operators/reduce_ops.h
lite/operators/reduce_ops.h
+1
-1
lite/operators/reduce_prod_op.cc
lite/operators/reduce_prod_op.cc
+1
-1
lite/operators/reduce_prod_op.h
lite/operators/reduce_prod_op.h
+1
-1
lite/operators/relu_op.cc
lite/operators/relu_op.cc
+1
-1
lite/operators/relu_op.h
lite/operators/relu_op.h
+1
-1
lite/operators/reshape_op.cc
lite/operators/reshape_op.cc
+3
-3
lite/operators/reshape_op.h
lite/operators/reshape_op.h
+2
-2
lite/operators/roi_align_op.cc
lite/operators/roi_align_op.cc
+1
-1
lite/operators/roi_align_op.h
lite/operators/roi_align_op.h
+1
-1
lite/operators/scale_op.cc
lite/operators/scale_op.cc
+1
-1
lite/operators/scale_op.h
lite/operators/scale_op.h
+1
-1
lite/operators/search_aligned_mat_mul_op.cc
lite/operators/search_aligned_mat_mul_op.cc
+1
-1
lite/operators/search_aligned_mat_mul_op.h
lite/operators/search_aligned_mat_mul_op.h
+1
-1
lite/operators/search_fc_op.cc
lite/operators/search_fc_op.cc
+1
-1
lite/operators/search_fc_op.h
lite/operators/search_fc_op.h
+1
-1
lite/operators/search_grnn_op.cc
lite/operators/search_grnn_op.cc
+1
-1
lite/operators/search_grnn_op.h
lite/operators/search_grnn_op.h
+1
-1
lite/operators/search_group_padding_op.cc
lite/operators/search_group_padding_op.cc
+1
-1
lite/operators/search_group_padding_op.h
lite/operators/search_group_padding_op.h
+1
-1
lite/operators/search_seq_depadding_op.cc
lite/operators/search_seq_depadding_op.cc
+1
-1
lite/operators/search_seq_depadding_op.h
lite/operators/search_seq_depadding_op.h
+1
-1
lite/operators/search_seq_fc_op.cc
lite/operators/search_seq_fc_op.cc
+1
-1
lite/operators/search_seq_fc_op.h
lite/operators/search_seq_fc_op.h
+1
-1
lite/operators/search_seq_softmax_op.cc
lite/operators/search_seq_softmax_op.cc
+1
-1
lite/operators/search_seq_softmax_op.h
lite/operators/search_seq_softmax_op.h
+1
-1
lite/operators/sequence_arithmetic_op.cc
lite/operators/sequence_arithmetic_op.cc
+1
-1
lite/operators/sequence_arithmetic_op.h
lite/operators/sequence_arithmetic_op.h
+1
-1
lite/operators/sequence_concat_op.cc
lite/operators/sequence_concat_op.cc
+1
-1
lite/operators/sequence_concat_op.h
lite/operators/sequence_concat_op.h
+1
-1
lite/operators/sequence_conv_op.cc
lite/operators/sequence_conv_op.cc
+1
-1
lite/operators/sequence_conv_op.h
lite/operators/sequence_conv_op.h
+1
-1
lite/operators/sequence_expand_as_op.cc
lite/operators/sequence_expand_as_op.cc
+1
-1
lite/operators/sequence_expand_as_op.h
lite/operators/sequence_expand_as_op.h
+1
-1
lite/operators/sequence_expand_op.cc
lite/operators/sequence_expand_op.cc
+1
-1
lite/operators/sequence_expand_op.h
lite/operators/sequence_expand_op.h
+1
-1
lite/operators/sequence_pool_concat_op.cc
lite/operators/sequence_pool_concat_op.cc
+1
-1
lite/operators/sequence_pool_concat_op.h
lite/operators/sequence_pool_concat_op.h
+1
-1
lite/operators/sequence_pool_op.cc
lite/operators/sequence_pool_op.cc
+1
-1
lite/operators/sequence_pool_op.h
lite/operators/sequence_pool_op.h
+1
-1
lite/operators/sequence_reshape_op.cc
lite/operators/sequence_reshape_op.cc
+1
-1
lite/operators/sequence_reshape_op.h
lite/operators/sequence_reshape_op.h
+1
-1
lite/operators/sequence_reverse_op.cc
lite/operators/sequence_reverse_op.cc
+1
-1
lite/operators/sequence_reverse_op.h
lite/operators/sequence_reverse_op.h
+1
-1
lite/operators/sequence_softmax_op.cc
lite/operators/sequence_softmax_op.cc
+1
-1
lite/operators/sequence_softmax_op.h
lite/operators/sequence_softmax_op.h
+1
-1
lite/operators/sequence_topk_avg_pooling_op.cc
lite/operators/sequence_topk_avg_pooling_op.cc
+1
-1
lite/operators/sequence_topk_avg_pooling_op.h
lite/operators/sequence_topk_avg_pooling_op.h
+1
-1
lite/operators/sgd_op.cc
lite/operators/sgd_op.cc
+1
-1
lite/operators/sgd_op.h
lite/operators/sgd_op.h
+1
-1
lite/operators/shape_op.cc
lite/operators/shape_op.cc
+1
-1
lite/operators/shape_op.h
lite/operators/shape_op.h
+1
-1
lite/operators/shuffle_channel_op.cc
lite/operators/shuffle_channel_op.cc
+1
-1
lite/operators/shuffle_channel_op.h
lite/operators/shuffle_channel_op.h
+1
-1
lite/operators/slice_op.cc
lite/operators/slice_op.cc
+1
-1
lite/operators/slice_op.h
lite/operators/slice_op.h
+1
-1
lite/operators/softmax_op.cc
lite/operators/softmax_op.cc
+1
-29
lite/operators/softmax_op.h
lite/operators/softmax_op.h
+1
-2
lite/operators/split_lod_tensor_op.cc
lite/operators/split_lod_tensor_op.cc
+1
-1
lite/operators/split_lod_tensor_op.h
lite/operators/split_lod_tensor_op.h
+1
-1
lite/operators/split_op.cc
lite/operators/split_op.cc
+1
-1
lite/operators/split_op.h
lite/operators/split_op.h
+1
-1
lite/operators/squeeze_op.cc
lite/operators/squeeze_op.cc
+3
-3
lite/operators/squeeze_op.h
lite/operators/squeeze_op.h
+2
-2
lite/operators/stack_op.cc
lite/operators/stack_op.cc
+1
-1
lite/operators/stack_op.h
lite/operators/stack_op.h
+1
-1
lite/operators/subgraph_op.cc
lite/operators/subgraph_op.cc
+1
-1
lite/operators/subgraph_op.h
lite/operators/subgraph_op.h
+1
-1
lite/operators/topk_op.cc
lite/operators/topk_op.cc
+1
-1
lite/operators/topk_op.h
lite/operators/topk_op.h
+1
-1
lite/operators/transpose_op.cc
lite/operators/transpose_op.cc
+2
-2
lite/operators/transpose_op.h
lite/operators/transpose_op.h
+2
-2
lite/operators/uniform_random_op.cc
lite/operators/uniform_random_op.cc
+1
-1
lite/operators/uniform_random_op.h
lite/operators/uniform_random_op.h
+1
-1
lite/operators/unsqueeze_op.cc
lite/operators/unsqueeze_op.cc
+3
-3
lite/operators/unsqueeze_op.h
lite/operators/unsqueeze_op.h
+2
-2
lite/operators/var_conv_2d_op.cc
lite/operators/var_conv_2d_op.cc
+1
-1
lite/operators/var_conv_2d_op.h
lite/operators/var_conv_2d_op.h
+1
-1
lite/operators/while_op.cc
lite/operators/while_op.cc
+1
-1
lite/operators/while_op.h
lite/operators/while_op.h
+1
-1
lite/operators/write_to_array_op.cc
lite/operators/write_to_array_op.cc
+1
-1
lite/operators/write_to_array_op.h
lite/operators/write_to_array_op.h
+1
-1
lite/operators/yolo_box_op.cc
lite/operators/yolo_box_op.cc
+1
-1
lite/operators/yolo_box_op.h
lite/operators/yolo_box_op.h
+1
-1
未找到文件。
lite/core/op_lite.cc
浏览文件 @
c754a38f
...
@@ -22,6 +22,61 @@
...
@@ -22,6 +22,61 @@
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
bool
OpLite
::
InferShape
()
{
// if input_tensor_ptrs and output_tensor_ptrs are overloaded in param_
// InferShapeByMemoryInternal will be applied.
if
(
param_
.
input_tensor_ptrs
()
&&
param_
.
output_tensor_ptrs
())
{
return
this
->
InferShapeWithCache
();
}
else
{
// otherwise, InferShapeImpl is applied directly.
return
this
->
InferShapeImpl
();
}
}
bool
OpLite
::
InferShapeWithCache
()
{
// 1. Get vector of current input tensors
auto
*
current_inputs
=
param_
.
input_tensor_ptrs
();
// 2. Get hash value of current inputs shape and lod
size_t
new_hash
=
0
;
for
(
auto
iter
=
current_inputs
->
begin
();
iter
!=
current_inputs
->
end
();
iter
++
)
{
// combined dims value into new_hash value.
auto
&
element_dims
=
(
*
iter
)
->
dims
();
for
(
int
i
=
0
;
i
<
element_dims
.
size
();
i
++
)
{
new_hash
=
lite
::
hash_combine
(
new_hash
,
static_cast
<
int
>
(
element_dims
[
i
]));
}
// combine lod value into new_hash valud.
auto
&
emement_lods
=
(
*
iter
)
->
lod
();
for
(
auto
lod_iter
=
emement_lods
.
begin
();
lod_iter
!=
emement_lods
.
end
();
lod_iter
++
)
{
for
(
int
i
=
0
;
i
<
lod_iter
->
size
();
i
++
)
{
new_hash
=
lite
::
hash_combine
(
new_hash
,
static_cast
<
int
>
(
lod_iter
->
at
(
i
)));
}
}
}
// 3. infer shapes of output tensors
if
(
new_hash
==
io_shape_lod_hash_
&&
new_hash
!=
0
)
{
// if current hash value is consistent with io_shape_lod_hash_,
// previous outputs shape and lod are reused.
auto
*
current_outputs
=
param_
.
output_tensor_ptrs
();
for
(
int
i
=
0
;
i
<
current_outputs
->
size
();
i
++
)
{
current_outputs
->
at
(
i
)
->
Resize
(
last_output_shapes
[
i
]);
current_outputs
->
at
(
i
)
->
set_lod
(
last_output_lods
[
i
]);
}
}
else
{
// otherwise, current hash value is changed, InferShapeImpl will apply.
io_shape_lod_hash_
=
new_hash
;
this
->
InferShapeImpl
();
auto
*
current_outputs
=
param_
.
output_tensor_ptrs
();
for
(
int
i
=
0
;
i
<
current_outputs
->
size
();
i
++
)
{
last_output_shapes
[
i
]
=
current_outputs
->
at
(
i
)
->
dims
();
last_output_lods
[
i
]
=
current_outputs
->
at
(
i
)
->
lod
();
}
}
return
true
;
}
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
OpLite
::
CreateKernels
(
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
OpLite
::
CreateKernels
(
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
)
{
const
std
::
vector
<
Place
>
&
places
,
const
std
::
string
&
kernel_type
)
{
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
kernels
;
std
::
vector
<
std
::
unique_ptr
<
KernelBase
>>
kernels
;
...
...
lite/core/op_lite.h
浏览文件 @
c754a38f
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#pragma once
#pragma once
#include <functional>
#include <list>
#include <list>
#include <map>
#include <map>
#include <memory>
#include <memory>
...
@@ -24,6 +25,7 @@
...
@@ -24,6 +25,7 @@
#include "lite/core/kernel.h"
#include "lite/core/kernel.h"
#include "lite/core/scope.h"
#include "lite/core/scope.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/model_parser/cpp/op_desc.h"
#include "lite/operators/op_params.h"
namespace
paddle
{
namespace
paddle
{
namespace
lite
{
namespace
lite
{
...
@@ -64,8 +66,8 @@ class OpLite : public Registry {
...
@@ -64,8 +66,8 @@ class OpLite : public Registry {
// Check the shape.
// Check the shape.
virtual
bool
CheckShape
()
const
{
return
true
;
}
virtual
bool
CheckShape
()
const
{
return
true
;
}
// Inference the outputs' shape.
// Inference the outputs' shape.
virtual
bool
InferShape
()
const
{
return
true
;
}
virtual
bool
InferShape
Impl
()
const
{
return
true
;
}
virtual
bool
SmartInferShape
()
{
return
this
->
InferShape
();
}
virtual
bool
InferShape
();
// Run this operator.
// Run this operator.
virtual
bool
Run
();
virtual
bool
Run
();
// Indicate whether the Op runs only once or not
// Indicate whether the Op runs only once or not
...
@@ -151,10 +153,16 @@ class OpLite : public Registry {
...
@@ -151,10 +153,16 @@ class OpLite : public Registry {
std
::
vector
<
Place
>
valid_places_
;
std
::
vector
<
Place
>
valid_places_
;
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
Place
kernel_place_
{
TARGET
(
kHost
),
PRECISION
(
kFloat
)};
std
::
unique_ptr
<
OpInfo
>
op_info_
;
std
::
unique_ptr
<
OpInfo
>
op_info_
;
std
::
vector
<
DDimLite
>
last_input_shapes
;
std
::
vector
<
DDimLite
>
last_output_shapes
;
std
::
vector
<
DDimLite
>
last_output_shapes
{};
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
last_output_lods
;
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
last_output_lods
{};
std
::
vector
<
std
::
vector
<
std
::
vector
<
uint64_t
>>>
last_input_lods
;
size_t
io_shape_lod_hash_
{};
mutable
operators
::
ParamBase
param_
;
private:
// Infer Shape according to memory, if current input shapes are consistent
// with that of previous inputs, output shapes of last time will be reused.
bool
InferShapeWithCache
();
};
};
/*
/*
...
...
lite/core/program.cc
浏览文件 @
c754a38f
...
@@ -286,8 +286,7 @@ void Instruction::Run() {
...
@@ -286,8 +286,7 @@ void Instruction::Run() {
return
;
return
;
}
}
// op_->InferShape();
op_
->
InferShape
();
op_
->
SmartInferShape
();
kernel_
->
Launch
();
kernel_
->
Launch
();
has_run_
=
true
;
has_run_
=
true
;
}
}
...
...
lite/operators/activation_grad_ops.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool ActivationGradOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ActivationGradOp
::
InferShape
()
const
{
bool
ActivationGradOp
::
InferShape
Impl
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
param_
.
X_grad
->
Resize
(
param_
.
Out_grad
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/activation_grad_ops.h
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite {
...
@@ -26,7 +26,7 @@ class ActivationGradOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/activation_ops.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool ActivationOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ActivationOp
::
InferShape
()
const
{
bool
ActivationOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
auto
out_lod
=
param_
.
Out
->
mutable_lod
();
auto
out_lod
=
param_
.
Out
->
mutable_lod
();
*
out_lod
=
param_
.
X
->
lod
();
*
out_lod
=
param_
.
X
->
lod
();
...
...
lite/operators/activation_ops.h
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ class ActivationOp : public OpLite {
...
@@ -26,7 +26,7 @@ class ActivationOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/affine_channel_op.cc
浏览文件 @
c754a38f
...
@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const {
...
@@ -44,7 +44,7 @@ bool AffineChannelOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
AffineChannelOpLite
::
InferShape
()
const
{
bool
AffineChannelOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
x_dims
=
param_
.
X
->
dims
();
param_
.
Out
->
Resize
(
x_dims
);
param_
.
Out
->
Resize
(
x_dims
);
return
true
;
return
true
;
...
...
lite/operators/affine_channel_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class AffineChannelOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/anchor_generator_op.cc
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const {
...
@@ -31,7 +31,7 @@ bool AnchorGeneratorOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
AnchorGeneratorOpLite
::
InferShape
()
const
{
bool
AnchorGeneratorOpLite
::
InferShape
Impl
()
const
{
auto
input_dims
=
param_
.
Input
->
dims
();
auto
input_dims
=
param_
.
Input
->
dims
();
size_t
num_anchors
=
param_
.
aspect_ratios
.
size
()
*
param_
.
anchor_sizes
.
size
();
size_t
num_anchors
=
param_
.
aspect_ratios
.
size
()
*
param_
.
anchor_sizes
.
size
();
std
::
vector
<
int64_t
>
output_shape
(
std
::
vector
<
int64_t
>
output_shape
(
...
...
lite/operators/anchor_generator_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class AnchorGeneratorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/argmax_op.cc
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const {
...
@@ -29,7 +29,7 @@ bool ArgmaxOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ArgmaxOpLite
::
InferShape
()
const
{
bool
ArgmaxOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
int
x_rank
=
x_dims
.
size
();
int
x_rank
=
x_dims
.
size
();
int
axis
=
param_
.
Axis
;
int
axis
=
param_
.
Axis
;
...
...
lite/operators/argmax_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class ArgmaxOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/assign_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool AssignOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
AssignOpLite
::
InferShape
()
const
{
bool
AssignOpLite
::
InferShape
Impl
()
const
{
lite
::
DDim
input_dims
;
lite
::
DDim
input_dims
;
input_dims
=
param_
.
X
->
dims
();
input_dims
=
param_
.
X
->
dims
();
param_
.
Out
->
Resize
(
lite
::
DDim
(
input_dims
));
param_
.
Out
->
Resize
(
lite
::
DDim
(
input_dims
));
...
...
lite/operators/assign_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class AssignOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/assign_value_op.cc
浏览文件 @
c754a38f
...
@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const {
...
@@ -35,7 +35,7 @@ bool AssignValueOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
AssignValueOpLite
::
InferShape
()
const
{
bool
AssignValueOpLite
::
InferShape
Impl
()
const
{
std
::
vector
<
int
>
shape
=
param_
.
shape
;
std
::
vector
<
int
>
shape
=
param_
.
shape
;
std
::
vector
<
int64_t
>
out_shape
;
std
::
vector
<
int64_t
>
out_shape
;
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
i
++
)
out_shape
.
push_back
(
shape
[
i
]);
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
i
++
)
out_shape
.
push_back
(
shape
[
i
]);
...
...
lite/operators/assign_value_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class AssignValueOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/attention_padding_mask_op.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool AttentionPaddingMaskOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
AttentionPaddingMaskOp
::
InferShape
()
const
{
bool
AttentionPaddingMaskOp
::
InferShape
Impl
()
const
{
auto
src_len
=
param_
.
X
->
lod
()[
0
][
1
];
auto
src_len
=
param_
.
X
->
lod
()[
0
][
1
];
CHECK_EQ
(
src_len
,
param_
.
X
->
dims
()[
1
])
CHECK_EQ
(
src_len
,
param_
.
X
->
dims
()[
1
])
<<
"Mismatch source length, expect: "
<<
src_len
<<
"Mismatch source length, expect: "
<<
src_len
...
...
lite/operators/attention_padding_mask_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite {
...
@@ -29,7 +29,7 @@ class AttentionPaddingMaskOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/axpy_op.cc
浏览文件 @
c754a38f
...
@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const {
...
@@ -34,7 +34,7 @@ bool AxpyOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
AxpyOpLite
::
InferShape
()
const
{
bool
AxpyOpLite
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
Bias
->
dims
();
auto
dims
=
param_
.
Bias
->
dims
();
// Set output dims
// Set output dims
...
...
lite/operators/axpy_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class AxpyOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/batch_norm_op.cc
浏览文件 @
c754a38f
...
@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const {
...
@@ -46,7 +46,7 @@ bool BatchNormOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
BatchNormOp
::
InferShape
()
const
{
bool
BatchNormOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
int64_t
channel_size
=
0
;
int64_t
channel_size
=
0
;
switch
(
param_
.
data_layout
)
{
switch
(
param_
.
data_layout
)
{
...
...
lite/operators/batch_norm_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite {
...
@@ -30,7 +30,7 @@ class BatchNormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/beam_search_decode_op.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool BeamSearchDecodeOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
BeamSearchDecodeOpLite
::
InferShape
()
const
{
return
true
;
}
bool
BeamSearchDecodeOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
BeamSearchDecodeOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
bool
BeamSearchDecodeOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
lite
::
Scope
*
scope
)
{
...
...
lite/operators/beam_search_decode_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class BeamSearchDecodeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/beam_search_op.cc
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const {
...
@@ -30,7 +30,7 @@ bool BeamSearchOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
BeamSearchOp
::
InferShape
()
const
{
return
true
;
}
bool
BeamSearchOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
BeamSearchOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
BeamSearchOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
pre_ids
=
scope
->
FindTensor
(
opdesc
.
Input
(
"pre_ids"
).
front
());
param_
.
pre_ids
=
scope
->
FindTensor
(
opdesc
.
Input
(
"pre_ids"
).
front
());
...
...
lite/operators/beam_search_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite {
...
@@ -30,7 +30,7 @@ class BeamSearchOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/box_clip_op.cc
浏览文件 @
c754a38f
...
@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const {
...
@@ -35,7 +35,7 @@ bool BoxClipOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
BoxClipOpLite
::
InferShape
()
const
{
bool
BoxClipOpLite
::
InferShape
Impl
()
const
{
auto
*
input
=
param_
.
Input
;
auto
*
input
=
param_
.
Input
;
auto
*
output
=
param_
.
Output
;
auto
*
output
=
param_
.
Output
;
output
->
Resize
(
input
->
dims
());
output
->
Resize
(
input
->
dims
());
...
...
lite/operators/box_clip_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class BoxClipOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/box_coder_op.cc
浏览文件 @
c754a38f
...
@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const {
...
@@ -35,7 +35,7 @@ bool BoxCoderOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
BoxCoderOpLite
::
InferShape
()
const
{
bool
BoxCoderOpLite
::
InferShape
Impl
()
const
{
auto
prior_box_dims
=
param_
.
prior_box
->
dims
();
auto
prior_box_dims
=
param_
.
prior_box
->
dims
();
auto
target_box_dims
=
param_
.
target_box
->
dims
();
auto
target_box_dims
=
param_
.
target_box
->
dims
();
std
::
string
code_type
=
param_
.
code_type
;
std
::
string
code_type
=
param_
.
code_type
;
...
...
lite/operators/box_coder_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite {
...
@@ -29,7 +29,7 @@ class BoxCoderOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/calib_op.cc
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const {
...
@@ -24,7 +24,7 @@ bool CalibOpLite::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
output
);
CHECK_OR_FALSE
(
param_
.
output
);
return
true
;
return
true
;
}
}
bool
CalibOpLite
::
InferShape
()
const
{
bool
CalibOpLite
::
InferShape
Impl
()
const
{
param_
.
output
->
Resize
(
param_
.
input
->
dims
());
param_
.
output
->
Resize
(
param_
.
input
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/calib_op.h
浏览文件 @
c754a38f
...
@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite {
...
@@ -42,7 +42,7 @@ class CalibOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
);
...
...
lite/operators/cast_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool CastOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
CastOp
::
InferShape
()
const
{
bool
CastOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
out_dims
=
param_
.
X
->
dims
();
auto
out_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/cast_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class CastOp : public OpLite {
...
@@ -30,7 +30,7 @@ class CastOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/collect_fpn_proposals_op.cc
浏览文件 @
c754a38f
...
@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
...
@@ -43,7 +43,7 @@ bool CollectFpnProposalsOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
CollectFpnProposalsOpLite
::
InferShape
()
const
{
bool
CollectFpnProposalsOpLite
::
InferShape
Impl
()
const
{
param_
.
fpn_rois
->
Resize
({
param_
.
post_nms_topN
,
4
});
param_
.
fpn_rois
->
Resize
({
param_
.
post_nms_topN
,
4
});
return
true
;
return
true
;
...
...
lite/operators/collect_fpn_proposals_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class CollectFpnProposalsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/compare_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool CompareOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
CompareOp
::
InferShape
()
const
{
bool
CompareOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/compare_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class CompareOp : public OpLite {
...
@@ -30,7 +30,7 @@ class CompareOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/concat_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool ConcatOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ConcatOpLite
::
InferShape
()
const
{
bool
ConcatOpLite
::
InferShape
Impl
()
const
{
const
std
::
vector
<
Tensor
*>
&
inputs
=
param_
.
x
;
const
std
::
vector
<
Tensor
*>
&
inputs
=
param_
.
x
;
const
size_t
n
=
inputs
.
size
();
const
size_t
n
=
inputs
.
size
();
CHECK_GT_OR_FALSE
(
n
,
0
);
CHECK_GT_OR_FALSE
(
n
,
0
);
...
...
lite/operators/concat_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class ConcatOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/conditional_block_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool ConditionalBlockOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ConditionalBlockOpLite
::
InferShape
()
const
{
return
true
;
}
bool
ConditionalBlockOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
ConditionalBlockOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
bool
ConditionalBlockOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
lite
::
Scope
*
scope
)
{
...
...
lite/operators/conditional_block_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class ConditionalBlockOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/conv_op.cc
浏览文件 @
c754a38f
...
@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
...
@@ -80,35 +80,7 @@ void UpdatePaddingAndDilation(std::vector<int>* paddings,
}
}
}
}
bool
ConvOpLite
::
SmartInferShape
()
{
bool
ConvOpLite
::
InferShapeImpl
()
const
{
if
(
!
last_input_shapes
.
empty
())
{
if
(
last_input_shapes
[
0
]
==
param_
.
x
->
dims
()
&&
last_input_lods
[
0
]
==
param_
.
x
->
lod
())
{
param_
.
output
->
Resize
(
last_output_shapes
[
0
]);
param_
.
output
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
x
->
dims
());
last_input_lods
.
push_back
(
param_
.
x
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
output
->
dims
());
last_output_lods
.
push_back
(
param_
.
output
->
lod
());
return
true
;
}
bool
ConvOpLite
::
InferShape
()
const
{
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
...
...
lite/operators/conv_op.h
浏览文件 @
c754a38f
...
@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite {
...
@@ -34,9 +34,7 @@ class ConvOpLite : public OpLite {
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
explicit
ConvOpLite
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShapeImpl
()
const
override
;
bool
InferShape
()
const
override
;
bool
SmartInferShape
()
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
...
...
lite/operators/conv_transpose_op.cc
浏览文件 @
c754a38f
...
@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size,
...
@@ -52,7 +52,7 @@ inline int ConvTransposeOutputSize(int input_size,
return
output_size
;
return
output_size
;
}
}
bool
ConvTransposeOpLite
::
InferShape
()
const
{
bool
ConvTransposeOpLite
::
InferShape
Impl
()
const
{
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
in_dims
=
param_
.
x
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
const
auto
filter_dims
=
param_
.
filter
->
dims
();
...
...
lite/operators/conv_transpose_op.h
浏览文件 @
c754a38f
...
@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite {
...
@@ -34,7 +34,7 @@ class ConvTransposeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/crf_decoding_op.cc
浏览文件 @
c754a38f
...
@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const {
...
@@ -60,7 +60,7 @@ bool CrfDecodingOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
CrfDecodingOpLite
::
InferShape
()
const
{
bool
CrfDecodingOpLite
::
InferShape
Impl
()
const
{
auto
emission_dims
=
param_
.
emission
->
dims
();
auto
emission_dims
=
param_
.
emission
->
dims
();
if
(
param_
.
length
==
nullptr
)
{
if
(
param_
.
length
==
nullptr
)
{
param_
.
viterbi_path
->
Resize
({
emission_dims
[
0
],
1
});
param_
.
viterbi_path
->
Resize
({
emission_dims
[
0
],
1
});
...
...
lite/operators/crf_decoding_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class CrfDecodingOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/crop_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool CropOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
CropOpLite
::
InferShape
()
const
{
bool
CropOpLite
::
InferShape
Impl
()
const
{
// nchw
// nchw
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
lite
::
DDim
output_shape
(
x_dims
);
lite
::
DDim
output_shape
(
x_dims
);
...
...
lite/operators/crop_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class CropOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class CropOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/decode_bboxes_op.cc
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const {
...
@@ -29,7 +29,7 @@ bool DecodeBboxesOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
DecodeBboxesOpLite
::
InferShape
()
const
{
bool
DecodeBboxesOpLite
::
InferShape
Impl
()
const
{
param_
.
bbox_data
->
Resize
(
param_
.
loc_data
->
dims
());
param_
.
bbox_data
->
Resize
(
param_
.
loc_data
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/decode_bboxes_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite {
...
@@ -29,7 +29,7 @@ class DecodeBboxesOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/density_prior_box_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool DensityPriorBoxOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
DensityPriorBoxOpLite
::
InferShape
()
const
{
return
true
;
}
bool
DensityPriorBoxOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
DensityPriorBoxOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
bool
DensityPriorBoxOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
lite
::
Scope
*
scope
)
{
...
...
lite/operators/density_prior_box_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class DensityPriorBoxOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/distribute_fpn_proposals_op.cc
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const {
...
@@ -32,7 +32,7 @@ bool DistributeFpnProposalsOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
DistributeFpnProposalsOpLite
::
InferShape
()
const
{
bool
DistributeFpnProposalsOpLite
::
InferShape
Impl
()
const
{
int
num_out_rois
=
param_
.
max_level
-
param_
.
min_level
+
1
;
int
num_out_rois
=
param_
.
max_level
-
param_
.
min_level
+
1
;
for
(
int
i
=
0
;
i
<
num_out_rois
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_out_rois
;
i
++
)
{
param_
.
multi_fpn_rois
[
i
]
->
Resize
({
-
1
,
4
});
param_
.
multi_fpn_rois
[
i
]
->
Resize
({
-
1
,
4
});
...
...
lite/operators/distribute_fpn_proposals_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class DistributeFpnProposalsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/dropout_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool DropoutOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
DropoutOp
::
InferShape
()
const
{
bool
DropoutOp
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
x_dims
=
param_
.
x
->
dims
();
param_
.
output
->
Resize
(
x_dims
);
param_
.
output
->
Resize
(
x_dims
);
if
(
param_
.
is_test
==
false
)
{
if
(
param_
.
is_test
==
false
)
{
...
...
lite/operators/dropout_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class DropoutOp : public OpLite {
...
@@ -28,7 +28,7 @@ class DropoutOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
...
...
lite/operators/elementwise_grad_ops.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool ElementwiseGradOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ElementwiseGradOp
::
InferShape
()
const
{
bool
ElementwiseGradOp
::
InferShape
Impl
()
const
{
auto
x_dim
=
param_
.
X
->
dims
();
auto
x_dim
=
param_
.
X
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
if
(
param_
.
XGrad
)
{
if
(
param_
.
XGrad
)
{
...
...
lite/operators/elementwise_grad_ops.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite {
...
@@ -27,7 +27,7 @@ class ElementwiseGradOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/elementwise_ops.cc
浏览文件 @
c754a38f
...
@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const {
...
@@ -26,39 +26,8 @@ bool ElementwiseOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
return
true
;
}
}
bool
ElementwiseOp
::
SmartInferShape
()
{
if
(
!
last_input_shapes
.
empty
())
{
if
(
last_input_shapes
[
0
]
==
param_
.
X
->
dims
()
&&
last_input_shapes
[
1
]
==
param_
.
Y
->
dims
()
&&
last_input_lods
[
0
]
==
param_
.
X
->
lod
()
&&
last_input_lods
[
1
]
==
param_
.
Y
->
lod
())
{
param_
.
Out
->
Resize
(
last_output_shapes
[
0
]);
param_
.
Out
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
X
->
dims
());
last_input_lods
.
push_back
(
param_
.
X
->
lod
());
last_input_shapes
.
push_back
(
param_
.
Y
->
dims
());
last_input_lods
.
push_back
(
param_
.
Y
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
bool
ElementwiseOp
::
InferShapeImpl
()
const
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
Out
->
dims
());
last_output_lods
.
push_back
(
param_
.
Out
->
lod
());
return
true
;
}
bool
ElementwiseOp
::
InferShape
()
const
{
auto
x_dim
=
param_
.
X
->
dims
();
auto
x_dim
=
param_
.
X
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
auto
y_dim
=
param_
.
Y
->
dims
();
if
(
x_dim
==
y_dim
)
{
if
(
x_dim
==
y_dim
)
{
...
@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
...
@@ -136,7 +105,7 @@ bool ElementwiseOp::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
// return true;
// return true;
//}
//}
// bool ElementwiseGradExplicitOp::InferShape() const {
// bool ElementwiseGradExplicitOp::InferShape
Impl
() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// param_.X_grad->Resize(param_.Out_grad->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// if (param_.Y_grad) param_.Y_grad->Resize(param_.Y->dims());
// return true;
// return true;
...
...
lite/operators/elementwise_ops.h
浏览文件 @
c754a38f
...
@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite {
...
@@ -27,8 +27,7 @@ class ElementwiseOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShapeImpl
()
const
override
;
bool
SmartInferShape
()
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite {
...
@@ -48,7 +47,7 @@ class ElementwiseOp : public OpLite {
// bool CheckShape() const override;
// bool CheckShape() const override;
// bool InferShape() const override;
// bool InferShape
Impl
() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...
...
lite/operators/expand_op.cc
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const {
...
@@ -32,7 +32,7 @@ bool ExpandOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ExpandOpLite
::
InferShape
()
const
{
bool
ExpandOpLite
::
InferShape
Impl
()
const
{
DDim
out_dims
(
param_
.
X
->
dims
());
DDim
out_dims
(
param_
.
X
->
dims
());
for
(
size_t
i
=
0
;
i
<
param_
.
expand_times
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
param_
.
expand_times
.
size
();
++
i
)
{
out_dims
[
i
]
*=
param_
.
expand_times
[
i
];
out_dims
[
i
]
*=
param_
.
expand_times
[
i
];
...
...
lite/operators/expand_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite {
...
@@ -28,7 +28,7 @@ class ExpandOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/fake_channel_wise_dequantize_max_abs.h
浏览文件 @
c754a38f
...
@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
...
@@ -36,7 +36,7 @@ class FakeChannelWiseDequantizeMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_dequantize_max_abs.h
浏览文件 @
c754a38f
...
@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite {
...
@@ -35,7 +35,7 @@ class FakeDequantizeMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h
浏览文件 @
c754a38f
...
@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
...
@@ -36,7 +36,7 @@ class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_quantize_moving_avg_max_abs.h
浏览文件 @
c754a38f
...
@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
...
@@ -36,7 +36,7 @@ class FakeQuantizeMovingAvgMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fake_quantize_range_abs_max.h
浏览文件 @
c754a38f
...
@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
...
@@ -36,7 +36,7 @@ class FakeQuantizeRangeMaxAbsOpLite : public OpLite {
bool
CheckShape
()
const
override
{
return
true
;
}
bool
CheckShape
()
const
override
{
return
true
;
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
auto
x
=
op_desc
.
Input
(
"X"
).
front
();
...
...
lite/operators/fc_op.cc
浏览文件 @
c754a38f
...
@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const {
...
@@ -48,34 +48,7 @@ bool FcOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
FcOpLite
::
SmartInferShape
()
{
bool
FcOpLite
::
InferShapeImpl
()
const
{
if
(
!
last_input_shapes
.
empty
()
&&
!
last_output_shapes
.
empty
())
{
if
(
last_input_shapes
[
0
]
==
param_
.
input
->
dims
()
&&
last_input_lods
[
0
]
==
param_
.
input
->
lod
())
{
param_
.
output
->
Resize
(
last_output_shapes
[
0
]);
param_
.
output
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
input
->
dims
());
last_input_lods
.
push_back
(
param_
.
input
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
output
->
dims
());
last_output_lods
.
push_back
(
param_
.
output
->
lod
());
return
true
;
}
bool
FcOpLite
::
InferShape
()
const
{
const
auto
&
input_dims
=
param_
.
input
->
dims
();
const
auto
&
input_dims
=
param_
.
input
->
dims
();
const
auto
&
w_dims
=
param_
.
w
->
dims
();
const
auto
&
w_dims
=
param_
.
w
->
dims
();
int
in_num_col_dims
=
param_
.
in_num_col_dims
;
int
in_num_col_dims
=
param_
.
in_num_col_dims
;
...
...
lite/operators/fc_op.h
浏览文件 @
c754a38f
...
@@ -35,8 +35,7 @@ class FcOpLite : public OpLite {
...
@@ -35,8 +35,7 @@ class FcOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShapeImpl
()
const
override
;
bool
SmartInferShape
()
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/feed_op.cc
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class FeedOp : public OpLite {
...
@@ -29,7 +29,7 @@ class FeedOp : public OpLite {
return
true
;
return
true
;
}
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/fetch_op.cc
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class FetchOp : public OpLite {
...
@@ -29,7 +29,7 @@ class FetchOp : public OpLite {
return
true
;
return
true
;
}
}
bool
InferShape
()
const
override
{
return
true
;
}
bool
InferShape
Impl
()
const
override
{
return
true
;
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
protected:
protected:
...
...
lite/operators/fill_constant_batch_size_like_op.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool FillConstantBatchSizeLikeOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
FillConstantBatchSizeLikeOp
::
InferShape
()
const
{
bool
FillConstantBatchSizeLikeOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
output_dim
{
param_
.
shape
.
begin
(),
param_
.
shape
.
end
()};
std
::
vector
<
int64_t
>
output_dim
{
param_
.
shape
.
begin
(),
param_
.
shape
.
end
()};
if
(
param_
.
input_dim_idx
==
0
&&
!
param_
.
input
->
lod
().
empty
())
{
if
(
param_
.
input_dim_idx
==
0
&&
!
param_
.
input
->
lod
().
empty
())
{
output_dim
[
param_
.
output_dim_idx
]
=
param_
.
input
->
lod
().
back
().
size
()
-
1
;
output_dim
[
param_
.
output_dim_idx
]
=
param_
.
input
->
lod
().
back
().
size
()
-
1
;
...
...
lite/operators/fill_constant_batch_size_like_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite {
...
@@ -32,7 +32,7 @@ class FillConstantBatchSizeLikeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/fill_constant_op.cc
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const {
...
@@ -24,7 +24,7 @@ bool FillConstantOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
FillConstantOp
::
InferShape
()
const
{
bool
FillConstantOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
out_shape
;
std
::
vector
<
int64_t
>
out_shape
;
auto
shape_tensor
=
param_
.
shape_tensor
;
auto
shape_tensor
=
param_
.
shape_tensor
;
auto
shape_tensor_list
=
param_
.
shape_tensor_list
;
auto
shape_tensor_list
=
param_
.
shape_tensor_list
;
...
...
lite/operators/fill_constant_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite {
...
@@ -31,7 +31,7 @@ class FillConstantOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/flatten_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool FlattenOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
FlattenOp
::
InferShape
()
const
{
bool
FlattenOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
auto
out_lod
=
param_
.
output
->
mutable_lod
();
auto
out_lod
=
param_
.
output
->
mutable_lod
();
...
@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const {
...
@@ -71,8 +71,8 @@ bool Flatten2Op::CheckShape() const {
return
true
;
return
true
;
}
}
bool
Flatten2Op
::
InferShape
()
const
{
bool
Flatten2Op
::
InferShape
Impl
()
const
{
FlattenOp
::
InferShape
();
FlattenOp
::
InferShape
Impl
();
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
0
);
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
...
...
lite/operators/flatten_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class FlattenOp : public OpLite {
...
@@ -30,7 +30,7 @@ class FlattenOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp {
...
@@ -49,7 +49,7 @@ class Flatten2Op : public FlattenOp {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/fusion_elementwise_activation_ops.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool FusionElementwiseActivationOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
FusionElementwiseActivationOp
::
InferShape
()
const
{
bool
FusionElementwiseActivationOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
X
->
dims
().
size
()
>=
param_
.
Y
->
dims
().
size
());
CHECK_OR_FALSE
(
param_
.
X
->
dims
().
size
()
>=
param_
.
Y
->
dims
().
size
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
return
true
;
...
@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
...
@@ -59,7 +59,7 @@ bool FusionElementwiseActivationOp::AttachImpl(const cpp::OpDesc& opdesc,
// return true;
// return true;
// }
// }
// bool FusionElementwiseActivationGradExplicitOp::InferShape() const {
// bool FusionElementwiseActivationGradExplicitOp::InferShape
Impl
() const {
// param_.X_grad->Resize(param_.Out_grad->dims());
// param_.X_grad->Resize(param_.Out_grad->dims());
// param_.Y_grad->Resize(param_.Y->dims());
// param_.Y_grad->Resize(param_.Y->dims());
// return true;
// return true;
...
...
lite/operators/fusion_elementwise_activation_ops.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite {
...
@@ -29,7 +29,7 @@ class FusionElementwiseActivationOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite {
...
@@ -51,7 +51,7 @@ class FusionElementwiseActivationOp : public OpLite {
// bool CheckShape() const override;
// bool CheckShape() const override;
// bool InferShape() const override;
// bool InferShape
Impl
() const override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
// bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override;
...
...
lite/operators/gather_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool GatherOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
GatherOp
::
InferShape
()
const
{
bool
GatherOp
::
InferShape
Impl
()
const
{
auto
index_dims
=
param_
.
Index
->
dims
();
auto
index_dims
=
param_
.
Index
->
dims
();
CHECK
(
index_dims
.
size
()
==
1
||
CHECK
(
index_dims
.
size
()
==
1
||
(
index_dims
.
size
()
==
2
&&
index_dims
[
1
]
==
1
))
(
index_dims
.
size
()
==
2
&&
index_dims
[
1
]
==
1
))
...
...
lite/operators/gather_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class GatherOp : public OpLite {
...
@@ -30,7 +30,7 @@ class GatherOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/generate_proposals_op.cc
浏览文件 @
c754a38f
...
@@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const {
...
@@ -43,7 +43,7 @@ bool GenerateProposalsOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
GenerateProposalsOpLite
::
InferShape
()
const
{
bool
GenerateProposalsOpLite
::
InferShape
Impl
()
const
{
param_
.
RpnRois
->
Resize
(
std
::
vector
<
int64_t
>
({
-
1
,
4
}));
param_
.
RpnRois
->
Resize
(
std
::
vector
<
int64_t
>
({
-
1
,
4
}));
param_
.
RpnRoiProbs
->
Resize
(
std
::
vector
<
int64_t
>
({
-
1
,
1
}));
param_
.
RpnRoiProbs
->
Resize
(
std
::
vector
<
int64_t
>
({
-
1
,
1
}));
return
true
;
return
true
;
...
...
lite/operators/generate_proposals_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class GenerateProposalsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/grid_sampler_op.cc
浏览文件 @
c754a38f
...
@@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const {
...
@@ -42,7 +42,7 @@ bool GridSamplerOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
GridSamplerOp
::
InferShape
()
const
{
bool
GridSamplerOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
param_
.
out
->
Resize
(
x_dims
);
param_
.
out
->
Resize
(
x_dims
);
return
true
;
return
true
;
...
...
lite/operators/grid_sampler_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite {
...
@@ -31,7 +31,7 @@ class GridSamplerOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/gru_op.cc
浏览文件 @
c754a38f
...
@@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const {
...
@@ -51,7 +51,7 @@ bool GRUOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
GRUOpLite
::
InferShape
()
const
{
bool
GRUOpLite
::
InferShape
Impl
()
const
{
const
auto
&
input_dims
=
param_
.
input
->
dims
();
const
auto
&
input_dims
=
param_
.
input
->
dims
();
const
auto
&
weight_dims
=
param_
.
weight
->
dims
();
const
auto
&
weight_dims
=
param_
.
weight
->
dims
();
int
frame_size
=
weight_dims
[
0
];
int
frame_size
=
weight_dims
[
0
];
...
...
lite/operators/gru_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class GRUOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class GRUOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/gru_unit_op.cc
浏览文件 @
c754a38f
...
@@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const {
...
@@ -51,7 +51,7 @@ bool GRUUnitOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
GRUUnitOpLite
::
InferShape
()
const
{
bool
GRUUnitOpLite
::
InferShape
Impl
()
const
{
auto
input_dims
=
param_
.
input
->
dims
();
auto
input_dims
=
param_
.
input
->
dims
();
auto
hidden_prev_dims
=
param_
.
hidden_prev
->
dims
();
auto
hidden_prev_dims
=
param_
.
hidden_prev
->
dims
();
auto
weight_dims
=
param_
.
weight
->
dims
();
auto
weight_dims
=
param_
.
weight
->
dims
();
...
...
lite/operators/gru_unit_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class GRUUnitOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/im2sequence_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ inline int Im2SeqOutputSize(
...
@@ -26,7 +26,7 @@ inline int Im2SeqOutputSize(
}
}
bool
Im2SequenceOp
::
CheckShape
()
const
{
return
true
;
}
bool
Im2SequenceOp
::
CheckShape
()
const
{
return
true
;
}
bool
Im2SequenceOp
::
InferShape
()
const
{
bool
Im2SequenceOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/im2sequence_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite {
...
@@ -30,7 +30,7 @@ class Im2SequenceOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/increment_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool IncrementOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
IncrementOp
::
InferShape
()
const
{
bool
IncrementOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
out_dims
=
param_
.
X
->
dims
();
auto
out_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/increment_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class IncrementOp : public OpLite {
...
@@ -30,7 +30,7 @@ class IncrementOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/instance_norm_op.cc
浏览文件 @
c754a38f
...
@@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const {
...
@@ -42,7 +42,7 @@ bool InstanceNormOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
InstanceNormOp
::
InferShape
()
const
{
bool
InstanceNormOp
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
int64_t
batch_size
=
x_dims
[
0
];
int64_t
batch_size
=
x_dims
[
0
];
int64_t
channel_size
=
x_dims
[
1
];
int64_t
channel_size
=
x_dims
[
1
];
...
...
lite/operators/instance_norm_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite {
...
@@ -31,7 +31,7 @@ class InstanceNormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/interpolate_op.cc
浏览文件 @
c754a38f
...
@@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const {
...
@@ -34,7 +34,7 @@ bool InterpolateOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
InterpolateOp
::
InferShape
()
const
{
bool
InterpolateOp
::
InferShape
Impl
()
const
{
auto
X
=
param_
.
X
;
auto
X
=
param_
.
X
;
int
n
=
X
->
dims
()[
0
];
int
n
=
X
->
dims
()[
0
];
...
...
lite/operators/interpolate_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class InterpolateOp : public OpLite {
...
@@ -31,7 +31,7 @@ class InterpolateOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/io_copy_op.cc
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const {
...
@@ -24,7 +24,7 @@ bool IoCopyOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
y
);
CHECK_OR_FALSE
(
param_
.
y
);
return
true
;
return
true
;
}
}
bool
IoCopyOp
::
InferShape
()
const
{
bool
IoCopyOp
::
InferShape
Impl
()
const
{
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/io_copy_op.h
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ class IoCopyOp : public OpLite {
...
@@ -24,7 +24,7 @@ class IoCopyOp : public OpLite {
public:
public:
explicit
IoCopyOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
explicit
IoCopyOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
Run
()
override
;
bool
Run
()
override
;
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
...
...
lite/operators/is_empty_op.cc
浏览文件 @
c754a38f
...
@@ -21,7 +21,7 @@ namespace operators {
...
@@ -21,7 +21,7 @@ namespace operators {
bool
IsEmptyOp
::
CheckShape
()
const
{
return
true
;
}
bool
IsEmptyOp
::
CheckShape
()
const
{
return
true
;
}
bool
IsEmptyOp
::
InferShape
()
const
{
return
true
;
}
bool
IsEmptyOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
IsEmptyOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
IsEmptyOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
X
=
param_
.
X
=
...
...
lite/operators/is_empty_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite {
...
@@ -30,7 +30,7 @@ class IsEmptyOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/layer_norm_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool LayerNormOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LayerNormOp
::
InferShape
()
const
{
bool
LayerNormOp
::
InferShape
Impl
()
const
{
auto
out_dims
=
param_
.
X
->
dims
();
auto
out_dims
=
param_
.
X
->
dims
();
param_
.
Y
->
Resize
(
out_dims
);
param_
.
Y
->
Resize
(
out_dims
);
auto
inner_size
=
out_dims
.
Flatten2D
(
param_
.
begin_norm_axis
)[
0
];
auto
inner_size
=
out_dims
.
Flatten2D
(
param_
.
begin_norm_axis
)[
0
];
...
...
lite/operators/layer_norm_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class LayerNormOp : public OpLite {
...
@@ -30,7 +30,7 @@ class LayerNormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/layout_op.cc
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const {
...
@@ -24,7 +24,7 @@ bool LayoutOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
y
);
CHECK_OR_FALSE
(
param_
.
y
);
return
true
;
return
true
;
}
}
bool
LayoutOp
::
InferShape
()
const
{
bool
LayoutOp
::
InferShape
Impl
()
const
{
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
param_
.
y
->
Resize
(
param_
.
x
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/layout_op.h
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ class LayoutOp : public OpLite {
...
@@ -24,7 +24,7 @@ class LayoutOp : public OpLite {
public:
public:
explicit
LayoutOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
explicit
LayoutOp
(
const
std
::
string
&
type
)
:
OpLite
(
type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
Run
()
override
;
bool
Run
()
override
;
std
::
string
DebugString
()
const
override
;
std
::
string
DebugString
()
const
override
;
...
...
lite/operators/lod_reset_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool LodResetOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LodResetOp
::
InferShape
()
const
{
bool
LodResetOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
...
...
lite/operators/lod_reset_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class LodResetOp : public OpLite {
...
@@ -30,7 +30,7 @@ class LodResetOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/logical_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool BinaryLogicalOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
BinaryLogicalOp
::
InferShape
()
const
{
bool
BinaryLogicalOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
auto
input_dims
=
param_
.
X
->
dims
();
...
@@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const {
...
@@ -53,7 +53,7 @@ bool UnaryLogicalOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
UnaryLogicalOp
::
InferShape
()
const
{
bool
UnaryLogicalOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/logical_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite {
...
@@ -30,7 +30,7 @@ class BinaryLogicalOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite {
...
@@ -49,7 +49,7 @@ class UnaryLogicalOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lookup_table_dequant_op.cc
浏览文件 @
c754a38f
...
@@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const {
...
@@ -36,7 +36,7 @@ bool LookupTableDequantOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LookupTableDequantOpLite
::
InferShape
()
const
{
bool
LookupTableDequantOpLite
::
InferShape
Impl
()
const
{
const
auto
&
table_dims
=
param_
.
W
->
dims
();
const
auto
&
table_dims
=
param_
.
W
->
dims
();
const
auto
&
ids_dims
=
param_
.
Ids
->
dims
();
const
auto
&
ids_dims
=
param_
.
Ids
->
dims
();
...
...
lite/operators/lookup_table_dequant_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class LookupTableDequantOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lookup_table_op.cc
浏览文件 @
c754a38f
...
@@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const {
...
@@ -36,7 +36,7 @@ bool LookupTableOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LookupTableOpLite
::
InferShape
()
const
{
bool
LookupTableOpLite
::
InferShape
Impl
()
const
{
const
auto
&
table_dims
=
param_
.
W
->
dims
();
const
auto
&
table_dims
=
param_
.
W
->
dims
();
const
auto
&
ids_dims
=
param_
.
Ids
->
dims
();
const
auto
&
ids_dims
=
param_
.
Ids
->
dims
();
...
...
lite/operators/lookup_table_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class LookupTableOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lookup_table_v2_op.cc
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const {
...
@@ -32,7 +32,7 @@ bool LookupTableV2OpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LookupTableV2OpLite
::
InferShape
()
const
{
bool
LookupTableV2OpLite
::
InferShape
Impl
()
const
{
auto
table_dims
=
param_
.
W
->
dims
();
auto
table_dims
=
param_
.
W
->
dims
();
auto
ids_dims
=
param_
.
Ids
->
dims
();
auto
ids_dims
=
param_
.
Ids
->
dims
();
...
...
lite/operators/lookup_table_v2_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class LookupTableV2OpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lrn_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool LrnOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LrnOpLite
::
InferShape
()
const
{
bool
LrnOpLite
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/lrn_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class LrnOpLite : public OpLite {
...
@@ -28,7 +28,7 @@ class LrnOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/lstm_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool LstmOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
LstmOp
::
InferShape
()
const
{
bool
LstmOp
::
InferShape
Impl
()
const
{
auto
in_dims
=
param_
.
Input
->
dims
();
auto
in_dims
=
param_
.
Input
->
dims
();
if
(
param_
.
H0
)
{
if
(
param_
.
H0
)
{
CHECK
(
param_
.
C0
)
<<
"lstm must has H0 and C0 in the same time"
;
CHECK
(
param_
.
C0
)
<<
"lstm must has H0 and C0 in the same time"
;
...
...
lite/operators/lstm_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class LstmOp : public OpLite {
...
@@ -30,7 +30,7 @@ class LstmOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/match_matrix_tensor_op.cc
浏览文件 @
c754a38f
...
@@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const {
...
@@ -42,7 +42,7 @@ bool MatchMatrixTensorOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MatchMatrixTensorOpLite
::
InferShape
()
const
{
bool
MatchMatrixTensorOpLite
::
InferShape
Impl
()
const
{
const
Tensor
*
x
=
param_
.
x
;
const
Tensor
*
x
=
param_
.
x
;
const
Tensor
*
y
=
param_
.
y
;
const
Tensor
*
y
=
param_
.
y
;
DDim
x_dims
=
param_
.
x
->
dims
();
DDim
x_dims
=
param_
.
x
->
dims
();
...
...
lite/operators/match_matrix_tensor_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class MatchMatrixTensorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/matmul_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool MatMulOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MatMulOpLite
::
InferShape
()
const
{
bool
MatMulOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
y_dims
=
param_
.
Y
->
dims
();
const
auto
y_dims
=
param_
.
Y
->
dims
();
bool
x_transpose
=
param_
.
transpose_X
;
bool
x_transpose
=
param_
.
transpose_X
;
...
...
lite/operators/matmul_op.h
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite {
...
@@ -33,7 +33,7 @@ class MatMulOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/mean_grad_op.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool MeanGradOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MeanGradOp
::
InferShape
()
const
{
bool
MeanGradOp
::
InferShape
Impl
()
const
{
param_
.
X_grad
->
Resize
(
param_
.
X
->
dims
());
param_
.
X_grad
->
Resize
(
param_
.
X
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/mean_grad_op.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class MeanGradOp : public OpLite {
...
@@ -27,7 +27,7 @@ class MeanGradOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/mean_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool MeanOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MeanOp
::
InferShape
()
const
{
bool
MeanOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
std
::
vector
<
int64_t
>
{
1
});
param_
.
Out
->
Resize
(
std
::
vector
<
int64_t
>
{
1
});
return
true
;
return
true
;
}
}
...
...
lite/operators/mean_op.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class MeanOp : public OpLite {
...
@@ -27,7 +27,7 @@ class MeanOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/merge_lod_tensor_op.cc
浏览文件 @
c754a38f
...
@@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const {
...
@@ -34,7 +34,7 @@ bool MergeLodTensorOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MergeLodTensorOpLite
::
InferShape
()
const
{
bool
MergeLodTensorOpLite
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
in_true
->
dims
();
auto
dims
=
param_
.
in_true
->
dims
();
param_
.
out
->
Resize
(
dims
);
param_
.
out
->
Resize
(
dims
);
return
true
;
return
true
;
...
...
lite/operators/merge_lod_tensor_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class MergeLodTensorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/mul_grad_op.cc
浏览文件 @
c754a38f
...
@@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const {
...
@@ -46,7 +46,7 @@ bool MulGradOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MulGradOpLite
::
InferShape
()
const
{
bool
MulGradOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
y_dims
=
param_
.
y
->
dims
();
const
auto
y_dims
=
param_
.
y
->
dims
();
if
(
param_
.
x_grad
)
{
if
(
param_
.
x_grad
)
{
...
...
lite/operators/mul_grad_op.h
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite {
...
@@ -33,7 +33,7 @@ class MulGradOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/mul_op.cc
浏览文件 @
c754a38f
...
@@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const {
...
@@ -35,7 +35,7 @@ bool MulOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MulOpLite
::
InferShape
()
const
{
bool
MulOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
y_dims
=
param_
.
y
->
dims
();
const
auto
y_dims
=
param_
.
y
->
dims
();
...
...
lite/operators/mul_op.h
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ class MulOpLite : public OpLite {
...
@@ -33,7 +33,7 @@ class MulOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
...
...
lite/operators/multiclass_nms_op.cc
浏览文件 @
c754a38f
...
@@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const {
...
@@ -41,7 +41,7 @@ bool MulticlassNmsOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
MulticlassNmsOpLite
::
InferShape
()
const
{
bool
MulticlassNmsOpLite
::
InferShape
Impl
()
const
{
auto
box_dims
=
param_
.
bboxes
->
dims
();
auto
box_dims
=
param_
.
bboxes
->
dims
();
auto
score_dims
=
param_
.
scores
->
dims
();
auto
score_dims
=
param_
.
scores
->
dims
();
auto
score_size
=
score_dims
.
size
();
auto
score_size
=
score_dims
.
size
();
...
...
lite/operators/multiclass_nms_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite {
...
@@ -29,7 +29,7 @@ class MulticlassNmsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/negative_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool NegativeOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
NegativeOpLite
::
InferShape
()
const
{
bool
NegativeOpLite
::
InferShape
Impl
()
const
{
lite
::
DDim
input_dims
;
lite
::
DDim
input_dims
;
input_dims
=
param_
.
X
->
dims
();
input_dims
=
param_
.
X
->
dims
();
param_
.
Out
->
Resize
(
lite
::
DDim
(
input_dims
));
param_
.
Out
->
Resize
(
lite
::
DDim
(
input_dims
));
...
...
lite/operators/negative_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class NegativeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/norm_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool NormOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool NormOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
NormOp
::
InferShape
()
const
{
bool
NormOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
out_dims
=
param_
.
X
->
dims
();
auto
out_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/norm_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class NormOp : public OpLite {
...
@@ -30,7 +30,7 @@ class NormOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/op_params.h
浏览文件 @
c754a38f
...
@@ -24,6 +24,7 @@
...
@@ -24,6 +24,7 @@
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/cpp/block_desc.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/utils/all.h"
#include "lite/utils/all.h"
#include "lite/utils/variant.h"
/*
/*
* This file contains all the argument parameter data structure for operators.
* This file contains all the argument parameter data structure for operators.
*/
*/
...
@@ -32,6 +33,16 @@ namespace paddle {
...
@@ -32,6 +33,16 @@ namespace paddle {
namespace
lite
{
namespace
lite
{
namespace
operators
{
namespace
operators
{
struct
ParamBase
{
public:
const
std
::
vector
<
Tensor
*>*
input_tensor_ptrs
()
const
{
return
nullptr
;
}
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
return
nullptr
;
}
protected:
std
::
shared_ptr
<
std
::
vector
<
const
Tensor
*>>
input_tensor_ptrs_cache_
{
nullptr
};
std
::
shared_ptr
<
std
::
vector
<
Tensor
*>>
output_tensor_ptrs_cache_
{
nullptr
};
};
using
param_t
=
Any
;
using
param_t
=
Any
;
#define WITH_INT8_CONFIG \
#define WITH_INT8_CONFIG \
bool enable_int8{false}; \
bool enable_int8{false}; \
...
@@ -41,38 +52,38 @@ using param_t = Any;
...
@@ -41,38 +52,38 @@ using param_t = Any;
int bit_length{8};
int bit_length{8};
/// ----------------------- Functional operators ------------------------------
/// ----------------------- Functional operators ------------------------------
struct
FeedParam
{
struct
FeedParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
>*
feed_list
{};
std
::
vector
<
lite
::
Tensor
>*
feed_list
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
int
col
;
int
col
;
};
};
struct
FetchParam
{
struct
FetchParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{};
const
lite
::
Tensor
*
input
{};
std
::
vector
<
lite
::
Tensor
>*
fetch_list
{};
std
::
vector
<
lite
::
Tensor
>*
fetch_list
{};
int
col
;
int
col
;
};
};
// Helper op for lite framework
// Helper op for lite framework
struct
IoCopyParam
{
struct
IoCopyParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
y
{};
int
process_type
{
0
};
int
process_type
{
0
};
};
};
struct
LayoutParam
{
struct
LayoutParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
y
{};
int
process_type
{
0
};
int
process_type
{
0
};
};
};
struct
CalibParam
{
struct
CalibParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{};
const
lite
::
Tensor
*
input
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
float
scale
;
float
scale
;
};
};
struct
SubgraphParam
{
struct
SubgraphParam
:
ParamBase
{
std
::
vector
<
std
::
string
>
input_names
{};
std
::
vector
<
std
::
string
>
input_names
{};
std
::
vector
<
std
::
string
>
output_names
{};
std
::
vector
<
std
::
string
>
output_names
{};
std
::
vector
<
std
::
string
>
input_data_names
{};
std
::
vector
<
std
::
string
>
input_data_names
{};
...
@@ -84,7 +95,7 @@ struct SubgraphParam {
...
@@ -84,7 +95,7 @@ struct SubgraphParam {
/// -------------------------- NN operators ------------------------------------
/// -------------------------- NN operators ------------------------------------
struct
FcParam
{
struct
FcParam
:
ParamBase
{
lite
::
Tensor
*
input
{
nullptr
};
lite
::
Tensor
*
input
{
nullptr
};
lite
::
Tensor
*
w
{
nullptr
};
lite
::
Tensor
*
w
{
nullptr
};
lite
::
Tensor
*
bias
{
nullptr
};
lite
::
Tensor
*
bias
{
nullptr
};
...
@@ -95,9 +106,24 @@ struct FcParam {
...
@@ -95,9 +106,24 @@ struct FcParam {
bool
padding_weights
{
false
};
bool
padding_weights
{
false
};
// for int8
// for int8
WITH_INT8_CONFIG
WITH_INT8_CONFIG
};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
struct
SearchSeqFcParam
{
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
input
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
output
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
struct
SearchSeqFcParam
:
ParamBase
{
lite
::
Tensor
*
x
{
nullptr
};
lite
::
Tensor
*
x
{
nullptr
};
lite
::
Tensor
*
w
{
nullptr
};
lite
::
Tensor
*
w
{
nullptr
};
lite
::
Tensor
*
b
{
nullptr
};
lite
::
Tensor
*
b
{
nullptr
};
...
@@ -106,7 +132,7 @@ struct SearchSeqFcParam {
...
@@ -106,7 +132,7 @@ struct SearchSeqFcParam {
};
};
// For Interpolate Op
// For Interpolate Op
struct
InterpolateParam
{
struct
InterpolateParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
OutSize
{};
lite
::
Tensor
*
OutSize
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -123,7 +149,7 @@ struct InterpolateParam {
...
@@ -123,7 +149,7 @@ struct InterpolateParam {
};
};
// For Mul Op
// For Mul Op
struct
MulParam
{
struct
MulParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
y
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
...
@@ -134,7 +160,7 @@ struct MulParam {
...
@@ -134,7 +160,7 @@ struct MulParam {
WITH_INT8_CONFIG
WITH_INT8_CONFIG
};
};
struct
MulGradParam
{
struct
MulGradParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
output_grad
{};
const
lite
::
Tensor
*
output_grad
{};
...
@@ -146,7 +172,7 @@ struct MulGradParam {
...
@@ -146,7 +172,7 @@ struct MulGradParam {
};
};
// For ReduceMean Op
// For ReduceMean Op
struct
ReduceMeanParam
{
struct
ReduceMeanParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -155,7 +181,7 @@ struct ReduceMeanParam {
...
@@ -155,7 +181,7 @@ struct ReduceMeanParam {
};
};
// For Stack Op
// For Stack Op
struct
StackParam
{
struct
StackParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
X
;
std
::
vector
<
lite
::
Tensor
*>
X
;
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -163,7 +189,7 @@ struct StackParam {
...
@@ -163,7 +189,7 @@ struct StackParam {
};
};
// For Power Op
// For Power Op
struct
PowerParam
{
struct
PowerParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -172,7 +198,7 @@ struct PowerParam {
...
@@ -172,7 +198,7 @@ struct PowerParam {
float
power
{};
float
power
{};
};
};
struct
ShuffleChannelParam
{
struct
ShuffleChannelParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -180,7 +206,7 @@ struct ShuffleChannelParam {
...
@@ -180,7 +206,7 @@ struct ShuffleChannelParam {
};
};
// For Yolobox
// For Yolobox
struct
YoloBoxParam
{
struct
YoloBoxParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
ImgSize
{};
lite
::
Tensor
*
ImgSize
{};
lite
::
Tensor
*
Boxes
{};
lite
::
Tensor
*
Boxes
{};
...
@@ -193,7 +219,7 @@ struct YoloBoxParam {
...
@@ -193,7 +219,7 @@ struct YoloBoxParam {
};
};
// For Scale Op
// For Scale Op
struct
ScaleParam
{
struct
ScaleParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
...
@@ -203,14 +229,29 @@ struct ScaleParam {
...
@@ -203,14 +229,29 @@ struct ScaleParam {
};
};
// For Softmax op
// For Softmax op
struct
SoftmaxParam
{
struct
SoftmaxParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
int
axis
{
-
1
};
int
axis
{
-
1
};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
x
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
output
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
};
// For Reshape and Reshape2 Op
// For Reshape and Reshape2 Op
struct
ReshapeParam
{
struct
ReshapeParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
std
::
vector
<
const
lite
::
Tensor
*>
shape_tensor_vct
{};
std
::
vector
<
const
lite
::
Tensor
*>
shape_tensor_vct
{};
const
lite
::
Tensor
*
shape_tensor
{};
const
lite
::
Tensor
*
shape_tensor
{};
...
@@ -222,7 +263,7 @@ struct ReshapeParam {
...
@@ -222,7 +263,7 @@ struct ReshapeParam {
};
};
// For Concat op
// For Concat op
struct
ConcatParam
{
struct
ConcatParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
x
{};
std
::
vector
<
lite
::
Tensor
*>
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
int
axis
{
0
};
int
axis
{
0
};
...
@@ -230,7 +271,7 @@ struct ConcatParam {
...
@@ -230,7 +271,7 @@ struct ConcatParam {
};
};
/// ----------------------- activation operators ----------------------
/// ----------------------- activation operators ----------------------
struct
ActivationParam
{
struct
ActivationParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
float
Leaky_relu_alpha
{
0
};
// leaky_relu param
float
Leaky_relu_alpha
{
0
};
// leaky_relu param
float
Relu_clipped_coef
{
6
};
// relu_clipped param
float
Relu_clipped_coef
{
6
};
// relu_clipped param
...
@@ -245,7 +286,7 @@ struct ActivationParam {
...
@@ -245,7 +286,7 @@ struct ActivationParam {
lite_api
::
ActivationType
active_type
;
lite_api
::
ActivationType
active_type
;
};
};
struct
ActivationGradParam
{
struct
ActivationGradParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Out
{};
const
lite
::
Tensor
*
Out
{};
// for backward
// for backward
...
@@ -254,7 +295,7 @@ struct ActivationGradParam {
...
@@ -254,7 +295,7 @@ struct ActivationGradParam {
};
};
// For Convolution op
// For Convolution op
struct
ConvParam
{
struct
ConvParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
filter
{};
lite
::
Tensor
*
filter
{};
lite
::
Tensor
*
bias
{
nullptr
};
lite
::
Tensor
*
bias
{
nullptr
};
...
@@ -294,10 +335,26 @@ struct ConvParam {
...
@@ -294,10 +335,26 @@ struct ConvParam {
std
::
vector
<
int
>
output_size
;
std
::
vector
<
int
>
output_size
;
// for int8
// for int8
WITH_INT8_CONFIG
WITH_INT8_CONFIG
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
x
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
output
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
};
// For BatchNorm op
// For BatchNorm op
struct
BatchNormParam
{
struct
BatchNormParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
bias
{};
lite
::
Tensor
*
bias
{};
lite
::
Tensor
*
scale
{};
lite
::
Tensor
*
scale
{};
...
@@ -316,7 +373,7 @@ struct BatchNormParam {
...
@@ -316,7 +373,7 @@ struct BatchNormParam {
};
};
// For Pooling op
// For Pooling op
struct
PoolParam
{
struct
PoolParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
std
::
string
pooling_type
{
""
};
std
::
string
pooling_type
{
""
};
...
@@ -340,7 +397,7 @@ struct PoolParam {
...
@@ -340,7 +397,7 @@ struct PoolParam {
};
};
// For Dropout op
// For Dropout op
struct
DropoutParam
{
struct
DropoutParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
mask
{};
lite
::
Tensor
*
mask
{};
...
@@ -352,7 +409,7 @@ struct DropoutParam {
...
@@ -352,7 +409,7 @@ struct DropoutParam {
};
};
// For Split op
// For Split op
struct
SplitParam
{
struct
SplitParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
std
::
vector
<
lite
::
Tensor
*>
output
{};
std
::
vector
<
lite
::
Tensor
*>
output
{};
lite
::
Tensor
*
axis_tensor
;
lite
::
Tensor
*
axis_tensor
;
...
@@ -364,7 +421,7 @@ struct SplitParam {
...
@@ -364,7 +421,7 @@ struct SplitParam {
};
};
// For Transpose op
// For Transpose op
struct
TransposeParam
{
struct
TransposeParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
xshape
{};
lite
::
Tensor
*
xshape
{};
...
@@ -375,7 +432,7 @@ struct TransposeParam {
...
@@ -375,7 +432,7 @@ struct TransposeParam {
};
};
/// ----------------------- element wise operators ----------------------
/// ----------------------- element wise operators ----------------------
struct
ElementwiseParam
{
struct
ElementwiseParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -384,9 +441,24 @@ struct ElementwiseParam {
...
@@ -384,9 +441,24 @@ struct ElementwiseParam {
WITH_INT8_CONFIG
WITH_INT8_CONFIG
float
x_input_scale
{
1.0
};
float
x_input_scale
{
1.0
};
float
y_input_scale
{
1.0
};
float
y_input_scale
{
1.0
};
};
///////////////////////////////////////////////////////////////////////////////////
// get a vector of input tensors
struct
ElementwiseGradParam
{
const
std
::
vector
<
const
Tensor
*>*
input_tensor_ptrs
()
{
if
(
UNLIKELY
(
input_tensor_ptrs_cache_
))
{
input_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
const
Tensor
*>
({
X
,
Y
}));
}
return
input_tensor_ptrs_cache_
.
get
();
}
// get a vector of output tensors
const
std
::
vector
<
Tensor
*>*
output_tensor_ptrs
()
{
if
(
UNLIKELY
(
output_tensor_ptrs_cache_
))
{
output_tensor_ptrs_cache_
.
reset
(
new
std
::
vector
<
lite
::
Tensor
*>
({
Out
}));
}
return
output_tensor_ptrs_cache_
.
get
();
}
};
struct
ElementwiseGradParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
OutGrad
{};
const
lite
::
Tensor
*
OutGrad
{};
...
@@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
...
@@ -404,12 +476,12 @@ struct FusionElementwiseActivationGradParam : public ElementwiseGradParam {
};
};
/// ----------------------- mean operators ----------------------
/// ----------------------- mean operators ----------------------
struct
MeanParam
{
struct
MeanParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
MeanGradParam
{
struct
MeanGradParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Out_grad
{};
const
lite
::
Tensor
*
Out_grad
{};
// for backward
// for backward
...
@@ -417,7 +489,7 @@ struct MeanGradParam {
...
@@ -417,7 +489,7 @@ struct MeanGradParam {
};
};
/// ----------------------- fill_constant operators ----------------------
/// ----------------------- fill_constant operators ----------------------
struct
FillConstantParam
{
struct
FillConstantParam
:
ParamBase
{
int
dtype
{
static_cast
<
int
>
(
VarDescAPI
::
VarDataType
::
FP32
)};
int
dtype
{
static_cast
<
int
>
(
VarDescAPI
::
VarDataType
::
FP32
)};
std
::
vector
<
int64_t
>
shape
{};
std
::
vector
<
int64_t
>
shape
{};
lite
::
Tensor
*
shape_tensor
{
nullptr
};
lite
::
Tensor
*
shape_tensor
{
nullptr
};
...
@@ -429,7 +501,7 @@ struct FillConstantParam {
...
@@ -429,7 +501,7 @@ struct FillConstantParam {
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
};
};
struct
FillConstantBatchSizeLikeParam
{
struct
FillConstantBatchSizeLikeParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
input
{
nullptr
};
lite
::
Tensor
*
out
{
nullptr
};
lite
::
Tensor
*
out
{
nullptr
};
...
@@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam {
...
@@ -443,7 +515,7 @@ struct FillConstantBatchSizeLikeParam {
};
};
//
//
struct
FakeQuantizeMovingAvgMaxAbsParam
{
struct
FakeQuantizeMovingAvgMaxAbsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
in_scale
{};
const
lite
::
Tensor
*
in_scale
{};
const
lite
::
Tensor
*
in_accum
{};
const
lite
::
Tensor
*
in_accum
{};
...
@@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam {
...
@@ -457,14 +529,14 @@ struct FakeQuantizeMovingAvgMaxAbsParam {
float
moving_rate
{
0.9
};
float
moving_rate
{
0.9
};
};
};
struct
FakeDequantizeMaxAbsParam
{
struct
FakeDequantizeMaxAbsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
in_scale
{};
const
lite
::
Tensor
*
in_scale
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
float
max_range
;
float
max_range
;
};
};
struct
FakeChannelWiseDequantizeMaxAbsParam
{
struct
FakeChannelWiseDequantizeMaxAbsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
std
::
vector
<
const
lite
::
Tensor
*>
scale_tensors
{};
std
::
vector
<
const
lite
::
Tensor
*>
scale_tensors
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
...
@@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam {
...
@@ -472,7 +544,7 @@ struct FakeChannelWiseDequantizeMaxAbsParam {
};
};
/// ----------------------- sgd operators ----------------------
/// ----------------------- sgd operators ----------------------
struct
SGDParam
{
struct
SGDParam
:
ParamBase
{
int
dtype
{
static_cast
<
int
>
(
VarDescAPI
::
VarDataType
::
FP32
)};
int
dtype
{
static_cast
<
int
>
(
VarDescAPI
::
VarDataType
::
FP32
)};
const
lite
::
Tensor
*
Param
{};
const
lite
::
Tensor
*
Param
{};
...
@@ -482,7 +554,7 @@ struct SGDParam {
...
@@ -482,7 +554,7 @@ struct SGDParam {
};
};
/// ----------------------- uniform_random operators ----------------------
/// ----------------------- uniform_random operators ----------------------
struct
UniformRandomParam
{
struct
UniformRandomParam
:
ParamBase
{
std
::
vector
<
int64_t
>
shape
{};
std
::
vector
<
int64_t
>
shape
{};
float
min
{
-
1.0
f
};
float
min
{
-
1.0
f
};
float
max
{
1.0
f
};
float
max
{
1.0
f
};
...
@@ -491,12 +563,12 @@ struct UniformRandomParam {
...
@@ -491,12 +563,12 @@ struct UniformRandomParam {
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
/// ----------------------- negative operators --------------
/// ----------------------- negative operators --------------
struct
NegativeParam
{
struct
NegativeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
/// ----------------------- pad2d operators ----------------------
/// ----------------------- pad2d operators ----------------------
struct
Pad2dParam
{
struct
Pad2dParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
paddings
{
0
,
0
,
0
,
0
};
std
::
vector
<
int
>
paddings
{
0
,
0
,
0
,
0
};
...
@@ -506,7 +578,7 @@ struct Pad2dParam {
...
@@ -506,7 +578,7 @@ struct Pad2dParam {
};
};
/// ----------------------- Crop operators ----------------------
/// ----------------------- Crop operators ----------------------
struct
CropParam
{
struct
CropParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
offsets
;
std
::
vector
<
int
>
offsets
;
...
@@ -514,21 +586,21 @@ struct CropParam {
...
@@ -514,21 +586,21 @@ struct CropParam {
};
};
///----------------------- argmax operators ----------------------
///----------------------- argmax operators ----------------------
struct
ArgmaxParam
{
struct
ArgmaxParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
int
Axis
{
0
};
int
Axis
{
0
};
};
};
///----------------------- axpy operators ----------------------
///----------------------- axpy operators ----------------------
struct
AxpyParam
{
struct
AxpyParam
:
ParamBase
{
lite
::
Tensor
*
Scale
{};
lite
::
Tensor
*
Scale
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Bias
{};
lite
::
Tensor
*
Bias
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
/// ----------------------- GRU unit operators ----------------------f
/// ----------------------- GRU unit operators ----------------------f
struct
GRUUnitParam
{
struct
GRUUnitParam
:
ParamBase
{
enum
ActType
{
identity
,
sigmoid
,
tanh
,
relu
};
enum
ActType
{
identity
,
sigmoid
,
tanh
,
relu
};
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
hidden_prev
{
nullptr
};
const
lite
::
Tensor
*
hidden_prev
{
nullptr
};
...
@@ -544,7 +616,7 @@ struct GRUUnitParam {
...
@@ -544,7 +616,7 @@ struct GRUUnitParam {
};
};
/// ------------------------------ lrn operators ------------------------------
/// ------------------------------ lrn operators ------------------------------
struct
LrnParam
{
struct
LrnParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
int
n
{
5
};
int
n
{
5
};
...
@@ -555,7 +627,7 @@ struct LrnParam {
...
@@ -555,7 +627,7 @@ struct LrnParam {
};
};
/// ----------------------- decode_bboxes operators ----------------------
/// ----------------------- decode_bboxes operators ----------------------
struct
DecodeBboxesParam
{
struct
DecodeBboxesParam
:
ParamBase
{
const
lite
::
Tensor
*
loc_data
{};
const
lite
::
Tensor
*
loc_data
{};
const
lite
::
Tensor
*
prior_data
{};
const
lite
::
Tensor
*
prior_data
{};
lite
::
Tensor
*
bbox_data
{};
lite
::
Tensor
*
bbox_data
{};
...
@@ -571,7 +643,7 @@ struct DecodeBboxesParam {
...
@@ -571,7 +643,7 @@ struct DecodeBboxesParam {
};
};
/// ----------------------- box_coder operators ----------------------
/// ----------------------- box_coder operators ----------------------
struct
BoxCoderParam
{
struct
BoxCoderParam
:
ParamBase
{
const
lite
::
Tensor
*
prior_box
{};
const
lite
::
Tensor
*
prior_box
{};
const
lite
::
Tensor
*
prior_box_var
{};
const
lite
::
Tensor
*
prior_box_var
{};
const
lite
::
Tensor
*
target_box
{};
const
lite
::
Tensor
*
target_box
{};
...
@@ -584,7 +656,7 @@ struct BoxCoderParam {
...
@@ -584,7 +656,7 @@ struct BoxCoderParam {
};
};
/// ----------------------- multiclass_nms operators ----------------------
/// ----------------------- multiclass_nms operators ----------------------
struct
MulticlassNmsParam
{
struct
MulticlassNmsParam
:
ParamBase
{
const
lite
::
Tensor
*
bboxes
{};
const
lite
::
Tensor
*
bboxes
{};
const
lite
::
Tensor
*
scores
{};
const
lite
::
Tensor
*
scores
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
...
@@ -599,7 +671,7 @@ struct MulticlassNmsParam {
...
@@ -599,7 +671,7 @@ struct MulticlassNmsParam {
};
};
/// ----------------------- priorbox operators ----------------------
/// ----------------------- priorbox operators ----------------------
struct
PriorBoxParam
{
struct
PriorBoxParam
:
ParamBase
{
lite
::
Tensor
*
input
{};
lite
::
Tensor
*
input
{};
lite
::
Tensor
*
image
{};
lite
::
Tensor
*
image
{};
lite
::
Tensor
*
boxes
{};
lite
::
Tensor
*
boxes
{};
...
@@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam {
...
@@ -628,7 +700,7 @@ struct DensityPriorBoxParam : public PriorBoxParam {
std
::
vector
<
int
>
density_sizes
;
std
::
vector
<
int
>
density_sizes
;
};
};
/// ----------------------- GRU operators ----------------------f
/// ----------------------- GRU operators ----------------------f
struct
GRUParam
{
struct
GRUParam
:
ParamBase
{
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
input
{
nullptr
};
const
lite
::
Tensor
*
h0
{
nullptr
};
const
lite
::
Tensor
*
h0
{
nullptr
};
const
lite
::
Tensor
*
weight
{
nullptr
};
const
lite
::
Tensor
*
weight
{
nullptr
};
...
@@ -645,7 +717,7 @@ struct GRUParam {
...
@@ -645,7 +717,7 @@ struct GRUParam {
};
};
/// ----------------------- BeamSearchDecode operators ----------------------f
/// ----------------------- BeamSearchDecode operators ----------------------f
struct
BeamSearchDecodeParam
{
struct
BeamSearchDecodeParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
>*
ids
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
ids
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
scores
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
scores
{
nullptr
};
lite
::
Tensor
*
sentence_ids
{
nullptr
};
lite
::
Tensor
*
sentence_ids
{
nullptr
};
...
@@ -655,21 +727,21 @@ struct BeamSearchDecodeParam {
...
@@ -655,21 +727,21 @@ struct BeamSearchDecodeParam {
};
};
/// ----------------------- LookupTable operators ----------------------f
/// ----------------------- LookupTable operators ----------------------f
struct
LookupTableParam
{
struct
LookupTableParam
:
ParamBase
{
const
lite
::
Tensor
*
W
{
nullptr
};
const
lite
::
Tensor
*
W
{
nullptr
};
const
lite
::
Tensor
*
Ids
{
nullptr
};
const
lite
::
Tensor
*
Ids
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
int64_t
padding_idx
{
-
1
};
int64_t
padding_idx
{
-
1
};
};
};
struct
LookupTableDequantParam
{
struct
LookupTableDequantParam
:
ParamBase
{
lite
::
Tensor
*
W
{
nullptr
};
lite
::
Tensor
*
W
{
nullptr
};
lite
::
Tensor
*
Ids
{
nullptr
};
lite
::
Tensor
*
Ids
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
int64_t
padding_idx
{
-
1
};
int64_t
padding_idx
{
-
1
};
};
};
struct
Im2SequenceParam
{
struct
Im2SequenceParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -679,19 +751,19 @@ struct Im2SequenceParam {
...
@@ -679,19 +751,19 @@ struct Im2SequenceParam {
std
::
vector
<
int
>
out_strides
{
1
,
1
};
std
::
vector
<
int
>
out_strides
{
1
,
1
};
};
};
struct
SequenceSoftmaxParam
{
struct
SequenceSoftmaxParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
NormParam
{
struct
NormParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Norm
{};
lite
::
Tensor
*
Norm
{};
int
axis
{
1
};
int
axis
{
1
};
float
epsilon
{
1e-10
};
float
epsilon
{
1e-10
};
};
};
struct
LayerNormParam
{
struct
LayerNormParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Scale
{};
const
lite
::
Tensor
*
Scale
{};
const
lite
::
Tensor
*
Bias
{};
const
lite
::
Tensor
*
Bias
{};
...
@@ -702,13 +774,13 @@ struct LayerNormParam {
...
@@ -702,13 +774,13 @@ struct LayerNormParam {
float
epsilon
{
1e-5
};
float
epsilon
{
1e-5
};
};
};
struct
LogicalParam
{
struct
LogicalParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
CompareParam
{
struct
CompareParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
bool
force_cpu
{
0
};
bool
force_cpu
{
0
};
...
@@ -716,7 +788,7 @@ struct CompareParam {
...
@@ -716,7 +788,7 @@ struct CompareParam {
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
WhileParam
{
struct
WhileParam
:
ParamBase
{
Scope
*
scope
{};
Scope
*
scope
{};
Tensor
*
cond
{};
Tensor
*
cond
{};
cpp
::
BlockDesc
*
sub_block
{};
cpp
::
BlockDesc
*
sub_block
{};
...
@@ -724,32 +796,32 @@ struct WhileParam {
...
@@ -724,32 +796,32 @@ struct WhileParam {
std
::
vector
<
Tensor
*>
outs
{};
std
::
vector
<
Tensor
*>
outs
{};
};
};
struct
TopkParam
{
struct
TopkParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Indices
{};
lite
::
Tensor
*
Indices
{};
int
K
{
1
};
int
K
{
1
};
};
};
struct
IncrementParam
{
struct
IncrementParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
float
step
{
1
};
float
step
{
1
};
};
};
struct
WriteToArrayParam
{
struct
WriteToArrayParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{
nullptr
};
const
lite
::
Tensor
*
X
{
nullptr
};
const
lite
::
Tensor
*
I
{
nullptr
};
const
lite
::
Tensor
*
I
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
Out
{
nullptr
};
std
::
vector
<
lite
::
Tensor
>*
Out
{
nullptr
};
};
};
struct
ReadFromArrayParam
{
struct
ReadFromArrayParam
:
ParamBase
{
const
std
::
vector
<
lite
::
Tensor
>*
X
{
nullptr
};
const
std
::
vector
<
lite
::
Tensor
>*
X
{
nullptr
};
const
lite
::
Tensor
*
I
{
nullptr
};
const
lite
::
Tensor
*
I
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
lite
::
Tensor
*
Out
{
nullptr
};
};
};
struct
BeamSearchParam
{
struct
BeamSearchParam
:
ParamBase
{
const
lite
::
Tensor
*
pre_ids
{};
const
lite
::
Tensor
*
pre_ids
{};
const
lite
::
Tensor
*
pre_scores
{};
const
lite
::
Tensor
*
pre_scores
{};
const
lite
::
Tensor
*
ids
{};
const
lite
::
Tensor
*
ids
{};
...
@@ -763,7 +835,7 @@ struct BeamSearchParam {
...
@@ -763,7 +835,7 @@ struct BeamSearchParam {
bool
is_accumulated
;
bool
is_accumulated
;
};
};
struct
SequencePoolParam
{
struct
SequencePoolParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
string
pool_type
{
"AVERAGE"
};
std
::
string
pool_type
{
"AVERAGE"
};
...
@@ -773,7 +845,7 @@ struct SequencePoolParam {
...
@@ -773,7 +845,7 @@ struct SequencePoolParam {
#endif
#endif
};
};
struct
SequenceConvParam
{
struct
SequenceConvParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Filter
{};
const
lite
::
Tensor
*
Filter
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -782,13 +854,13 @@ struct SequenceConvParam {
...
@@ -782,13 +854,13 @@ struct SequenceConvParam {
int
contextLength
;
int
contextLength
;
};
};
struct
SequencePoolConcatParam
{
struct
SequencePoolConcatParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
X
{};
std
::
vector
<
lite
::
Tensor
*>
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
std
::
string
>
pool_type
{};
std
::
vector
<
std
::
string
>
pool_type
{};
};
};
struct
SearchGroupPaddingParam
{
struct
SearchGroupPaddingParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out_emb_padding
{};
lite
::
Tensor
*
out_emb_padding
{};
lite
::
Tensor
*
out_new
{};
lite
::
Tensor
*
out_new
{};
...
@@ -796,36 +868,36 @@ struct SearchGroupPaddingParam {
...
@@ -796,36 +868,36 @@ struct SearchGroupPaddingParam {
int
pad_id
;
int
pad_id
;
};
};
struct
SequenceReshapeParam
{
struct
SequenceReshapeParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
int
new_dim
;
int
new_dim
;
};
};
struct
SequenceExpandParam
{
struct
SequenceExpandParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
int
ref_level
{
-
1
};
int
ref_level
{
-
1
};
};
};
struct
SequenceExpandAsParam
{
struct
SequenceExpandAsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{
nullptr
};
const
lite
::
Tensor
*
x
{
nullptr
};
const
lite
::
Tensor
*
y
{
nullptr
};
const
lite
::
Tensor
*
y
{
nullptr
};
lite
::
Tensor
*
out
{
nullptr
};
lite
::
Tensor
*
out
{
nullptr
};
};
};
struct
SequenceReverseParam
{
struct
SequenceReverseParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
SequenceConcatParam
{
struct
SequenceConcatParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
X
{};
std
::
vector
<
lite
::
Tensor
*>
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
AttentionPaddingMaskParam
{
struct
AttentionPaddingMaskParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
int
pad_id
;
int
pad_id
;
...
@@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam {
...
@@ -834,21 +906,21 @@ struct AttentionPaddingMaskParam {
lite
::
Tensor
*
pad_begin
{};
lite
::
Tensor
*
pad_begin
{};
};
};
struct
SequenceArithmeticParam
{
struct
SequenceArithmeticParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
int
op_type
{
1
};
int
op_type
{
1
};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
ReduceMaxParam
{
struct
ReduceMaxParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
dim
{};
std
::
vector
<
int
>
dim
{};
bool
keep_dim
{
false
};
bool
keep_dim
{
false
};
};
};
struct
LodResetParam
{
struct
LodResetParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -856,12 +928,12 @@ struct LodResetParam {
...
@@ -856,12 +928,12 @@ struct LodResetParam {
bool
append
;
bool
append
;
};
};
struct
IsEmptyParam
{
struct
IsEmptyParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
ReduceParam
{
struct
ReduceParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
output
{};
lite
::
Tensor
*
output
{};
std
::
vector
<
int
>
dim
{
0
};
std
::
vector
<
int
>
dim
{
0
};
...
@@ -869,7 +941,7 @@ struct ReduceParam {
...
@@ -869,7 +941,7 @@ struct ReduceParam {
bool
reduce_all
{
false
};
bool
reduce_all
{
false
};
};
};
struct
VarConv2DParam
{
struct
VarConv2DParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
ROW
{};
const
lite
::
Tensor
*
ROW
{};
const
lite
::
Tensor
*
COLUMN
{};
const
lite
::
Tensor
*
COLUMN
{};
...
@@ -888,19 +960,19 @@ struct VarConv2DParam {
...
@@ -888,19 +960,19 @@ struct VarConv2DParam {
};
};
/// ----------------------- shape operators ----------------------
/// ----------------------- shape operators ----------------------
struct
ShapeParam
{
struct
ShapeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
CastParam
{
struct
CastParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
int
out_dtype
{
2
};
int
out_dtype
{
2
};
int
in_dtype
{
2
};
int
in_dtype
{
2
};
};
};
struct
SliceParam
{
struct
SliceParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
axes
{};
std
::
vector
<
int
>
axes
{};
...
@@ -914,7 +986,7 @@ struct SliceParam {
...
@@ -914,7 +986,7 @@ struct SliceParam {
lite
::
Tensor
*
EndsTensor
{
nullptr
};
lite
::
Tensor
*
EndsTensor
{
nullptr
};
};
};
struct
AffineChannelParam
{
struct
AffineChannelParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
// X is 4D tensor
const
lite
::
Tensor
*
X
{};
// X is 4D tensor
const
lite
::
Tensor
*
Scale
{};
const
lite
::
Tensor
*
Scale
{};
const
lite
::
Tensor
*
Bias
{};
const
lite
::
Tensor
*
Bias
{};
...
@@ -922,7 +994,7 @@ struct AffineChannelParam {
...
@@ -922,7 +994,7 @@ struct AffineChannelParam {
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
AnchorGeneratorParam
{
struct
AnchorGeneratorParam
:
ParamBase
{
const
lite
::
Tensor
*
Input
{};
const
lite
::
Tensor
*
Input
{};
std
::
vector
<
float
>
anchor_sizes
{};
std
::
vector
<
float
>
anchor_sizes
{};
std
::
vector
<
float
>
aspect_ratios
{};
std
::
vector
<
float
>
aspect_ratios
{};
...
@@ -934,7 +1006,7 @@ struct AnchorGeneratorParam {
...
@@ -934,7 +1006,7 @@ struct AnchorGeneratorParam {
lite
::
Tensor
*
Variances
{};
lite
::
Tensor
*
Variances
{};
};
};
struct
GenerateProposalsParam
{
struct
GenerateProposalsParam
:
ParamBase
{
// inputs
// inputs
const
lite
::
Tensor
*
Scores
{};
const
lite
::
Tensor
*
Scores
{};
const
lite
::
Tensor
*
BboxDeltas
{};
const
lite
::
Tensor
*
BboxDeltas
{};
...
@@ -954,14 +1026,14 @@ struct GenerateProposalsParam {
...
@@ -954,14 +1026,14 @@ struct GenerateProposalsParam {
lite
::
Tensor
*
RpnRoiProbs
{};
lite
::
Tensor
*
RpnRoiProbs
{};
};
};
/// ----------------------- squeeze operators ----------------------
/// ----------------------- squeeze operators ----------------------
struct
SqueezeParam
{
struct
SqueezeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
XShape
{};
lite
::
Tensor
*
XShape
{};
std
::
vector
<
int
>
axes
{};
std
::
vector
<
int
>
axes
{};
};
};
struct
UnsqueezeParam
{
struct
UnsqueezeParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
XShape
{};
lite
::
Tensor
*
XShape
{};
...
@@ -971,14 +1043,14 @@ struct UnsqueezeParam {
...
@@ -971,14 +1043,14 @@ struct UnsqueezeParam {
};
};
/// ----------------------- expand operators ----------------------
/// ----------------------- expand operators ----------------------
struct
ExpandParam
{
struct
ExpandParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
std
::
vector
<
int
>
expand_times
{};
std
::
vector
<
int
>
expand_times
{};
};
};
/// ----------------------- matmul operators ----------------------
/// ----------------------- matmul operators ----------------------
struct
MatMulParam
{
struct
MatMulParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Y
{};
const
lite
::
Tensor
*
Y
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -987,20 +1059,20 @@ struct MatMulParam {
...
@@ -987,20 +1059,20 @@ struct MatMulParam {
float
alpha
{
1.0
f
};
float
alpha
{
1.0
f
};
};
};
struct
GatherParam
{
struct
GatherParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
Index
{};
const
lite
::
Tensor
*
Index
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
/// ----------------------- assign operators -----------------------
/// ----------------------- assign operators -----------------------
struct
AssignParam
{
struct
AssignParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
/// ----------------------- roi_align operators -----------------------
/// ----------------------- roi_align operators -----------------------
struct
RoiAlignParam
{
struct
RoiAlignParam
:
ParamBase
{
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
X
{};
lite
::
Tensor
*
ROIs
{};
lite
::
Tensor
*
ROIs
{};
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
...
@@ -1011,13 +1083,13 @@ struct RoiAlignParam {
...
@@ -1011,13 +1083,13 @@ struct RoiAlignParam {
};
};
/// ----------------------- box_clip operators -----------------------
/// ----------------------- box_clip operators -----------------------
struct
BoxClipParam
{
struct
BoxClipParam
:
ParamBase
{
const
lite
::
Tensor
*
Input
{};
const
lite
::
Tensor
*
Input
{};
const
lite
::
Tensor
*
ImInfo
{};
const
lite
::
Tensor
*
ImInfo
{};
lite
::
Tensor
*
Output
{};
lite
::
Tensor
*
Output
{};
};
};
struct
RangeParam
{
struct
RangeParam
:
ParamBase
{
const
lite
::
Tensor
*
Start
;
const
lite
::
Tensor
*
Start
;
const
lite
::
Tensor
*
End
;
const
lite
::
Tensor
*
End
;
const
lite
::
Tensor
*
Step
;
const
lite
::
Tensor
*
Step
;
...
@@ -1025,7 +1097,7 @@ struct RangeParam {
...
@@ -1025,7 +1097,7 @@ struct RangeParam {
};
};
/// ----------------------- assign_value operators -----------------------
/// ----------------------- assign_value operators -----------------------
struct
AssignValueParam
{
struct
AssignValueParam
:
ParamBase
{
std
::
vector
<
int
>
shape
{};
std
::
vector
<
int
>
shape
{};
int
dtype
{};
int
dtype
{};
std
::
vector
<
float
>
fp32_values
{};
std
::
vector
<
float
>
fp32_values
{};
...
@@ -1034,7 +1106,7 @@ struct AssignValueParam {
...
@@ -1034,7 +1106,7 @@ struct AssignValueParam {
};
};
/// --------------- sequence_topk_avg_pooling operators ------------------
/// --------------- sequence_topk_avg_pooling operators ------------------
struct
SequenceTopkAvgPoolingParam
{
struct
SequenceTopkAvgPoolingParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
ROW
{};
const
lite
::
Tensor
*
ROW
{};
const
lite
::
Tensor
*
COLUMN
{};
const
lite
::
Tensor
*
COLUMN
{};
...
@@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam {
...
@@ -1045,7 +1117,7 @@ struct SequenceTopkAvgPoolingParam {
};
};
/// --------------- search_fc operators ------------------
/// --------------- search_fc operators ------------------
struct
SearchFcParam
{
struct
SearchFcParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
W
{};
const
lite
::
Tensor
*
W
{};
const
lite
::
Tensor
*
b
{};
const
lite
::
Tensor
*
b
{};
...
@@ -1053,7 +1125,7 @@ struct SearchFcParam {
...
@@ -1053,7 +1125,7 @@ struct SearchFcParam {
int
out_size
{};
int
out_size
{};
};
};
/// --------------------- match_matrix_tensor operators --------------------
/// --------------------- match_matrix_tensor operators --------------------
struct
MatchMatrixTensorParam
{
struct
MatchMatrixTensorParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
y
{};
const
lite
::
Tensor
*
w
{};
const
lite
::
Tensor
*
w
{};
...
@@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam {
...
@@ -1064,14 +1136,14 @@ struct MatchMatrixTensorParam {
};
};
/// --------------------- search_seq_depadding operators --------------------
/// --------------------- search_seq_depadding operators --------------------
struct
SearchSeqDepaddingParam
{
struct
SearchSeqDepaddingParam
:
ParamBase
{
const
lite
::
Tensor
*
pad
{};
const
lite
::
Tensor
*
pad
{};
const
lite
::
Tensor
*
src
{};
const
lite
::
Tensor
*
src
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
};
};
/// --------------------- search_grnn operators --------------------
/// --------------------- search_grnn operators --------------------
struct
SearchGrnnParam
{
struct
SearchGrnnParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
wi
{};
const
lite
::
Tensor
*
wi
{};
const
lite
::
Tensor
*
wh
{};
const
lite
::
Tensor
*
wh
{};
...
@@ -1084,7 +1156,7 @@ struct SearchGrnnParam {
...
@@ -1084,7 +1156,7 @@ struct SearchGrnnParam {
lite
::
Tensor
*
layout_input
{};
lite
::
Tensor
*
layout_input
{};
};
};
struct
SplitLodTensorParam
{
struct
SplitLodTensorParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
mask
{};
const
lite
::
Tensor
*
mask
{};
lite
::
Tensor
*
out_true
{};
lite
::
Tensor
*
out_true
{};
...
@@ -1092,7 +1164,7 @@ struct SplitLodTensorParam {
...
@@ -1092,7 +1164,7 @@ struct SplitLodTensorParam {
int
level
{};
int
level
{};
};
};
struct
MergeLodTensorParam
{
struct
MergeLodTensorParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
x
{};
const
lite
::
Tensor
*
mask
{};
const
lite
::
Tensor
*
mask
{};
const
lite
::
Tensor
*
in_true
{};
const
lite
::
Tensor
*
in_true
{};
...
@@ -1101,7 +1173,7 @@ struct MergeLodTensorParam {
...
@@ -1101,7 +1173,7 @@ struct MergeLodTensorParam {
int
level
{};
int
level
{};
};
};
struct
ConditionalBlockParam
{
struct
ConditionalBlockParam
:
ParamBase
{
const
lite
::
Tensor
*
cond
{};
const
lite
::
Tensor
*
cond
{};
std
::
vector
<
lite
::
Tensor
*>
x
{};
std
::
vector
<
lite
::
Tensor
*>
x
{};
std
::
vector
<
lite
::
Tensor
*>
outs
{};
std
::
vector
<
lite
::
Tensor
*>
outs
{};
...
@@ -1110,14 +1182,14 @@ struct ConditionalBlockParam {
...
@@ -1110,14 +1182,14 @@ struct ConditionalBlockParam {
bool
is_scalar_condition
{};
bool
is_scalar_condition
{};
};
};
struct
CollectFpnProposalsParam
{
struct
CollectFpnProposalsParam
:
ParamBase
{
std
::
vector
<
lite
::
Tensor
*>
multi_level_rois
{};
std
::
vector
<
lite
::
Tensor
*>
multi_level_rois
{};
std
::
vector
<
lite
::
Tensor
*>
multi_level_scores
{};
std
::
vector
<
lite
::
Tensor
*>
multi_level_scores
{};
lite
::
Tensor
*
fpn_rois
{};
lite
::
Tensor
*
fpn_rois
{};
int
post_nms_topN
{};
int
post_nms_topN
{};
};
};
struct
DistributeFpnProposalsParam
{
struct
DistributeFpnProposalsParam
:
ParamBase
{
const
lite
::
Tensor
*
fpn_rois
{};
const
lite
::
Tensor
*
fpn_rois
{};
std
::
vector
<
lite
::
Tensor
*>
multi_fpn_rois
{};
std
::
vector
<
lite
::
Tensor
*>
multi_fpn_rois
{};
lite
::
Tensor
*
restore_index
{};
lite
::
Tensor
*
restore_index
{};
...
@@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam {
...
@@ -1128,7 +1200,7 @@ struct DistributeFpnProposalsParam {
};
};
/// --------------------- instance_norm operators --------------------
/// --------------------- instance_norm operators --------------------
struct
InstanceNormParam
{
struct
InstanceNormParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
bias
{};
lite
::
Tensor
*
bias
{};
...
@@ -1138,12 +1210,12 @@ struct InstanceNormParam {
...
@@ -1138,12 +1210,12 @@ struct InstanceNormParam {
float
epsilon
;
float
epsilon
;
};
};
/// --------------------- grid sampler operators --------------------
/// --------------------- grid sampler operators --------------------
struct
GridSamplerParam
{
struct
GridSamplerParam
:
ParamBase
{
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
x
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
out
{};
lite
::
Tensor
*
grid
{};
lite
::
Tensor
*
grid
{};
};
};
struct
LstmParam
{
struct
LstmParam
:
ParamBase
{
lite
::
Tensor
*
Input
{};
lite
::
Tensor
*
Input
{};
lite
::
Tensor
*
Weight
{};
lite
::
Tensor
*
Weight
{};
lite
::
Tensor
*
Bias
{};
lite
::
Tensor
*
Bias
{};
...
@@ -1160,7 +1232,7 @@ struct LstmParam {
...
@@ -1160,7 +1232,7 @@ struct LstmParam {
std
::
string
candidate_activation
;
std
::
string
candidate_activation
;
};
};
struct
CrfDecodingParam
{
struct
CrfDecodingParam
:
ParamBase
{
lite
::
Tensor
*
emission
{};
lite
::
Tensor
*
emission
{};
lite
::
Tensor
*
transition
{};
lite
::
Tensor
*
transition
{};
lite
::
Tensor
*
label
{};
lite
::
Tensor
*
label
{};
...
...
lite/operators/pad2d_op.cc
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const {
...
@@ -30,7 +30,7 @@ bool Pad2dOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
Pad2dOpLite
::
InferShape
()
const
{
bool
Pad2dOpLite
::
InferShape
Impl
()
const
{
// nchw
// nchw
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
int
out_h
=
x_dims
[
2
]
+
param_
.
paddings
[
0
]
+
param_
.
paddings
[
1
];
int
out_h
=
x_dims
[
2
]
+
param_
.
paddings
[
0
]
+
param_
.
paddings
[
1
];
...
...
lite/operators/pad2d_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class Pad2dOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/pool_op.cc
浏览文件 @
c754a38f
...
@@ -60,7 +60,7 @@ int PoolOutputSize(int input_size,
...
@@ -60,7 +60,7 @@ int PoolOutputSize(int input_size,
return
output_size
;
return
output_size
;
}
}
bool
PoolOpLite
::
InferShape
()
const
{
bool
PoolOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
int
>&
ksize
=
param_
.
ksize
;
std
::
vector
<
int
>&
ksize
=
param_
.
ksize
;
// dynamic update 4-pad
// dynamic update 4-pad
...
...
lite/operators/pool_op.h
浏览文件 @
c754a38f
...
@@ -37,7 +37,7 @@ class PoolOpLite : public OpLite {
...
@@ -37,7 +37,7 @@ class PoolOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
{
...
...
lite/operators/power_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool PowerOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
PowerOp
::
InferShape
()
const
{
bool
PowerOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/power_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class PowerOp : public OpLite {
...
@@ -31,7 +31,7 @@ class PowerOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/prior_box_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool PriorBoxOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
PriorBoxOpLite
::
InferShape
()
const
{
return
true
;
}
bool
PriorBoxOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
PriorBoxOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
PriorBoxOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
auto
input
=
opdesc
.
Input
(
"Input"
).
front
();
auto
input
=
opdesc
.
Input
(
"Input"
).
front
();
...
...
lite/operators/prior_box_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite {
...
@@ -29,7 +29,7 @@ class PriorBoxOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/range_op.cc
浏览文件 @
c754a38f
...
@@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) {
...
@@ -41,7 +41,7 @@ void GetSize(T start, T end, T step, int64_t* size) {
:
std
::
ceil
(
std
::
abs
((
end
-
start
)
/
step
));
:
std
::
ceil
(
std
::
abs
((
end
-
start
)
/
step
));
}
}
bool
RangeOpLite
::
InferShape
()
const
{
bool
RangeOpLite
::
InferShape
Impl
()
const
{
int
start
=
param_
.
Start
->
data
<
float
>
()[
0
];
int
start
=
param_
.
Start
->
data
<
float
>
()[
0
];
int
end
=
param_
.
End
->
data
<
float
>
()[
0
];
int
end
=
param_
.
End
->
data
<
float
>
()[
0
];
int
step
=
param_
.
Step
->
data
<
float
>
()[
0
];
int
step
=
param_
.
Step
->
data
<
float
>
()[
0
];
...
...
lite/operators/range_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class RangeOpLite : public OpLite {
...
@@ -29,7 +29,7 @@ class RangeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/read_from_array_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool ReadFromArrayOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ReadFromArrayOp
::
InferShape
()
const
{
bool
ReadFromArrayOp
::
InferShape
Impl
()
const
{
int
id
=
param_
.
I
->
data
<
int64_t
>
()[
0
];
int
id
=
param_
.
I
->
data
<
int64_t
>
()[
0
];
auto
out_dims
=
(
*
param_
.
X
)[
id
].
dims
();
auto
out_dims
=
(
*
param_
.
X
)[
id
].
dims
();
param_
.
Out
->
Resize
(
out_dims
);
param_
.
Out
->
Resize
(
out_dims
);
...
...
lite/operators/read_from_array_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite {
...
@@ -30,7 +30,7 @@ class ReadFromArrayOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/reduce_max_op.cc
浏览文件 @
c754a38f
...
@@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const {
...
@@ -39,7 +39,7 @@ bool ReduceMaxOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ReduceMaxOp
::
InferShape
()
const
{
bool
ReduceMaxOp
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
dim
;
auto
dims
=
param_
.
dim
;
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
bool
reduce_all
=
false
;
bool
reduce_all
=
false
;
...
...
lite/operators/reduce_max_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite {
...
@@ -28,7 +28,7 @@ class ReduceMaxOp : public OpLite {
ReduceMaxOp
()
{}
ReduceMaxOp
()
{}
explicit
ReduceMaxOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
ReduceMaxOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/reduce_mean_op.cc
浏览文件 @
c754a38f
...
@@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const {
...
@@ -39,7 +39,7 @@ bool ReduceMeanOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ReduceMeanOp
::
InferShape
()
const
{
bool
ReduceMeanOp
::
InferShape
Impl
()
const
{
auto
dims
=
param_
.
dim
;
auto
dims
=
param_
.
dim
;
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
bool
reduce_all
=
false
;
bool
reduce_all
=
false
;
...
...
lite/operators/reduce_mean_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite {
...
@@ -28,7 +28,7 @@ class ReduceMeanOp : public OpLite {
ReduceMeanOp
()
{}
ReduceMeanOp
()
{}
explicit
ReduceMeanOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
ReduceMeanOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/reduce_ops.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool ReduceOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ReduceOp
::
InferShape
()
const
{
bool
ReduceOp
::
InferShape
Impl
()
const
{
const
auto
&
x_dims
=
param_
.
x
->
dims
();
const
auto
&
x_dims
=
param_
.
x
->
dims
();
auto
x_rank
=
x_dims
.
size
();
auto
x_rank
=
x_dims
.
size
();
auto
dims
=
param_
.
dim
;
auto
dims
=
param_
.
dim
;
...
...
lite/operators/reduce_ops.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class ReduceOp : public OpLite {
...
@@ -30,7 +30,7 @@ class ReduceOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/reduce_prod_op.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool ReduceProdOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ReduceProdOpLite
::
InferShape
()
const
{
bool
ReduceProdOpLite
::
InferShape
Impl
()
const
{
auto
x
=
param_
.
x
;
auto
x
=
param_
.
x
;
auto
out
=
param_
.
output
;
auto
out
=
param_
.
output
;
std
::
vector
<
int
>
dim
=
param_
.
dim
;
std
::
vector
<
int
>
dim
=
param_
.
dim
;
...
...
lite/operators/reduce_prod_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite {
...
@@ -29,7 +29,7 @@ class ReduceProdOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/relu_op.cc
浏览文件 @
c754a38f
...
@@ -20,7 +20,7 @@ namespace lite {
...
@@ -20,7 +20,7 @@ namespace lite {
namespace
operators
{
namespace
operators
{
bool
ReluOp
::
CheckShape
()
const
{
return
true
;
}
bool
ReluOp
::
CheckShape
()
const
{
return
true
;
}
bool
ReluOp
::
InferShape
()
const
{
bool
ReluOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
...
...
lite/operators/relu_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class ReluOp : public OpLite {
...
@@ -30,7 +30,7 @@ class ReluOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/reshape_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool ReshapeOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ReshapeOp
::
InferShape
()
const
{
bool
ReshapeOp
::
InferShape
Impl
()
const
{
const
auto
&
shape_tensor_vct
=
param_
.
shape_tensor_vct
;
const
auto
&
shape_tensor_vct
=
param_
.
shape_tensor_vct
;
auto
*
shape_tensor
=
param_
.
shape_tensor
;
auto
*
shape_tensor
=
param_
.
shape_tensor
;
const
auto
&
shape_vct
=
param_
.
shape_vct
;
const
auto
&
shape_vct
=
param_
.
shape_vct
;
...
@@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const {
...
@@ -97,8 +97,8 @@ bool Reshape2Op::CheckShape() const {
return
true
;
return
true
;
}
}
bool
Reshape2Op
::
InferShape
()
const
{
bool
Reshape2Op
::
InferShape
Impl
()
const
{
ReshapeOp
::
InferShape
();
ReshapeOp
::
InferShape
Impl
();
const
auto
&
x_dims
=
param_
.
x
->
dims
();
const
auto
&
x_dims
=
param_
.
x
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
);
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
);
xshape_dims
[
0
]
=
0
;
xshape_dims
[
0
]
=
0
;
...
...
lite/operators/reshape_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class ReshapeOp : public OpLite {
...
@@ -30,7 +30,7 @@ class ReshapeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp {
...
@@ -48,7 +48,7 @@ class Reshape2Op : public ReshapeOp {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/roi_align_op.cc
浏览文件 @
c754a38f
...
@@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const {
...
@@ -38,7 +38,7 @@ bool RoiAlignOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
RoiAlignOpLite
::
InferShape
()
const
{
bool
RoiAlignOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
auto
rois_dims
=
param_
.
ROIs
->
dims
();
auto
rois_dims
=
param_
.
ROIs
->
dims
();
...
...
lite/operators/roi_align_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class RoiAlignOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/scale_op.cc
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const {
...
@@ -24,7 +24,7 @@ bool ScaleOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ScaleOp
::
InferShape
()
const
{
bool
ScaleOp
::
InferShape
Impl
()
const
{
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/scale_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class ScaleOp : public OpLite {
...
@@ -30,7 +30,7 @@ class ScaleOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_aligned_mat_mul_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool SearchAlignedMatMulOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchAlignedMatMulOpLite
::
InferShape
()
const
{
bool
SearchAlignedMatMulOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
x_dims
=
param_
.
X
->
dims
();
const
auto
y_dims
=
param_
.
Y
->
dims
();
const
auto
y_dims
=
param_
.
Y
->
dims
();
const
auto
&
x_lod
=
param_
.
X
->
lod
();
const
auto
&
x_lod
=
param_
.
X
->
lod
();
...
...
lite/operators/search_aligned_mat_mul_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class SearchAlignedMatMulOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/search_fc_op.cc
浏览文件 @
c754a38f
...
@@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const {
...
@@ -50,7 +50,7 @@ bool SearchFcOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchFcOpLite
::
InferShape
()
const
{
bool
SearchFcOpLite
::
InferShape
Impl
()
const
{
auto
out_size
=
param_
.
out_size
;
auto
out_size
=
param_
.
out_size
;
lite
::
DDim
dims
(
std
::
vector
<
int64_t
>
({
-
1
,
out_size
}));
lite
::
DDim
dims
(
std
::
vector
<
int64_t
>
({
-
1
,
out_size
}));
param_
.
Out
->
Resize
(
dims
);
param_
.
Out
->
Resize
(
dims
);
...
...
lite/operators/search_fc_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class SearchFcOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_grnn_op.cc
浏览文件 @
c754a38f
...
@@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const {
...
@@ -51,7 +51,7 @@ bool SearchGrnnOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchGrnnOpLite
::
InferShape
()
const
{
bool
SearchGrnnOpLite
::
InferShape
Impl
()
const
{
const
auto
&
x_dims
=
param_
.
x
->
dims
();
const
auto
&
x_dims
=
param_
.
x
->
dims
();
const
auto
&
x_lod
=
param_
.
x
->
lod
();
const
auto
&
x_lod
=
param_
.
x
->
lod
();
CHECK_OR_FALSE
(
!
x_lod
.
empty
());
CHECK_OR_FALSE
(
!
x_lod
.
empty
());
...
...
lite/operators/search_grnn_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class SearchGrnnOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_group_padding_op.cc
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const {
...
@@ -31,7 +31,7 @@ bool SearchGroupPaddingOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchGroupPaddingOp
::
InferShape
()
const
{
bool
SearchGroupPaddingOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
x_dims
=
param_
.
x
->
dims
().
Vectorize
();
std
::
vector
<
int64_t
>
x_dims
=
param_
.
x
->
dims
().
Vectorize
();
param_
.
out_emb_padding
->
Resize
({
-
1
,
x_dims
[
1
]});
param_
.
out_emb_padding
->
Resize
({
-
1
,
x_dims
[
1
]});
...
...
lite/operators/search_group_padding_op.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite {
...
@@ -27,7 +27,7 @@ class SearchGroupPaddingOp : public OpLite {
SearchGroupPaddingOp
()
{}
SearchGroupPaddingOp
()
{}
explicit
SearchGroupPaddingOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
SearchGroupPaddingOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"search_group_padding"
;
}
std
::
string
DebugString
()
const
override
{
return
"search_group_padding"
;
}
...
...
lite/operators/search_seq_depadding_op.cc
浏览文件 @
c754a38f
...
@@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const {
...
@@ -44,7 +44,7 @@ bool SearchSeqDepaddingOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchSeqDepaddingOpLite
::
InferShape
()
const
{
bool
SearchSeqDepaddingOpLite
::
InferShape
Impl
()
const
{
DDim
pad_dims
=
param_
.
pad
->
dims
();
DDim
pad_dims
=
param_
.
pad
->
dims
();
param_
.
out
->
Resize
({
-
1
,
pad_dims
[
1
]});
param_
.
out
->
Resize
({
-
1
,
pad_dims
[
1
]});
return
true
;
return
true
;
...
...
lite/operators/search_seq_depadding_op.h
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite {
...
@@ -32,7 +32,7 @@ class SearchSeqDepaddingOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/search_seq_fc_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool SearchSeqFcOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchSeqFcOpLite
::
InferShape
()
const
{
bool
SearchSeqFcOpLite
::
InferShape
Impl
()
const
{
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
x_dims
=
param_
.
x
->
dims
();
const
auto
w_dims
=
param_
.
w
->
dims
();
const
auto
w_dims
=
param_
.
w
->
dims
();
const
auto
&
x_lod
=
param_
.
x
->
lod
();
const
auto
&
x_lod
=
param_
.
x
->
lod
();
...
...
lite/operators/search_seq_fc_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class SearchSeqFcOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/search_seq_softmax_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool SearchSeqSoftmaxOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SearchSeqSoftmaxOp
::
InferShape
()
const
{
bool
SearchSeqSoftmaxOp
::
InferShape
Impl
()
const
{
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
param_
.
output
->
set_lod
(
param_
.
x
->
lod
());
param_
.
output
->
set_lod
(
param_
.
x
->
lod
());
return
true
;
return
true
;
...
...
lite/operators/search_seq_softmax_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite {
...
@@ -31,7 +31,7 @@ class SearchSeqSoftmaxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_arithmetic_op.cc
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const {
...
@@ -28,7 +28,7 @@ bool SequenceArithmeticOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceArithmeticOp
::
InferShape
()
const
{
bool
SequenceArithmeticOp
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
set_lod
(
param_
.
X
->
lod
());
param_
.
Out
->
set_lod
(
param_
.
X
->
lod
());
return
true
;
return
true
;
...
...
lite/operators/sequence_arithmetic_op.h
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite {
...
@@ -29,7 +29,7 @@ class SequenceArithmeticOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_concat_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool SequenceConcatOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceConcatOp
::
InferShape
()
const
{
return
true
;
}
bool
SequenceConcatOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
SequenceConcatOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
bool
SequenceConcatOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
lite
::
Scope
*
scope
)
{
...
...
lite/operators/sequence_concat_op.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite {
...
@@ -27,7 +27,7 @@ class SequenceConcatOp : public OpLite {
SequenceConcatOp
()
{}
SequenceConcatOp
()
{}
explicit
SequenceConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
SequenceConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"sequence_concat"
;
}
std
::
string
DebugString
()
const
override
{
return
"sequence_concat"
;
}
...
...
lite/operators/sequence_conv_op.cc
浏览文件 @
c754a38f
...
@@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const {
...
@@ -44,7 +44,7 @@ bool SequenceConvOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceConvOp
::
InferShape
()
const
{
bool
SequenceConvOp
::
InferShape
Impl
()
const
{
const
auto
*
input
=
param_
.
X
;
const
auto
*
input
=
param_
.
X
;
const
auto
*
filter
=
param_
.
Filter
;
const
auto
*
filter
=
param_
.
Filter
;
auto
in_dims
=
input
->
dims
();
auto
in_dims
=
input
->
dims
();
...
...
lite/operators/sequence_conv_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite {
...
@@ -28,7 +28,7 @@ class SequenceConvOp : public OpLite {
SequenceConvOp
()
{}
SequenceConvOp
()
{}
explicit
SequenceConvOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
SequenceConvOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/sequence_expand_as_op.cc
浏览文件 @
c754a38f
...
@@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const {
...
@@ -34,7 +34,7 @@ bool SequenceExpandAsOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceExpandAsOpLite
::
InferShape
()
const
{
bool
SequenceExpandAsOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
auto
y_lod
=
param_
.
y
->
lod
();
auto
y_lod
=
param_
.
y
->
lod
();
auto
out_dims
=
x_dims
;
auto
out_dims
=
x_dims
;
...
...
lite/operators/sequence_expand_as_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class SequenceExpandAsOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_expand_op.cc
浏览文件 @
c754a38f
...
@@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const {
...
@@ -40,7 +40,7 @@ bool SequenceExpandOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceExpandOp
::
InferShape
()
const
{
bool
SequenceExpandOp
::
InferShape
Impl
()
const
{
const
auto
x_lod
=
param_
.
X
->
lod
();
const
auto
x_lod
=
param_
.
X
->
lod
();
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
int
ref_level
=
param_
.
ref_level
;
int
ref_level
=
param_
.
ref_level
;
...
...
lite/operators/sequence_expand_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite {
...
@@ -30,7 +30,7 @@ class SequenceExpandOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_pool_concat_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool SequencePoolConcatOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequencePoolConcatOp
::
InferShape
()
const
{
bool
SequencePoolConcatOp
::
InferShape
Impl
()
const
{
int
out_dim
=
0
;
int
out_dim
=
0
;
for
(
int
i
=
0
;
i
<
param_
.
X
.
size
();
++
i
)
{
for
(
int
i
=
0
;
i
<
param_
.
X
.
size
();
++
i
)
{
out_dim
+=
param_
.
X
[
i
]
->
dims
().
count
(
1
,
param_
.
X
[
i
]
->
dims
().
size
());
out_dim
+=
param_
.
X
[
i
]
->
dims
().
count
(
1
,
param_
.
X
[
i
]
->
dims
().
size
());
...
...
lite/operators/sequence_pool_concat_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite {
...
@@ -28,7 +28,7 @@ class SequencePoolConcatOp : public OpLite {
SequencePoolConcatOp
()
{}
SequencePoolConcatOp
()
{}
explicit
SequencePoolConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
SequencePoolConcatOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/sequence_pool_op.cc
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const {
...
@@ -29,7 +29,7 @@ bool SequencePoolOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequencePoolOp
::
InferShape
()
const
{
bool
SequencePoolOp
::
InferShape
Impl
()
const
{
const
auto
*
input
=
param_
.
X
;
const
auto
*
input
=
param_
.
X
;
auto
out_dims
=
input
->
dims
();
auto
out_dims
=
input
->
dims
();
out_dims
[
0
]
=
input
->
lod
()[
0
].
size
()
-
1
;
out_dims
[
0
]
=
input
->
lod
()[
0
].
size
()
-
1
;
...
...
lite/operators/sequence_pool_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite {
...
@@ -28,7 +28,7 @@ class SequencePoolOp : public OpLite {
SequencePoolOp
()
{}
SequencePoolOp
()
{}
explicit
SequencePoolOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
SequencePoolOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/sequence_reshape_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool SequenceReshapeOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceReshapeOp
::
InferShape
()
const
{
bool
SequenceReshapeOp
::
InferShape
Impl
()
const
{
int
new_dim
=
param_
.
new_dim
;
int
new_dim
=
param_
.
new_dim
;
auto
x_numel
=
param_
.
x
->
dims
().
production
();
auto
x_numel
=
param_
.
x
->
dims
().
production
();
std
::
vector
<
int64_t
>
out_shape
{
x_numel
/
new_dim
,
std
::
vector
<
int64_t
>
out_shape
{
x_numel
/
new_dim
,
...
...
lite/operators/sequence_reshape_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite {
...
@@ -31,7 +31,7 @@ class SequenceReshapeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_reverse_op.cc
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const {
...
@@ -30,7 +30,7 @@ bool SequenceReverseOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceReverseOp
::
InferShape
()
const
{
bool
SequenceReverseOp
::
InferShape
Impl
()
const
{
const
auto
*
input
=
param_
.
X
;
const
auto
*
input
=
param_
.
X
;
auto
out_dims
=
input
->
dims
();
auto
out_dims
=
input
->
dims
();
param_
.
Out
->
Resize
(
out_dims
);
param_
.
Out
->
Resize
(
out_dims
);
...
...
lite/operators/sequence_reverse_op.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite {
...
@@ -27,7 +27,7 @@ class SequenceReverseOp : public OpLite {
SequenceReverseOp
()
{}
SequenceReverseOp
()
{}
explicit
SequenceReverseOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
SequenceReverseOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"sequence_reverse"
;
}
std
::
string
DebugString
()
const
override
{
return
"sequence_reverse"
;
}
...
...
lite/operators/sequence_softmax_op.cc
浏览文件 @
c754a38f
...
@@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const {
...
@@ -24,7 +24,7 @@ bool SequenceSoftmaxOp::CheckShape() const {
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
return
true
;
return
true
;
}
}
bool
SequenceSoftmaxOp
::
InferShape
()
const
{
bool
SequenceSoftmaxOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
input_dims
=
param_
.
X
->
dims
();
auto
input_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/sequence_softmax_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite {
...
@@ -30,7 +30,7 @@ class SequenceSoftmaxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sequence_topk_avg_pooling_op.cc
浏览文件 @
c754a38f
...
@@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const {
...
@@ -43,7 +43,7 @@ bool SequenceTopkAvgPoolingOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SequenceTopkAvgPoolingOpLite
::
InferShape
()
const
{
bool
SequenceTopkAvgPoolingOpLite
::
InferShape
Impl
()
const
{
int
channel_num
=
param_
.
channel_num
;
int
channel_num
=
param_
.
channel_num
;
std
::
vector
<
int
>
topks
=
param_
.
topks
;
std
::
vector
<
int
>
topks
=
param_
.
topks
;
auto
row_dim
=
param_
.
ROW
->
dims
();
auto
row_dim
=
param_
.
ROW
->
dims
();
...
...
lite/operators/sequence_topk_avg_pooling_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class SequenceTopkAvgPoolingOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/sgd_op.cc
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const {
...
@@ -30,7 +30,7 @@ bool SGDOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SGDOpLite
::
InferShape
()
const
{
bool
SGDOpLite
::
InferShape
Impl
()
const
{
param_
.
ParamOut
->
Resize
(
param_
.
Param
->
dims
());
param_
.
ParamOut
->
Resize
(
param_
.
Param
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/sgd_op.h
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ class SGDOpLite : public OpLite {
...
@@ -33,7 +33,7 @@ class SGDOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/shape_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool ShapeOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ShapeOpLite
::
InferShape
()
const
{
bool
ShapeOpLite
::
InferShape
Impl
()
const
{
std
::
vector
<
int64_t
>
shape_vec
;
std
::
vector
<
int64_t
>
shape_vec
;
shape_vec
.
push_back
(
static_cast
<
int64_t
>
(
param_
.
X
->
dims
().
size
()));
shape_vec
.
push_back
(
static_cast
<
int64_t
>
(
param_
.
X
->
dims
().
size
()));
param_
.
Out
->
Resize
(
shape_vec
);
param_
.
Out
->
Resize
(
shape_vec
);
...
...
lite/operators/shape_op.h
浏览文件 @
c754a38f
...
@@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite {
...
@@ -28,7 +28,7 @@ class ShapeOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/shuffle_channel_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool ShuffleChannelOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
ShuffleChannelOpLite
::
InferShape
()
const
{
bool
ShuffleChannelOpLite
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
param_
.
Out
->
Resize
(
param_
.
X
->
dims
());
return
true
;
return
true
;
}
}
...
...
lite/operators/shuffle_channel_op.h
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite {
...
@@ -33,7 +33,7 @@ class ShuffleChannelOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/slice_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool SliceOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SliceOp
::
InferShape
()
const
{
bool
SliceOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
Out
);
CHECK_OR_FALSE
(
param_
.
Out
);
// TODO(Superjomn) Enable data sharing.
// TODO(Superjomn) Enable data sharing.
auto
in_dims
=
param_
.
X
->
dims
();
auto
in_dims
=
param_
.
X
->
dims
();
...
...
lite/operators/slice_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class SliceOp : public OpLite {
...
@@ -30,7 +30,7 @@ class SliceOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/softmax_op.cc
浏览文件 @
c754a38f
...
@@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const {
...
@@ -29,35 +29,7 @@ bool SoftmaxOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SoftmaxOp
::
SmartInferShape
()
{
bool
SoftmaxOp
::
InferShapeImpl
()
const
{
if
(
!
last_input_shapes
.
empty
()
&&
!
last_output_shapes
.
empty
())
{
if
(
param_
.
x
->
dims
()
==
last_input_shapes
[
0
]
&&
param_
.
x
->
lod
()
==
last_input_lods
[
0
])
{
param_
.
output
->
Resize
(
last_output_shapes
[
0
]);
param_
.
output
->
set_lod
(
last_output_lods
[
0
]);
return
true
;
}
}
this
->
InferShape
();
if
(
!
last_input_shapes
.
empty
())
{
last_input_shapes
.
clear
();
last_input_lods
.
clear
();
}
last_input_shapes
.
push_back
(
param_
.
x
->
dims
());
last_input_lods
.
push_back
(
param_
.
x
->
lod
());
if
(
!
last_output_shapes
.
empty
())
{
last_output_shapes
.
clear
();
last_output_lods
.
clear
();
}
last_output_shapes
.
push_back
(
param_
.
output
->
dims
());
last_output_lods
.
push_back
(
param_
.
output
->
lod
());
return
true
;
}
bool
SoftmaxOp
::
InferShape
()
const
{
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
param_
.
output
->
Resize
(
param_
.
x
->
dims
());
auto
out_lod
=
param_
.
output
->
mutable_lod
();
auto
out_lod
=
param_
.
output
->
mutable_lod
();
*
out_lod
=
param_
.
x
->
lod
();
*
out_lod
=
param_
.
x
->
lod
();
...
...
lite/operators/softmax_op.h
浏览文件 @
c754a38f
...
@@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite {
...
@@ -30,8 +30,7 @@ class SoftmaxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShapeImpl
()
const
override
;
bool
SmartInferShape
()
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/split_lod_tensor_op.cc
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const {
...
@@ -33,7 +33,7 @@ bool SplitLodTensorOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SplitLodTensorOpLite
::
InferShape
()
const
{
bool
SplitLodTensorOpLite
::
InferShape
Impl
()
const
{
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
param_
.
out_true
->
Resize
(
x_dims
);
param_
.
out_true
->
Resize
(
x_dims
);
param_
.
out_false
->
Resize
(
x_dims
);
param_
.
out_false
->
Resize
(
x_dims
);
...
...
lite/operators/split_lod_tensor_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite {
...
@@ -31,7 +31,7 @@ class SplitLodTensorOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/split_op.cc
浏览文件 @
c754a38f
...
@@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const {
...
@@ -29,7 +29,7 @@ bool SplitOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SplitOp
::
InferShape
()
const
{
bool
SplitOp
::
InferShape
Impl
()
const
{
const
auto
&
outs
=
param_
.
output
;
const
auto
&
outs
=
param_
.
output
;
auto
in_dims
=
param_
.
x
->
dims
();
auto
in_dims
=
param_
.
x
->
dims
();
int
axis
=
param_
.
axis
;
int
axis
=
param_
.
axis
;
...
...
lite/operators/split_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class SplitOp : public OpLite {
...
@@ -30,7 +30,7 @@ class SplitOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/squeeze_op.cc
浏览文件 @
c754a38f
...
@@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const {
...
@@ -75,7 +75,7 @@ bool SqueezeOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
SqueezeOp
::
InferShape
()
const
{
bool
SqueezeOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int
>
squeeze_dims
=
param_
.
axes
;
std
::
vector
<
int
>
squeeze_dims
=
param_
.
axes
;
DDim
in_dims
=
param_
.
X
->
dims
();
DDim
in_dims
=
param_
.
X
->
dims
();
DDim
out_dim
=
GetOutputShape
(
squeeze_dims
,
in_dims
,
true
);
DDim
out_dim
=
GetOutputShape
(
squeeze_dims
,
in_dims
,
true
);
...
@@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const {
...
@@ -105,8 +105,8 @@ bool Squeeze2Op::CheckShape() const {
return
true
;
return
true
;
}
}
bool
Squeeze2Op
::
InferShape
()
const
{
bool
Squeeze2Op
::
InferShape
Impl
()
const
{
SqueezeOp
::
InferShape
();
SqueezeOp
::
InferShape
Impl
();
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
1
);
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
1
);
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
...
...
lite/operators/squeeze_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class SqueezeOp : public OpLite {
...
@@ -30,7 +30,7 @@ class SqueezeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp {
...
@@ -48,7 +48,7 @@ class Squeeze2Op : public SqueezeOp {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/stack_op.cc
浏览文件 @
c754a38f
...
@@ -32,7 +32,7 @@ bool StackOp::CheckShape() const {
...
@@ -32,7 +32,7 @@ bool StackOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
StackOp
::
InferShape
()
const
{
bool
StackOp
::
InferShape
Impl
()
const
{
auto
input
=
param_
.
X
;
auto
input
=
param_
.
X
;
auto
input_dims
=
input
[
0
]
->
dims
();
auto
input_dims
=
input
[
0
]
->
dims
();
int
axis
=
param_
.
axis
;
int
axis
=
param_
.
axis
;
...
...
lite/operators/stack_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class StackOp : public OpLite {
...
@@ -31,7 +31,7 @@ class StackOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/subgraph_op.cc
浏览文件 @
c754a38f
...
@@ -22,7 +22,7 @@ namespace operators {
...
@@ -22,7 +22,7 @@ namespace operators {
bool
SubgraphOp
::
CheckShape
()
const
{
return
true
;
}
bool
SubgraphOp
::
CheckShape
()
const
{
return
true
;
}
bool
SubgraphOp
::
InferShape
()
const
{
return
CheckShape
();
/* enrich me */
}
bool
SubgraphOp
::
InferShape
Impl
()
const
{
return
CheckShape
();
/* enrich me */
}
bool
SubgraphOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
SubgraphOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
param_
.
input_names
=
op_desc
.
Input
(
"Inputs"
);
param_
.
input_names
=
op_desc
.
Input
(
"Inputs"
);
...
...
lite/operators/subgraph_op.h
浏览文件 @
c754a38f
...
@@ -35,7 +35,7 @@ class SubgraphOp : public OpLite {
...
@@ -35,7 +35,7 @@ class SubgraphOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/topk_op.cc
浏览文件 @
c754a38f
...
@@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const {
...
@@ -25,7 +25,7 @@ bool TopkOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
TopkOp
::
InferShape
()
const
{
bool
TopkOp
::
InferShape
Impl
()
const
{
auto
out_dims
=
param_
.
X
->
dims
();
auto
out_dims
=
param_
.
X
->
dims
();
out_dims
[
out_dims
.
size
()
-
1
]
=
param_
.
K
;
out_dims
[
out_dims
.
size
()
-
1
]
=
param_
.
K
;
auto
out
=
param_
.
Out
;
auto
out
=
param_
.
Out
;
...
...
lite/operators/topk_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class TopkOp : public OpLite {
...
@@ -30,7 +30,7 @@ class TopkOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/transpose_op.cc
浏览文件 @
c754a38f
...
@@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const {
...
@@ -42,7 +42,7 @@ bool TransposeOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
TransposeOp
::
InferShape
()
const
{
bool
TransposeOp
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
output
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
...
@@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const {
...
@@ -111,7 +111,7 @@ bool Transpose2Op::CheckShape() const {
return
true
;
return
true
;
}
}
bool
Transpose2Op
::
InferShape
()
const
{
bool
Transpose2Op
::
InferShape
Impl
()
const
{
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
x
);
CHECK_OR_FALSE
(
param_
.
output
);
CHECK_OR_FALSE
(
param_
.
output
);
auto
x_dims
=
param_
.
x
->
dims
();
auto
x_dims
=
param_
.
x
->
dims
();
...
...
lite/operators/transpose_op.h
浏览文件 @
c754a38f
...
@@ -31,7 +31,7 @@ class TransposeOp : public OpLite {
...
@@ -31,7 +31,7 @@ class TransposeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -50,7 +50,7 @@ class Transpose2Op : public OpLite {
...
@@ -50,7 +50,7 @@ class Transpose2Op : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/uniform_random_op.cc
浏览文件 @
c754a38f
...
@@ -22,7 +22,7 @@ namespace operators {
...
@@ -22,7 +22,7 @@ namespace operators {
bool
UniformRandomOpLite
::
CheckShape
()
const
{
return
true
;
}
bool
UniformRandomOpLite
::
CheckShape
()
const
{
return
true
;
}
bool
UniformRandomOpLite
::
InferShape
()
const
{
bool
UniformRandomOpLite
::
InferShape
Impl
()
const
{
param_
.
Out
->
Resize
(
param_
.
shape
);
param_
.
Out
->
Resize
(
param_
.
shape
);
return
true
;
return
true
;
}
}
...
...
lite/operators/uniform_random_op.h
浏览文件 @
c754a38f
...
@@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite {
...
@@ -33,7 +33,7 @@ class UniformRandomOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
...
...
lite/operators/unsqueeze_op.cc
浏览文件 @
c754a38f
...
@@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const {
...
@@ -62,7 +62,7 @@ bool UnsqueezeOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
UnsqueezeOp
::
InferShape
()
const
{
bool
UnsqueezeOp
::
InferShape
Impl
()
const
{
std
::
vector
<
int
>
final_axes
;
std
::
vector
<
int
>
final_axes
;
auto
axes
=
param_
.
axes
;
auto
axes
=
param_
.
axes
;
auto
*
axes_tensor
=
param_
.
axes_tensor
;
auto
*
axes_tensor
=
param_
.
axes_tensor
;
...
@@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const {
...
@@ -129,8 +129,8 @@ bool Unsqueeze2Op::CheckShape() const {
return
true
;
return
true
;
}
}
bool
Unsqueeze2Op
::
InferShape
()
const
{
bool
Unsqueeze2Op
::
InferShape
Impl
()
const
{
UnsqueezeOp
::
InferShape
();
UnsqueezeOp
::
InferShape
Impl
();
auto
x_dims
=
param_
.
X
->
dims
();
auto
x_dims
=
param_
.
X
->
dims
();
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
1
);
std
::
vector
<
DDim
::
value_type
>
xshape_dims
(
x_dims
.
size
()
+
1
,
1
);
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
x_dims
.
size
();
i
++
)
{
...
...
lite/operators/unsqueeze_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite {
...
@@ -30,7 +30,7 @@ class UnsqueezeOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
@@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp {
...
@@ -48,7 +48,7 @@ class Unsqueeze2Op : public UnsqueezeOp {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/var_conv_2d_op.cc
浏览文件 @
c754a38f
...
@@ -21,7 +21,7 @@ namespace operators {
...
@@ -21,7 +21,7 @@ namespace operators {
bool
VarConv2dOp
::
CheckShape
()
const
{
return
true
;
}
bool
VarConv2dOp
::
CheckShape
()
const
{
return
true
;
}
bool
VarConv2dOp
::
InferShape
()
const
{
return
true
;
}
bool
VarConv2dOp
::
InferShape
Impl
()
const
{
return
true
;
}
bool
VarConv2dOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
bool
VarConv2dOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
X
=
const_cast
<
lite
::
Tensor
*>
(
param_
.
X
=
const_cast
<
lite
::
Tensor
*>
(
...
...
lite/operators/var_conv_2d_op.h
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite {
...
@@ -27,7 +27,7 @@ class VarConv2dOp : public OpLite {
VarConv2dOp
()
{}
VarConv2dOp
()
{}
explicit
VarConv2dOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
explicit
VarConv2dOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"var_conv_2d"
;
}
std
::
string
DebugString
()
const
override
{
return
"var_conv_2d"
;
}
...
...
lite/operators/while_op.cc
浏览文件 @
c754a38f
...
@@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const {
...
@@ -27,7 +27,7 @@ bool WhileOpLite::CheckShape() const {
return
true
;
return
true
;
}
}
bool
WhileOpLite
::
InferShape
()
const
{
return
true
;
}
bool
WhileOpLite
::
InferShape
Impl
()
const
{
return
true
;
}
bool
WhileOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
bool
WhileOpLite
::
AttachImpl
(
const
cpp
::
OpDesc
&
op_desc
,
lite
::
Scope
*
scope
)
{
auto
inputs
=
op_desc
.
Input
(
"X"
);
auto
inputs
=
op_desc
.
Input
(
"X"
);
...
...
lite/operators/while_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class WhileOpLite : public OpLite {
...
@@ -30,7 +30,7 @@ class WhileOpLite : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/write_to_array_op.cc
浏览文件 @
c754a38f
...
@@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const {
...
@@ -26,7 +26,7 @@ bool WriteToArrayOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
WriteToArrayOp
::
InferShape
()
const
{
bool
WriteToArrayOp
::
InferShape
Impl
()
const
{
int
id
=
param_
.
I
->
data
<
int64_t
>
()[
0
];
int
id
=
param_
.
I
->
data
<
int64_t
>
()[
0
];
if
(
param_
.
Out
->
size
()
<
id
+
1
)
{
if
(
param_
.
Out
->
size
()
<
id
+
1
)
{
param_
.
Out
->
resize
(
id
+
1
);
param_
.
Out
->
resize
(
id
+
1
);
...
...
lite/operators/write_to_array_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite {
...
@@ -30,7 +30,7 @@ class WriteToArrayOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
lite/operators/yolo_box_op.cc
浏览文件 @
c754a38f
...
@@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const {
...
@@ -46,7 +46,7 @@ bool YoloBoxOp::CheckShape() const {
return
true
;
return
true
;
}
}
bool
YoloBoxOp
::
InferShape
()
const
{
bool
YoloBoxOp
::
InferShape
Impl
()
const
{
auto
*
X
=
param_
.
X
;
auto
*
X
=
param_
.
X
;
auto
anchors
=
param_
.
anchors
;
auto
anchors
=
param_
.
anchors
;
int
anchor_num
=
anchors
.
size
()
/
2
;
int
anchor_num
=
anchors
.
size
()
/
2
;
...
...
lite/operators/yolo_box_op.h
浏览文件 @
c754a38f
...
@@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite {
...
@@ -30,7 +30,7 @@ class YoloBoxOp : public OpLite {
bool
CheckShape
()
const
override
;
bool
CheckShape
()
const
override
;
bool
InferShape
()
const
override
;
bool
InferShape
Impl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录