diff --git a/src/framework/attribute.h b/src/framework/attribute.h index 2a18bfc8a76fc2b988885468de88e12df1fae20e..e00cee09b36a4372c938f356900faab88e610010 100644 --- a/src/framework/attribute.h +++ b/src/framework/attribute.h @@ -93,6 +93,14 @@ class Attribute { case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK: { break; } + case PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONGS: { + vector val(attr_desc->n_longs); + for (int i = 0; i < attr_desc->n_longs; ++i) { + val[i] = attr_desc->longs[i]; + } + attr.Set>(val); + break; + } default: PADDLE_MOBILE_THROW_EXCEPTION("attr type not support"); } @@ -144,6 +152,8 @@ class Attribute { } else if (attr.variant_.TypeId() == typeid(vector).hash_code()) { return vistor(attr.variant_.Get>()); + } else if (attr.variant_.TypeId() == typeid(vector).hash_code()) { + return vistor(attr.variant_.Get>()); } else { PADDLE_MOBILE_THROW_EXCEPTION("type not support"); } @@ -151,7 +161,8 @@ class Attribute { private: Variant, vector, vector, bool, - vector, BlockDesc *, int64_t> + vector, BlockDesc *, vector, int64_t, + vector> variant_; }; diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index e9438dcc06a86aaa4176d73072d5d69e5c14cf47..5d95df063b50d86165dc73d5da31dd17827c09d7 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -74,7 +74,6 @@ Executor::Executor(const Program &program, } ops_of_block0_.push_back(op_handler); } - if (program_.combined) { InitCombineMemory(); } else { @@ -423,15 +422,23 @@ void Executor::SetInput(const LoDTensor &input, template std::shared_ptr Executor::GetOutput( const std::string &var_name) { - int index = 0; - if (fetch_indices_.find(var_name) != fetch_indices_.end()) { - index = fetch_indices_.find(var_name)->second; - } - auto *fetch_var = program_.scope->Var("fetch"); - framework::LoDTensor &target = - fetch_var->template GetMutable()->at(index); + const auto &iter = fetch_indices_.find(var_name); + if (var_name == "fetch" || iter != fetch_indices_.end()) { + int index = 0; + if (iter != fetch_indices_.end()) { + index = iter->second; + } + auto *fetch_var = program_.scope->Var("fetch"); + framework::LoDTensor &target = + fetch_var->template GetMutable()->at(index); - return std::make_shared(target); + return std::make_shared(target); + } else { + auto *fetch_var = program_.scope->Var(var_name); + framework::LoDTensor *target = + fetch_var->template GetMutable(); + return std::make_shared(*target); + } } template diff --git a/src/framework/framework.pb-c.c b/src/framework/framework.pb-c.c index bbccc76a22f5efbe69b58e6a546d063923077af6..394c17f09724a9db2bfc62603a3ffa46cf032899 100644 --- a/src/framework/framework.pb-c.c +++ b/src/framework/framework.pb-c.c @@ -13,13 +13,6 @@ void paddle_mobile__framework__proto__version__init( PADDLE_MOBILE__FRAMEWORK__PROTO__VERSION__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__version__get_packed_size( - const PaddleMobile__Framework__Proto__Version *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__version__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__Version * paddle_mobile__framework__proto__version__unpack(ProtobufCAllocator *allocator, size_t len, @@ -54,13 +47,6 @@ void paddle_mobile__framework__proto__op_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__OP_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__op_desc__get_packed_size( - const PaddleMobile__Framework__Proto__OpDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__op_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__OpDesc * paddle_mobile__framework__proto__op_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -95,13 +81,6 @@ void paddle_mobile__framework__proto__op_proto__init( PADDLE_MOBILE__FRAMEWORK__PROTO__OP_PROTO__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__op_proto__get_packed_size( - const PaddleMobile__Framework__Proto__OpProto *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__op_proto__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__OpProto * paddle_mobile__framework__proto__op_proto__unpack(ProtobufCAllocator *allocator, size_t len, @@ -162,13 +141,6 @@ void paddle_mobile__framework__proto__var_type__init( PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_TYPE__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__var_type__get_packed_size( - const PaddleMobile__Framework__Proto__VarType *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__var_type__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__VarType * paddle_mobile__framework__proto__var_type__unpack(ProtobufCAllocator *allocator, size_t len, @@ -191,13 +163,6 @@ void paddle_mobile__framework__proto__var_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__VAR_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__var_desc__get_packed_size( - const PaddleMobile__Framework__Proto__VarDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__var_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__VarDesc * paddle_mobile__framework__proto__var_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -220,13 +185,6 @@ void paddle_mobile__framework__proto__block_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__BLOCK_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__block_desc__get_packed_size( - const PaddleMobile__Framework__Proto__BlockDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__block_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__BlockDesc * paddle_mobile__framework__proto__block_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data) { @@ -248,13 +206,6 @@ void paddle_mobile__framework__proto__program_desc__init( PADDLE_MOBILE__FRAMEWORK__PROTO__PROGRAM_DESC__INIT; *message = init_value; } -size_t paddle_mobile__framework__proto__program_desc__get_packed_size( - const PaddleMobile__Framework__Proto__ProgramDesc *message) { - assert(message->base.descriptor == - &paddle_mobile__framework__proto__program_desc__descriptor); - return protobuf_c_message_get_packed_size( - (const ProtobufCMessage *)(message)); -} PaddleMobile__Framework__Proto__ProgramDesc * paddle_mobile__framework__proto__program_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data) { @@ -310,7 +261,7 @@ const ProtobufCMessageDescriptor NULL /* reserved[123] */ }; static const ProtobufCFieldDescriptor - paddle_mobile__framework__proto__op_desc__attr__field_descriptors[13] = { + paddle_mobile__framework__proto__op_desc__attr__field_descriptors[14] = { { "name", 1, PROTOBUF_C_LABEL_REQUIRED, PROTOBUF_C_TYPE_STRING, 0, /* quantifier_offset */ @@ -405,6 +356,13 @@ static const ProtobufCFieldDescriptor NULL, NULL, 0, /* flags */ 0, NULL, NULL /* reserved1,reserved2, etc */ }, + { + "longs", 15, PROTOBUF_C_LABEL_REPEATED, PROTOBUF_C_TYPE_INT64, + offsetof(PaddleMobile__Framework__Proto__OpDesc__Attr, n_longs), + offsetof(PaddleMobile__Framework__Proto__OpDesc__Attr, longs), NULL, + NULL, 0, /* flags */ + 0, NULL, NULL /* reserved1,reserved2, etc */ + }, }; static const unsigned paddle_mobile__framework__proto__op_desc__attr__field_indices_by_name[] = { @@ -417,6 +375,7 @@ static const unsigned 2, /* field[2] = i */ 5, /* field[5] = ints */ 11, /* field[11] = l */ + 13, /* field[13] = longs */ 0, /* field[0] = name */ 4, /* field[4] = s */ 7, /* field[7] = strings */ @@ -424,7 +383,7 @@ static const unsigned }; static const ProtobufCIntRange paddle_mobile__framework__proto__op_desc__attr__number_ranges[2 + 1] = { - {1, 0}, {10, 8}, {0, 13}}; + {1, 0}, {10, 8}, {0, 14}}; const ProtobufCMessageDescriptor paddle_mobile__framework__proto__op_desc__attr__descriptor = { PROTOBUF_C__MESSAGE_DESCRIPTOR_MAGIC, @@ -433,7 +392,7 @@ const ProtobufCMessageDescriptor "PaddleMobile__Framework__Proto__OpDesc__Attr", "paddle_mobile.framework.proto", sizeof(PaddleMobile__Framework__Proto__OpDesc__Attr), - 13, + 14, paddle_mobile__framework__proto__op_desc__attr__field_descriptors, paddle_mobile__framework__proto__op_desc__attr__field_indices_by_name, 2, @@ -1448,7 +1407,7 @@ const ProtobufCMessageDescriptor NULL /* reserved[123] */ }; static const ProtobufCEnumValue - paddle_mobile__framework__proto__attr_type__enum_values_by_number[11] = { + paddle_mobile__framework__proto__attr_type__enum_values_by_number[12] = { {"INT", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT", 0}, {"FLOAT", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__FLOAT", 1}, {"STRING", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__STRING", 2}, @@ -1460,15 +1419,16 @@ static const ProtobufCEnumValue {"BLOCK", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK", 8}, {"LONG", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONG", 9}, {"BLOCKS", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS", 10}, + {"LONGS", "PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONGS", 11}, }; static const ProtobufCIntRange paddle_mobile__framework__proto__attr_type__value_ranges[] = {{0, 0}, - {0, 11}}; + {0, 12}}; static const ProtobufCEnumValueIndex - paddle_mobile__framework__proto__attr_type__enum_values_by_name[11] = { + paddle_mobile__framework__proto__attr_type__enum_values_by_name[12] = { {"BLOCK", 8}, {"BLOCKS", 10}, {"BOOLEAN", 6}, {"BOOLEANS", 7}, {"FLOAT", 1}, {"FLOATS", 4}, {"INT", 0}, {"INTS", 3}, - {"LONG", 9}, {"STRING", 2}, {"STRINGS", 5}, + {"LONG", 9}, {"LONGS", 11}, {"STRING", 2}, {"STRINGS", 5}, }; const ProtobufCEnumDescriptor paddle_mobile__framework__proto__attr_type__descriptor = { @@ -1477,9 +1437,9 @@ const ProtobufCEnumDescriptor "AttrType", "PaddleMobile__Framework__Proto__AttrType", "paddle_mobile.framework.proto", - 11, + 12, paddle_mobile__framework__proto__attr_type__enum_values_by_number, - 11, + 12, paddle_mobile__framework__proto__attr_type__enum_values_by_name, 1, paddle_mobile__framework__proto__attr_type__value_ranges, diff --git a/src/framework/framework.pb-c.h b/src/framework/framework.pb-c.h index b7bac7ef9c99f62489bcd74936b3c0b55374abfb..a0f2eaee12acded26ce210c1016aeba0c4eba4ed 100644 --- a/src/framework/framework.pb-c.h +++ b/src/framework/framework.pb-c.h @@ -102,8 +102,9 @@ typedef enum _PaddleMobile__Framework__Proto__AttrType { PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BOOLEANS = 7, PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCK = 8, PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONG = 9, - PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS = - 10 PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE( + PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__BLOCKS = 10, + PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__LONGS = + 11 PROTOBUF_C__FORCE_ENUM_TO_BE_INT_SIZE( PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE) } PaddleMobile__Framework__Proto__AttrType; @@ -152,13 +153,15 @@ struct _PaddleMobile__Framework__Proto__OpDesc__Attr { int64_t l; size_t n_blocks_idx; int32_t *blocks_idx; + size_t n_longs; + int64_t *longs; }; #define PADDLE_MOBILE__FRAMEWORK__PROTO__OP_DESC__ATTR__INIT \ { \ PROTOBUF_C_MESSAGE_INIT( \ &paddle_mobile__framework__proto__op_desc__attr__descriptor) \ , NULL, PADDLE_MOBILE__FRAMEWORK__PROTO__ATTR_TYPE__INT, 0, 0, 0, 0, NULL, \ - 0, NULL, 0, NULL, 0, NULL, 0, 0, 0, NULL, 0, 0, 0, 0, 0, NULL \ + 0, NULL, 0, NULL, 0, NULL, 0, 0, 0, NULL, 0, 0, 0, 0, 0, NULL, 0, NULL \ } struct _PaddleMobile__Framework__Proto__OpDesc__Var { @@ -417,8 +420,6 @@ struct _PaddleMobile__Framework__Proto__ProgramDesc { /* PaddleMobile__Framework__Proto__Version methods */ void paddle_mobile__framework__proto__version__init( PaddleMobile__Framework__Proto__Version *message); -size_t paddle_mobile__framework__proto__version__get_packed_size( - const PaddleMobile__Framework__Proto__Version *message); PaddleMobile__Framework__Proto__Version * paddle_mobile__framework__proto__version__unpack(ProtobufCAllocator *allocator, size_t len, @@ -435,8 +436,6 @@ void paddle_mobile__framework__proto__op_desc__var__init( /* PaddleMobile__Framework__Proto__OpDesc methods */ void paddle_mobile__framework__proto__op_desc__init( PaddleMobile__Framework__Proto__OpDesc *message); -size_t paddle_mobile__framework__proto__op_desc__get_packed_size( - const PaddleMobile__Framework__Proto__OpDesc *message); PaddleMobile__Framework__Proto__OpDesc * paddle_mobile__framework__proto__op_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -453,8 +452,6 @@ void paddle_mobile__framework__proto__op_proto__attr__init( /* PaddleMobile__Framework__Proto__OpProto methods */ void paddle_mobile__framework__proto__op_proto__init( PaddleMobile__Framework__Proto__OpProto *message); -size_t paddle_mobile__framework__proto__op_proto__get_packed_size( - const PaddleMobile__Framework__Proto__OpProto *message); PaddleMobile__Framework__Proto__OpProto * paddle_mobile__framework__proto__op_proto__unpack(ProtobufCAllocator *allocator, size_t len, @@ -483,8 +480,6 @@ void paddle_mobile__framework__proto__var_type__tuple__init( /* PaddleMobile__Framework__Proto__VarType methods */ void paddle_mobile__framework__proto__var_type__init( PaddleMobile__Framework__Proto__VarType *message); -size_t paddle_mobile__framework__proto__var_type__get_packed_size( - const PaddleMobile__Framework__Proto__VarType *message); PaddleMobile__Framework__Proto__VarType * paddle_mobile__framework__proto__var_type__unpack(ProtobufCAllocator *allocator, size_t len, @@ -495,8 +490,6 @@ void paddle_mobile__framework__proto__var_type__free_unpacked( /* PaddleMobile__Framework__Proto__VarDesc methods */ void paddle_mobile__framework__proto__var_desc__init( PaddleMobile__Framework__Proto__VarDesc *message); -size_t paddle_mobile__framework__proto__var_desc__get_packed_size( - const PaddleMobile__Framework__Proto__VarDesc *message); PaddleMobile__Framework__Proto__VarDesc * paddle_mobile__framework__proto__var_desc__unpack(ProtobufCAllocator *allocator, size_t len, @@ -507,8 +500,6 @@ void paddle_mobile__framework__proto__var_desc__free_unpacked( /* PaddleMobile__Framework__Proto__BlockDesc methods */ void paddle_mobile__framework__proto__block_desc__init( PaddleMobile__Framework__Proto__BlockDesc *message); -size_t paddle_mobile__framework__proto__block_desc__get_packed_size( - const PaddleMobile__Framework__Proto__BlockDesc *message); PaddleMobile__Framework__Proto__BlockDesc * paddle_mobile__framework__proto__block_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data); @@ -518,8 +509,6 @@ void paddle_mobile__framework__proto__block_desc__free_unpacked( /* PaddleMobile__Framework__Proto__ProgramDesc methods */ void paddle_mobile__framework__proto__program_desc__init( PaddleMobile__Framework__Proto__ProgramDesc *message); -size_t paddle_mobile__framework__proto__program_desc__get_packed_size( - const PaddleMobile__Framework__Proto__ProgramDesc *message); PaddleMobile__Framework__Proto__ProgramDesc * paddle_mobile__framework__proto__program_desc__unpack( ProtobufCAllocator *allocator, size_t len, const uint8_t *data); diff --git a/src/framework/framework.proto b/src/framework/framework.proto index 4f41e26dc2df8550a6ce318d6e39ef4f3e875e73..27a98e0d6178b0fb20dcf640635413691efb7f10 100644 --- a/src/framework/framework.proto +++ b/src/framework/framework.proto @@ -35,6 +35,7 @@ enum AttrType { BLOCK = 8; LONG = 9; BLOCKS = 10; + LONGS = 11; } // OpDesc describes an instance of a C++ framework::OperatorBase @@ -55,6 +56,7 @@ message OpDesc { optional int32 block_idx = 12; optional int64 l = 13; repeated int32 blocks_idx = 14; + repeated int64 longs = 15; }; message Var { diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 7c14b81f6845e4236c3d6fcf590c9d59c5735f1e..98af2ca6053fe544b49df4510b74ad0ac505b009 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -324,9 +324,15 @@ LOAD_OP1(psroi_pool, CPU); #ifdef ROI_PERSPECTIVE_OP LOAD_OP1(roi_perspective_transform, CPU); #endif +#ifdef BEAM_SEARCH_OP +LOAD_OP1(beam_search, CPU); +#endif #ifdef BEAM_SEARCH_DECODE_OP LOAD_OP1(beam_search_decode, CPU); #endif #ifdef PAD2D_OP LOAD_OP1(pad2d, CPU); #endif +#ifdef ONE_HOT_OP +LOAD_OP1(one_hot, CPU); +#endif diff --git a/src/operators/beam_search_decode_op.cpp b/src/operators/beam_search_decode_op.cpp index 410446944ce3bc7f0968ba84ea3445bf709605d6..9b01d2e17f363d3b729102a9747f6dc6682ea8aa 100644 --- a/src/operators/beam_search_decode_op.cpp +++ b/src/operators/beam_search_decode_op.cpp @@ -33,5 +33,4 @@ namespace ops = paddle_mobile::operators; REGISTER_OPERATOR_CPU(beam_search_decode, ops::BeamSearchDecodeOp); #endif -namespace ops = paddle_mobile::operators; #endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/beam_search_op.cpp b/src/operators/beam_search_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..502510ebeefd29336531fac24d279e009f6b8d6d --- /dev/null +++ b/src/operators/beam_search_op.cpp @@ -0,0 +1,36 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#pragma once + +#include "operators/beam_search_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void BeamSearchOp::InferShape() const {} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(beam_search, ops::BeamSearchOp); +#endif + +#endif // BEAM_SEARCH_OP diff --git a/src/operators/beam_search_op.h b/src/operators/beam_search_op.h new file mode 100644 index 0000000000000000000000000000000000000000..985552d9f6efde5a474ca57672b8500bfc558e32 --- /dev/null +++ b/src/operators/beam_search_op.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/beam_search_kernel.h" + +namespace paddle_mobile { +namespace operators { + +DECLARE_OPERATOR(BeamSearch, BeamSearchParam, BeamSearchKernel); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_OP diff --git a/src/operators/kernel/arm/beam_search_kernel.cpp b/src/operators/kernel/arm/beam_search_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e88e2f18eed1d9aefbeb954a02245ff6daae036 --- /dev/null +++ b/src/operators/kernel/arm/beam_search_kernel.cpp @@ -0,0 +1,261 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#include "operators/kernel/beam_search_kernel.h" +#include + +namespace paddle_mobile { +namespace operators { + +template +class BeamSearchFunctor { + public: + void operator()(const framework::LoDTensor *pre_ids, + const framework::LoDTensor *pre_scores, + const framework::LoDTensor *ids, + const framework::LoDTensor *scores, + framework::LoDTensor *selected_ids, + framework::LoDTensor *selected_scores, + framework::Tensor *parent_idx, size_t level, size_t beam_size, + int end_id, bool is_accumulated) { + auto abs_lod = framework::ToAbsOffset(scores->lod()); + auto &high_level = abs_lod[level]; + + auto items = SelectTopBeamSizeItems(pre_ids, pre_scores, ids, scores, level, + beam_size, end_id, is_accumulated); + auto selected_items = ToMap(items, high_level.back()); + + PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id); + // calculate the output tensor's height + size_t num_instances = std::accumulate( + std::begin(selected_items), std::end(selected_items), 0, + [](size_t a, std::vector &b) { return a + b.size(); }); + // the output tensor shape should be [num_instances, 1] + auto dims = framework::make_ddim( + std::vector({static_cast(num_instances), 1})); + selected_ids->Resize(dims); + selected_scores->Resize(dims); + parent_idx->Resize({static_cast(num_instances)}); + + auto *selected_ids_data = selected_ids->mutable_data(); + auto *selected_scores_data = selected_scores->mutable_data(); + auto *parent_idx_data = parent_idx->mutable_data(); + + // fill in data + std::vector low_level; + size_t low_offset = 0; + for (auto &items : selected_items) { + low_level.push_back(low_offset); + for (auto &item : items) { + parent_idx_data[low_offset] = static_cast(low_level.size() - 1); + selected_ids_data[low_offset] = item.id; + selected_scores_data[low_offset] = item.score; + low_offset++; + } + } + low_level.push_back(low_offset); + + // fill lod + framework::LoD lod(2); + lod[0].assign(high_level.begin(), high_level.end()); + lod[1].assign(low_level.begin(), low_level.end()); + selected_ids->set_lod(lod); + selected_scores->set_lod(lod); + } + + /* + * The basic items help to sort. + */ + struct Item { + Item() {} + Item(size_t offset, size_t id, float score) + : offset(offset), id(id), score(score) {} + // offset in the higher lod level. + size_t offset; + // prefix id in the lower lod level. + // size_t prefix; + // the candidate id + size_t id; + // the corresponding score + float score; + + inline bool operator<(const Item &in) const { + return (score < in.score) || + ((score == in.score) && (offset < in.offset)); + } + + inline void operator=(const Item &in) { + offset = in.offset; + id = in.id; + score = in.score; + } + }; + + protected: + /* + * Prune the source sentences all branchs finished, and it is optional. + * Pruning must one step later than finishing (thus pre_ids is needed here), + * since the end tokens must be writed out. + */ + void PruneEndBeams(const framework::LoDTensor *pre_ids, + const framework::LoD &abs_lod, + std::vector> *items, size_t lod_level, + int end_id) { + auto *pre_ids_data = pre_ids->data(); + auto &high_level = abs_lod[lod_level]; + for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { + size_t src_prefix_start = high_level[src_idx]; + size_t src_prefix_end = high_level[src_idx + 1]; + bool finish_flag = true; + for (size_t offset = src_prefix_start; offset < src_prefix_end; + offset++) { + for (auto &item : items->at(offset)) { + if (item.id != static_cast(end_id) || + pre_ids_data[offset] != end_id) { + finish_flag = false; + break; + } + } + if (!finish_flag) break; + } + if (finish_flag) { // all branchs of the beam (source sentence) end and + // prune this beam + for (size_t offset = src_prefix_start; offset < src_prefix_end; + offset++) + items->at(offset).clear(); + } + } + } + + /* + * Transform the items into a map whose key is offset, value is the items. + * NOTE low performance. + */ + std::vector> ToMap( + const std::vector> &items, size_t element_num) { + std::vector> result; + result.resize(element_num); + for (auto &entries : items) { + for (const auto &item : entries) { + result[item.offset].push_back(item); + } + } + return result; + } + + void Insert(std::vector *top_beam_ptr, const Item &item, + size_t beam_size) { + std::vector &top_beam = *top_beam_ptr; + + size_t num_beams = top_beam.size(); + if (num_beams < beam_size) { + top_beam.resize(num_beams + 1); + num_beams++; + } else { + if (item < top_beam[beam_size - 1]) { + return; + } + } + + for (int k = static_cast(num_beams) - 2; k >= 0; --k) { + if (top_beam[k] < item) { + top_beam[k + 1] = top_beam[k]; + } else { + top_beam[k + 1] = item; + return; + } + } + top_beam[0] = item; + } + + /* + * For each source, select top beam_size records. + */ + std::vector> SelectTopBeamSizeItems( + const framework::LoDTensor *pre_ids, + const framework::LoDTensor *pre_scores, const framework::LoDTensor *ids, + const framework::LoDTensor *scores, size_t lod_level, size_t beam_size, + int end_id, bool is_accumulated) { + std::vector> result; + + // find the current candidates + auto abs_lod = framework::ToAbsOffset(scores->lod()); + + auto *pre_ids_data = pre_ids->data(); + auto *pre_scores_data = pre_scores->data(); + + auto *ids_data = ids ? ids->data() : nullptr; + auto *scores_data = scores->data(); + + size_t num_seqs = scores->NumElements(lod_level); + size_t seq_width = 1; + for (int i = 1; i < scores->dims().size(); i++) { + seq_width *= scores->dims()[i]; + } + + for (size_t seq_id = 0; seq_id < num_seqs; ++seq_id) { + size_t seq_offset_start = abs_lod[lod_level][seq_id]; + size_t seq_offset_end = abs_lod[lod_level][seq_id + 1]; + + std::vector top_beam; + top_beam.reserve(beam_size); + + for (size_t offset = seq_offset_start; offset < seq_offset_end; + ++offset) { + auto pre_id = pre_ids_data[offset]; + auto pre_score = pre_scores_data[offset]; + if (pre_id == end_id) { + // Allocate all probability mass to end_id for finished branchs and + // the other candidate ids can be ignored. + Item item(offset, end_id, pre_score); + Insert(&top_beam, item, beam_size); + } else { + size_t index = offset * seq_width; + for (size_t d = 0; d < seq_width; d++, index++) { + int64_t id = ids_data ? ids_data[index] : static_cast(d); + float score = is_accumulated + ? scores_data[index] + : pre_score + std::log(scores_data[index]); + Item item(offset, id, score); + Insert(&top_beam, item, beam_size); + } + } + } + + result.emplace_back(top_beam); + } + + return result; + } +}; + +template <> +bool BeamSearchKernel::Init(BeamSearchParam *param) { + return true; +} + +template <> +void BeamSearchKernel::Compute(const BeamSearchParam ¶m) { + BeamSearchFunctor alg; + alg(param.pre_ids_, param.pre_scores_, param.ids_, param.scores_, + param.selected_ids_, param.selected_scores_, param.parent_idx_, + param.level_, param.beam_size_, param.end_id_, param.is_accumulated_); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/one_hot_kernel.cpp b/src/operators/kernel/arm/one_hot_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..208b34ea2cd7a2357870c08d27fdcfd164380d0c --- /dev/null +++ b/src/operators/kernel/arm/one_hot_kernel.cpp @@ -0,0 +1,61 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#include "operators/kernel/one_hot_kernel.h" +#include "framework/data_type.h" + +namespace paddle_mobile { +namespace operators { + +template +struct OnehotOpFunctor { + const framework::LoDTensor* in_; + framework::LoDTensor* out_; + int depth_; + + OnehotOpFunctor(const framework::LoDTensor* in, framework::LoDTensor* out, + int depth) + : in_(in), out_(out), depth_(depth) {} + + template + void apply() const { + auto* p_in_data = in_->data(); + auto numel = in_->numel(); + auto* p_out_data = out_->mutable_data(); + memset(p_out_data, 0, out_->numel() * sizeof(OutT)); + + for (int i = 0; i < numel; ++i) { + *(p_out_data + i * depth_ + p_in_data[i]) = 1.0; + } + } +}; + +template <> +bool OnehotKernel::Init(OnehotParam* param) { + return true; +} + +template <> +void OnehotKernel::Compute(const OnehotParam& param) { + framework::VisitDataType( + framework::ToDataType(param.dtype_), + OnehotOpFunctor(param.input_, param.output_, param.depth_)); +} + +} // namespace operators +} // namespace paddle_mobile + +#endif // ONE_HOT_OP diff --git a/src/operators/kernel/beam_search_kernel.h b/src/operators/kernel/beam_search_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..4e3640d905f08e997ae797b6e91e861fdb57728c --- /dev/null +++ b/src/operators/kernel/beam_search_kernel.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#define GET_VAR_AS_LOD_TENSOR(name, name_dict, scope) \ + OpParam::GetVarValue(name, name_dict, scope) + +template +class BeamSearchParam : public OpParam { + public: + BeamSearchParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + pre_ids_ = GET_VAR_AS_LOD_TENSOR("pre_ids", inputs, scope); + pre_scores_ = GET_VAR_AS_LOD_TENSOR("pre_scores", inputs, scope); + ids_ = GET_VAR_AS_LOD_TENSOR("ids", inputs, scope); + scores_ = GET_VAR_AS_LOD_TENSOR("scores", inputs, scope); + + selected_ids_ = GET_VAR_AS_LOD_TENSOR("selected_ids", outputs, scope); + selected_scores_ = GET_VAR_AS_LOD_TENSOR("selected_scores", outputs, scope); + if (outputs.count("parent_idx")) { + parent_idx_ = GET_VAR_AS_LOD_TENSOR("parent_idx", outputs, scope); + } else { + parent_idx_ = new framework::Tensor(); + } + + level_ = OpParam::GetAttr("level", attrs); + beam_size_ = OpParam::GetAttr("beam_size", attrs); + end_id_ = OpParam::GetAttr("end_id", attrs); + if (OpParam::HasAttr("is_accumulated", attrs)) { + is_accumulated_ = OpParam::GetAttr("is_accumulated", attrs); + } + } + + public: + framework::LoDTensor *pre_ids_; + framework::LoDTensor *pre_scores_; + framework::LoDTensor *ids_; + framework::LoDTensor *scores_; + + framework::LoDTensor *selected_ids_; + framework::LoDTensor *selected_scores_; + framework::Tensor *parent_idx_; + + int level_; + int beam_size_; + int end_id_; + bool is_accumulated_ = true; +}; + +DECLARE_KERNEL(BeamSearch, BeamSearchParam); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_OP diff --git a/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h b/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h index 663c65c83a0f5b76e292925ea8cb0994b0f99ad1..cb5bbc91c3b2cede812d28c77e669ddbe46078bf 100644 --- a/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h +++ b/src/operators/kernel/central-arm-func/elementwise_sub_arm_func.h @@ -15,6 +15,8 @@ limitations under the License. */ #ifdef ELEMENTWISESUB_OP #pragma once + +#include "framework/data_type.h" #include "operators/math/elementwise_op_function.h" #include "operators/op_param.h" @@ -26,15 +28,33 @@ struct SubFunctor { inline T operator()(T a, T b) const { return a - b; } }; +struct SubOpFunctor { + const framework::Tensor* x_; + const framework::Tensor* y_; + const int axis_; + framework::Tensor* out_; + + SubOpFunctor(const framework::Tensor* x, const framework::Tensor* y, + framework::Tensor* out, const int axis) + : x_(x), y_(y), out_(out), axis_(axis) {} + + template + void apply() const { + out_->mutable_data(); + ElementwiseComputeEx, T>(x_, y_, axis_, SubFunctor(), + out_); + } +}; + template -void ElementwiseSubCompute(const ElementwiseSubParam ¶m) { - const Tensor *input_x = param.InputX(); - const Tensor *input_y = param.InputY(); - Tensor *Out = param.Out(); - Out->mutable_data(); +void ElementwiseSubCompute(const ElementwiseSubParam& param) { + const Tensor* input_x = param.InputX(); + const Tensor* input_y = param.InputY(); + Tensor* out = param.Out(); + int axis = param.Axis(); - ElementwiseComputeEx, float>(input_x, input_y, axis, - SubFunctor(), Out); + framework::VisitDataType(framework::ToDataType(input_x->type()), + SubOpFunctor(input_x, input_y, out, axis)); } template class ElementwiseSubKernel; diff --git a/src/operators/kernel/one_hot_kernel.h b/src/operators/kernel/one_hot_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..0a57158823e7ba6024668af5d04093d0426ebaac --- /dev/null +++ b/src/operators/kernel/one_hot_kernel.h @@ -0,0 +1,53 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +#define GET_VAR_AS_LOD_TENSOR(name, name_dict, scope) \ + OpParam::GetVarValue(name, name_dict, scope) + +template +class OnehotParam : public OpParam { + public: + OnehotParam(const VariableNameMap &inputs, const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + input_ = GET_VAR_AS_LOD_TENSOR("X", inputs, scope); + output_ = GET_VAR_AS_LOD_TENSOR("Out", outputs, scope); + + depth_ = OpParam::GetAttr("depth", attrs); + dtype_ = OpParam::GetAttr("dtype", attrs); + } + + public: + framework::LoDTensor *input_; + framework::LoDTensor *output_; + + int depth_; + int dtype_; +}; + +DECLARE_KERNEL(Onehot, OnehotParam); + +} // namespace operators +} // namespace paddle_mobile + +#endif // ONE_HOT_OP diff --git a/src/operators/one_hot_op.cpp b/src/operators/one_hot_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..396f55318a80dc3177a0ae5f4b151eaec7806a6d --- /dev/null +++ b/src/operators/one_hot_op.cpp @@ -0,0 +1,43 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#pragma once + +#include "operators/one_hot_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void OnehotOp::InferShape() const { + const auto &x_dims = this->param_.input_->dims(); + int depth = this->param_.depth_; + framework::DDim out_dims(x_dims); + out_dims[out_dims.size() - 1] = depth; + this->param_.output_->Resize(out_dims); + this->param_.output_->set_lod(this->param_.input_->lod()); +} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(one_hot, ops::OnehotOp); +#endif + +#endif // ONE_HOT_OP diff --git a/src/operators/one_hot_op.h b/src/operators/one_hot_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4b7e83bf996873de087844887839055031d97f66 --- /dev/null +++ b/src/operators/one_hot_op.h @@ -0,0 +1,31 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef ONE_HOT_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/one_hot_kernel.h" + +namespace paddle_mobile { +namespace operators { + +DECLARE_OPERATOR(Onehot, OnehotParam, OnehotKernel); + +} // namespace operators +} // namespace paddle_mobile + +#endif // ONE_HOT_OP diff --git a/tools/op.cmake b/tools/op.cmake index a362f9685899e2cfca184e5f08387e96908235d0..3bdedc15d8e228d5ce69356de8388a0e28cf4a6a 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -299,8 +299,10 @@ if(NOT FOUND_MATCH) set(PROPOSAL_OP ON) set(PSROI_POOL_OP ON) set(ROI_PERSPECTIVE_OP ON) + set(BEAM_SEARCH_OP ON) set(BEAM_SEARCH_DECODE_OP ON) set(PAD2D_OP ON) + set(ONE_HOT_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -604,9 +606,15 @@ endif() if (ROI_PERSPECTIVE_OP) add_definitions(-DROI_PERSPECTIVE_OP) endif() +if (BEAM_SEARCH_OP) + add_definitions(-DBEAM_SEARCH_OP) +endif() if (BEAM_SEARCH_DECODE_OP) add_definitions(-DBEAM_SEARCH_DECODE_OP) endif() if (PAD2D_OP) add_definitions(-DPAD2D_OP) endif() +if (ONE_HOT_OP) + add_definitions(-DONE_HOT_OP) +endif()