diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 1cbae18e7ee32223f815c813975a4f5c29c48749..8c3fa0364c8310795a10bb36b6d2eb9f35df5a53 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -170,20 +170,19 @@ void Executor::InitMemory() { for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); - auto tensor = var->template GetMutable(); if (var_desc->Persistable()) { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; } char *origin_data = ReadFileToBuff(program_.model_path + "/" + var_desc->Name()); char *data = origin_data; + auto tensor = var->template GetMutable(); LoadMemory(reinterpret_cast(&data), var_desc, tensor); delete[] origin_data; } else { - if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { - varInputMemory(var_desc, var, tensor); - } + varInputMemory(var_desc, var); } } } @@ -205,23 +204,18 @@ void Executor::InitCombineMemory() { for (const auto &block : program_desc_->Blocks()) { for (const auto &var_desc : block->Vars()) { auto var = program_.scope->Var(var_desc->Name()); - auto tensor = var->template GetMutable(); if (var_desc->Persistable()) { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; } DLOG << " init combine memory persistable: " << var_desc->Name(); - + auto tensor = var->template GetMutable(); LoadMemory(reinterpret_cast(&data), var_desc, tensor); } else { - if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { - DLOG << " init combine memory no persistable in lod: " - << var_desc->Name(); - varInputMemory(var_desc, var, tensor); - } else { - DLOG << " init combine memory no persistable: " << var_desc->Name(); - } + DLOG << " init combine memory no persistable: " << var_desc->Name(); + varInputMemory(var_desc, var); } } } @@ -239,6 +233,7 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { auto tensor = var->template GetMutable(); if (var_desc->Persistable()) { if (var_desc->Name() == "feed" || var_desc->Name() == "fetch") { + var->template GetMutable(); continue; } } else { @@ -249,6 +244,9 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { input_tensor.dims()[3]}); tensor->Resize(new_dim); tensor->template mutable_data(); + } else { + PADDLE_MOBILE_THROW_EXCEPTION("Unsupported var type `%d`", + var_desc->Type()); } } } @@ -261,34 +259,43 @@ void Executor::InitNoPersistableMemory(const Tensor &input_tensor) { template bool Executor::varInputMemory( - const std::shared_ptr &var_desc, Variable *var, - LoDTensor *tensor) const { + const std::shared_ptr &var_desc, Variable *var) const { #ifdef PADDLE_MOBILE_FPGA tensor->init(typeid(float)); return true; #endif - auto type = var_desc->Tensor_desc().DataType(); - switch (type) { - case VARTYPE_TYPE_FP32: - tensor->mutable_data(); - break; - case VARTYPE_TYPE_INT8: - tensor->mutable_data(); - break; - case VARTYPE_TYPE_INT32: - tensor->mutable_data(); - break; - case VARTYPE_TYPE_INT64: - tensor->mutable_data(); - break; - default: - break; + auto TypeId = [](const VarType_Type &type) -> std::type_index { + switch (type) { + case VARTYPE_TYPE_BOOL: + return typeid(bool); + case VARTYPE_TYPE_FP32: + return typeid(float); + case VARTYPE_TYPE_INT8: + return typeid(int8_t); + case VARTYPE_TYPE_INT32: + return typeid(int); + case VARTYPE_TYPE_INT64: + return typeid(int64_t); + default: + PADDLE_MOBILE_THROW_EXCEPTION("got unhandled var type `%d`", type); + } + }; + + auto type = var_desc->Type(); + if (type == VARTYPE_TYPE_LOD_TENSOR) { + auto data_type = var_desc->Tensor_desc().DataType(); + framework::LoDTensor *tensor = var->template GetMutable(); + tensor->mutable_data(TypeId(data_type)); + } else if (type == VARTYPE_TYPE_STEP_SCOPES) { + std::vector *step_scopes = + var->template GetMutable>(); + } else if (type == VARTYPE_TYPE_STEP_LOD_TENSOR_ARRAY) { + framework::LoDTensorArray *tensor_array = + var->template GetMutable(); + } else { + PADDLE_MOBILE_THROW_EXCEPTION("got unhandled var type `%d`", type); } - bool is_mute_match = - (type == VARTYPE_TYPE_FP32) || (type == VARTYPE_TYPE_INT8) || - (type == VARTYPE_TYPE_INT32) || (type == VARTYPE_TYPE_INT64); - PADDLE_MOBILE_ENFORCE(is_mute_match, "got unhandled data type : %d", type); - return is_mute_match; + return true; } template diff --git a/src/framework/executor.h b/src/framework/executor.h index 2c2af1f9814b48b9264b13f01868c1e998c73701..b9692db388b384b2336b224a8cf19854df61041d 100644 --- a/src/framework/executor.h +++ b/src/framework/executor.h @@ -61,8 +61,8 @@ class Executor { protected: Executor() = default; - bool varInputMemory(const std::shared_ptr &var_desc, Variable *var, - LoDTensor *tensor) const; + bool varInputMemory(const std::shared_ptr &var_desc, + Variable *var) const; void InitMemory(); void InitCombineMemory(); void InitNoPersistableMemory(const Tensor &input_tensor); diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 3b214a52f1314202d183b871784bebae0b6ec795..95524a76760359a5f543220bf4365f2778d32090 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -324,3 +324,6 @@ LOAD_OP1(psroi_pool, CPU); #ifdef ROI_PERSPECTIVE_OP LOAD_OP1(roi_perspective_transform, CPU); #endif +#ifdef BEAM_SEARCH_DECODE_OP +LOAD_OP1(beam_search_decode, CPU); +#endif diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 611b134eaab7376dfba9ef3294f8fa14bf13e0da..4d68b8477febd967800bca9eb0334a85843b40cf 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -64,9 +64,10 @@ void OperatorBase::Run() { for (const auto key : input_keys) { auto var_vec_in = inputs_.at(key); for (int i = 0; i < var_vec_in.size(); ++i) { - auto vari = this->scope_->FindVar(var_vec_in[i]); - if (vari->IsInitialized()) { - const Tensor *tensor = vari->template Get(); + auto var = this->scope_->FindVar(var_vec_in[i]); + if (var->IsInitialized() && + var->template IsType()) { + const Tensor *tensor = var->template Get(); if (tensor) DLOG << type_ << " input- " << key << "=" << *tensor; } } @@ -74,9 +75,10 @@ void OperatorBase::Run() { for (const auto key : GetOutKeys()) { auto var_vec_out = outputs_.at(key); for (int i = 0; i < var_vec_out.size(); ++i) { - auto vari = scope_->FindVar(var_vec_out[i]); - if (vari->IsInitialized()) { - const Tensor *tensor = vari->template Get(); + auto var = scope_->FindVar(var_vec_out[i]); + if (var->IsInitialized() && + var->template IsType()) { + const Tensor *tensor = var->template Get(); if (tensor) DLOG << type_ << " output- " << key << "=" << *tensor; } } diff --git a/src/framework/program/program_desc.cpp b/src/framework/program/program_desc.cpp index aba0c829e56712282f83dabee2adf98002288432..6866ab9c75cb06ad1af86ab99a32d59dfa7b45f5 100644 --- a/src/framework/program/program_desc.cpp +++ b/src/framework/program/program_desc.cpp @@ -78,9 +78,8 @@ void ProgramDesc::Description(std::string header) { } for (const auto &var_desc : block->Vars()) { + LOG(kLOG_DEBUG1) << "var name: " << var_desc->Name(); if (var_desc->Type() == VARTYPE_TYPE_LOD_TENSOR) { - LOG(kLOG_DEBUG1) << "var name: " << var_desc->Name(); - const TensorDesc &tensor_desc = var_desc->Tensor_desc(); LOG(kLOG_DEBUG2) << "in var tensor desc dims size: " diff --git a/src/io/paddle_mobile.cpp b/src/io/paddle_mobile.cpp index 7e8fcb82887dca02a84677aebe2b8a4547338bc9..7ea501fc7582e28180aa464edb950d56e250a741 100644 --- a/src/io/paddle_mobile.cpp +++ b/src/io/paddle_mobile.cpp @@ -152,14 +152,14 @@ PMStatus PaddleMobile::Predict() { } template -void PaddleMobile::Feed(const framework::Tensor &input, - const std::string &var_name) { +void PaddleMobile::Feed(const std::string &var_name, + const framework::Tensor &input) { executor_->SetInput(input, var_name); } template -void PaddleMobile::Feed(const framework::LoDTensor &input, - const std::string &var_name) { +void PaddleMobile::Feed(const std::string &var_name, + const framework::LoDTensor &input) { executor_->SetInput(input, var_name); } diff --git a/src/io/paddle_mobile.h b/src/io/paddle_mobile.h index 55fcaf3598603bc9adcf9410d529c68063442dc3..b651028f29fa10111ccef334ddf41b9fbec46c1e 100644 --- a/src/io/paddle_mobile.h +++ b/src/io/paddle_mobile.h @@ -33,7 +33,7 @@ namespace paddle_mobile { template class PaddleMobile { public: - PaddleMobile(PaddleMobileConfigInternal config) : config_(config) { + explicit PaddleMobile(PaddleMobileConfigInternal config) : config_(config) { #ifndef PADDLE_MOBILE_CL bool is_gpu = std::is_same, Device>::value; PADDLE_MOBILE_ENFORCE(!is_gpu, "Please recompile with GPU_CL is on"); @@ -69,8 +69,8 @@ class PaddleMobile { const std::vector &dims); PMStatus Predict(); - void Feed(const framework::LoDTensor &input, const std::string &var_name); - void Feed(const framework::Tensor &input, const std::string &var_name); + void Feed(const std::string &var_name, const framework::LoDTensor &input); + void Feed(const std::string &var_name, const framework::Tensor &input); typedef std::shared_ptr LoDTensorPtr; LoDTensorPtr Fetch(const std::string &var_name); diff --git a/src/operators/beam_search_decode_op.cpp b/src/operators/beam_search_decode_op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..410446944ce3bc7f0968ba84ea3445bf709605d6 --- /dev/null +++ b/src/operators/beam_search_decode_op.cpp @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_DECODE_OP + +#pragma once + +#include "operators/beam_search_decode_op.h" + +namespace paddle_mobile { +namespace operators { + +template +void BeamSearchDecodeOp::InferShape() const {} + +} // namespace operators +} // namespace paddle_mobile + +namespace ops = paddle_mobile::operators; + +#ifdef PADDLE_MOBILE_CPU +REGISTER_OPERATOR_CPU(beam_search_decode, ops::BeamSearchDecodeOp); +#endif + +namespace ops = paddle_mobile::operators; +#endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/beam_search_decode_op.h b/src/operators/beam_search_decode_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f212959474eade3da0f026bcdb1e3d15ddd30c6d --- /dev/null +++ b/src/operators/beam_search_decode_op.h @@ -0,0 +1,32 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_DECODE_OP + +#pragma once + +#include +#include "framework/operator.h" +#include "operators/kernel/beam_search_decode_kernel.h" + +namespace paddle_mobile { +namespace operators { + +DECLARE_OPERATOR(BeamSearchDecode, BeamSearchDecodeParam, + BeamSearchDecodeKernel); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/feed_op.cpp b/src/operators/feed_op.cpp index 4e496fb51d16c47d801eabada7c36dbdefdd2140..9e0b037c8dff4e4ea27d6f2f3155d06c9ed4821f 100644 --- a/src/operators/feed_op.cpp +++ b/src/operators/feed_op.cpp @@ -21,7 +21,8 @@ template void FeedOp::InferShape() const { auto out_dims = this->param_.Out()->dims(); out_dims[0] = this->param_.BatchSize(); - auto input_dims = this->param_.InputX()->dims(); + int col = this->param_.Col(); + auto input_dims = this->param_.InputX()->at(col).dims(); if (input_dims.size() == 4) { this->param_.Out()->Resize(input_dims); } else { diff --git a/src/operators/kernel/arm/beam_search_decode_kernel.cpp b/src/operators/kernel/arm/beam_search_decode_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e23c67804e3f9178f1410da8cf895227c192bc73 --- /dev/null +++ b/src/operators/kernel/arm/beam_search_decode_kernel.cpp @@ -0,0 +1,37 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_DECODE_OP + +#include "operators/kernel/beam_search_decode_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool BeamSearchDecodeKernel::Init( + BeamSearchDecodeParam *param) { + return true; +} + +template <> +void BeamSearchDecodeKernel::Compute( + const BeamSearchDecodeParam ¶m) { + // TODO(hjchen2) +} + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/arm/feed_kernel.cpp b/src/operators/kernel/arm/feed_kernel.cpp index 598f6df01b16683f4d6e06f6418a2930a7ec8736..26ea2ac5f7d806aa6e69dfe9697ed84b61347c0e 100644 --- a/src/operators/kernel/arm/feed_kernel.cpp +++ b/src/operators/kernel/arm/feed_kernel.cpp @@ -24,8 +24,9 @@ bool FeedKernel::Init(FeedParam *param) { template <> void FeedKernel::Compute(const FeedParam ¶m) { - param.Out()->ShareDataWith(*(param.InputX())); - param.Out()->set_lod(param.InputX()->lod()); + int col = param.Col(); + param.Out()->ShareDataWith(param.InputX()->at(col)); + param.Out()->set_lod(param.InputX()->at(col).lod()); } template class FeedKernel; diff --git a/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp b/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp index 5d6b86c34c963808e72a05db08940db97b0212b9..5381fe26b9d6fb80887bcb3d5371b526d11530e4 100644 --- a/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp +++ b/src/operators/kernel/arm/tensor_array_read_write_kernel.cpp @@ -28,8 +28,9 @@ void WriteToArrayKernel::Compute( const WriteToArrayParam ¶m) { int64_t offset = param.index_->data()[0]; if (offset >= param.output_->size()) { - param.output_->resize(offset); + param.output_->resize(offset + 1); } + framework::LoDTensor *out_tensor = &(param.output_->at(offset)); out_tensor->set_lod(param.input_->lod()); if (param.input_->memory_size() > 0) { diff --git a/src/operators/kernel/arm/while_kernel.cpp b/src/operators/kernel/arm/while_kernel.cpp index f27a897ffcadbb1ade759f324766ccbfb8dd49d5..ecdfd29d7cc34d0fb39e32dcc079ab62a6784dfd 100644 --- a/src/operators/kernel/arm/while_kernel.cpp +++ b/src/operators/kernel/arm/while_kernel.cpp @@ -12,12 +12,46 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef WHILE_OP + #include "operators/kernel/while_kernel.h" +#include "framework/op_registry.h" +#include "framework/operator.h" namespace paddle_mobile { namespace operators { -#ifdef WHILE_OP +class StepExecutor { + typedef std::shared_ptr> OperatorPtr; + + public: + StepExecutor(const framework::BlockDesc *block, framework::Scope *scope) + : scope_(std::shared_ptr(scope)) { + std::vector> ops = block->Ops(); + ops_of_block_.resize(ops.size()); + for (int i = 0; i < ops.size(); ++i) { + std::shared_ptr op_desc = ops[i]; + auto op_handler = framework::OpRegistry::CreateOp( + op_desc->Type(), op_desc->GetInputs(), op_desc->GetOutputs(), + op_desc->GetAttrMap(), scope_); + ops_of_block_[i] = op_handler; + } + } + + void Run() { + for (auto &op_handler : ops_of_block_) { + DLOG << "run op: " << op_handler->Type(); + op_handler->InferShape(); + op_handler->Run(); + DLOG << "run op finish"; + } + } + + private: + std::shared_ptr scope_; + std::vector ops_of_block_; +}; + template <> bool WhileKernel::Init(WhileParam *param) { return true; @@ -26,8 +60,15 @@ bool WhileKernel::Init(WhileParam *param) { template <> void WhileKernel::Compute(const WhileParam ¶m) { // TODO(hjchen2) + auto ¤t_scope = param.scope_->NewScope(); + StepExecutor executor(param.sub_block_, ¤t_scope); + while (param.cond_->data()[0]) { + executor.Run(); + } + param.scope_->DeleteScope(¤t_scope); } -#endif // WHILE_OP } // namespace operators } // namespace paddle_mobile + +#endif // WHILE_OP diff --git a/src/operators/kernel/beam_search_decode_kernel.h b/src/operators/kernel/beam_search_decode_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..fa40436115c5e03bfbfb6d2e9915ec6649d6f4b1 --- /dev/null +++ b/src/operators/kernel/beam_search_decode_kernel.h @@ -0,0 +1,57 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef BEAM_SEARCH_DECODE_OP + +#pragma once + +#include "framework/operator.h" +#include "operators/op_param.h" + +namespace paddle_mobile { +namespace operators { + +template +class BeamSearchDecodeParam : public OpParam { + public: + BeamSearchDecodeParam(const VariableNameMap &inputs, + const VariableNameMap &outputs, + const AttributeMap &attrs, const Scope &scope) { + ids_ = + OpParam::GetVarValue("Ids", inputs, scope); + scores_ = OpParam::GetVarValue("Scores", inputs, + scope); + sentence_ids_ = OpParam::GetVarValue("SentenceIds", + outputs, scope); + sentence_scores_ = OpParam::GetVarValue( + "SentenceScores", outputs, scope); + beam_size_ = OpParam::GetAttr("beam_size", attrs); + end_id_ = OpParam::GetAttr("end_id", attrs); + } + + public: + framework::LoDTensorArray *ids_; + framework::LoDTensorArray *scores_; + framework::LoDTensor *sentence_ids_; + framework::LoDTensor *sentence_scores_; + int beam_size_; + int end_id_; +}; + +DECLARE_KERNEL(BeamSearchDecode, BeamSearchDecodeParam); + +} // namespace operators +} // namespace paddle_mobile + +#endif // BEAM_SEARCH_DECODE_OP diff --git a/src/operators/kernel/while_kernel.h b/src/operators/kernel/while_kernel.h index 4b9d0ffe58f4c552b4d05f494432296a6fc87fd4..64fb7a607e7f9d8fdbd2c6d1091b9da7133831be 100644 --- a/src/operators/kernel/while_kernel.h +++ b/src/operators/kernel/while_kernel.h @@ -26,18 +26,16 @@ class WhileParam : public OpParam { public: WhileParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) - : inputs_(inputs), outputs_(outputs), scope_(scope) { + : scope_(&scope) { cond_ = OpParam::GetVarValue("Condition", inputs, scope); sub_block_ = OpParam::GetAttr("sub_block", attrs); } public: + const Scope *scope_; framework::LoDTensor *cond_; - const framework::BlockDesc *sub_block_; - const VariableNameMap inputs_; - const VariableNameMap outputs_; - const Scope scope_; + framework::BlockDesc *sub_block_; }; DECLARE_KERNEL(While, WhileParam); diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 00fbfbc771cfe9329b8ba76f120a5bc304dc80fc..267ed2effba14116a2d43914f33f9c921f1a456b 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1172,18 +1172,21 @@ class FeedParam : public OpParam { public: FeedParam(const VariableNameMap &inputs, const VariableNameMap &outputs, const AttributeMap &attrs, const Scope &scope) { - input_x_ = InputXFrom(inputs, scope); + input_x_ = InputXFrom(inputs, scope); out_ = OutFrom(outputs, scope); + col_ = GetAttr("col", attrs); auto var = scope.FindVar("batch_size"); batch_size = var->GetValue(); } - const LoDTensor *InputX() const { return input_x_; } + const framework::LoDTensorArray *InputX() const { return input_x_; } GType *Out() const { return out_; } + const int Col() const { return col_; } const int BatchSize() const { return batch_size; } private: - LoDTensor *input_x_; + framework::LoDTensorArray *input_x_; GType *out_; + int col_; int batch_size; }; @@ -3008,24 +3011,6 @@ class LogicalUnaryParam : public OpParam { }; #endif // LOGICAL_NOT_OP -// #ifdef WHILE_OP -// template -// class WhileParam : public OpParam { -// public: -// WhileParam(const VariableNameMap &inputs, -// const VariableNameMap &outputs, const AttributeMap &attrs, -// const Scope &scope) { -// cond_ = OpParam::GetVarValue("Condition", inputs, -// scope); block_desc_ = OpParam::GetAttr("sub_block", attrs); -// } -// -// public: -// framework::LoDTensor *cond_; -// const framework::BlockDesc *block_desc_; -// }; -// #endif // WHILE_OP - #ifdef WRITE_TO_ARRAY_OP template class WriteToArrayParam : public OpParam { @@ -3099,17 +3084,17 @@ class IncrementParam : public OpParam { const AttributeMap &attrs, const Scope &scope) { input_x_ = InputXFrom(inputs, scope); output_ = OutFrom(outputs, scope); - step_ = OpParam::GetAttr("step", attrs); + step_ = OpParam::GetAttr("step", attrs); } const GType *InputX() const { return input_x_; } GType *Out() const { return output_; } - int Step() const { return step_; } + float Step() const { return step_; } public: GType *input_x_; GType *output_; - int step_; + float step_; }; #endif // INCREMENT_OP diff --git a/tools/op.cmake b/tools/op.cmake index 89c23e47ca4ad98cd6b10d48e489f41c8319fb60..d25fce7cff14effbc1264dc46cba6364cee486bf 100644 --- a/tools/op.cmake +++ b/tools/op.cmake @@ -294,6 +294,7 @@ if(NOT FOUND_MATCH) set(PROPOSAL_OP ON) set(PSROI_POOL_OP ON) set(ROI_PERSPECTIVE_OP ON) + set(BEAM_SEARCH_DECODE_OP ON) endif() # option(BATCHNORM_OP "" ON) @@ -597,3 +598,6 @@ endif() if (ROI_PERSPECTIVE_OP) add_definitions(-DROI_PERSPECTIVE_OP) endif() +if (BEAM_SEARCH_DECODE_OP) + add_definitions(-DBEAM_SEARCH_DECODE_OP) +endif()