diff --git a/src/framework/operator.cpp b/src/framework/operator.cpp index 857e8ea1c7795c72b95b843e4a5e6351b8c9a5a9..dfdf0af79ac98d0bb79c7da3fdcc872341417b87 100644 --- a/src/framework/operator.cpp +++ b/src/framework/operator.cpp @@ -13,11 +13,32 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "framework/operator.h" -#include "framework/op_info.h" +#include "operators/op_param.h" namespace paddle_mobile { namespace framework { +template +vector OperatorBase::GetOutKeys() const { + auto it = op_input_output_key.find(type_); + if (it == op_input_output_key.end()) { + DLOG << type_ << " has no outputs"; + } + return it->second.second; +} + +template +static T *GetVarValue(const string &key, const VariableNameMap &var_map, + const Scope &scope) { + auto var_vec = var_map.at(key); + if (!var_vec.empty()) { + auto var = scope.FindVar(var_vec[0]); + return var->GetMutable(); + } else { + return nullptr; + } +} + template OperatorBase::OperatorBase(const std::string &type, const VariableNameMap &inputs, @@ -31,9 +52,22 @@ OperatorBase::OperatorBase(const std::string &type, scope_(scope) { CheckAllInputOutputSet(); } + template void OperatorBase::CheckAllInputOutputSet() const {} +template +void OperatorBase::Run() const { + RunImpl(); +#ifdef PADDLE_MOBILE_DEBUG + vector output_keys = GetOutKeys(); + for (const auto key : output_keys) { + Tensor *out_ = GetVarValue(key, outputs_, *scope_); + DLOG << type_ << " output- " << key << "=" << *out_; + } +#endif +} + template class OperatorBase; template class OperatorWithKernel; diff --git a/src/framework/operator.h b/src/framework/operator.h index 5a40a9266309845e2ad1950b227fe44863d8053e..549916b9a38f88c3fed9f8bf0d874d6ad2bf11c9 100644 --- a/src/framework/operator.h +++ b/src/framework/operator.h @@ -36,6 +36,8 @@ limitations under the License. */ namespace paddle_mobile { namespace framework { +using std::string; +using std::vector; static std::unordered_map< std::string, std::pair, std::vector>> op_input_output_key = {{"conv2d", {{"Input"}, {"Output"}}}, @@ -57,7 +59,9 @@ class OperatorBase : PaddleMobileObject { const VariableNameMap &outputs, const AttributeMap &attrs, std::shared_ptr scope); virtual ~OperatorBase() {} - virtual void Run() const = 0; + void Run() const; + vector GetOutKeys() const; + virtual void RunImpl() const = 0; virtual void InferShape() const = 0; const VariableNameMap &Inputs() const { return inputs_; } @@ -88,7 +92,8 @@ class OperatorWithKernel : public OperatorBase { const VariableNameMap &outputs, const AttributeMap &attrs, std::shared_ptr scope) : OperatorBase(type, inputs, outputs, attrs, scope) {} - virtual void Run() const = 0; + + virtual void RunImpl() const = 0; virtual void InferShape() const = 0; }; @@ -113,7 +118,7 @@ class FusionOpMatcher : PaddleMobileObject { virtual std::string Type() = 0; - virtual void FolderNodes(Node &node) { + virtual void FolderNodes(const Node &node) { node.Folder(node_.Depth(), Type(), {}); } diff --git a/src/framework/tensor.h b/src/framework/tensor.h index b6a7c724ad13d3757101ebaf48d089cc4e7f957e..674edd67733ef8d0520d28f5c131e9da6746ad17 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -18,11 +18,12 @@ limitations under the License. */ #include #include #include +#include #include #include -#include "data_layout.h" -#include "ddim.h" +#include "framework/data_layout.h" +#include "framework/ddim.h" #include "memory/t_malloc.h" namespace paddle_mobile { @@ -62,8 +63,8 @@ struct SizeOfTypeFunctor { static inline size_t SizeOfType(std::type_index type) { SizeOfTypeFunctor functor; size_t size = functor(type); - // PADDLE_ENFORCE(size != 0UL, "Cannot get size of type %s", - // type.name()); + + PADDLE_MOBILE_ENFORCE(size != 0UL, "Cannot get size of type %s", type.name()); return size; } @@ -72,16 +73,27 @@ class LoDTensor; class Tensor { public: Tensor() : offset_(0) {} + template + Tensor(std::vector input, DDim ddim) : offset_(0) { + PADDLE_MOBILE_ENFORCE( + input.size() == framework::product(ddim), + "input vector'length should be equal to tensor's length"); + auto input_ptr = mutable_data(ddim); + for (int i = 0; i < input.size(); ++i) { + input_ptr[i] = input[i]; + } + } /*! Return a pointer to mutable memory block. */ template inline T *data() { check_memory_size(); - // PADDLE_ENFORCE(std::is_same::value || - // holder_->type().hash_code() == - // typeid(T).hash_code(), - // "Tensor holds the wrong type, it holds %s", - // this->holder_->type().name()); + PADDLE_MOBILE_ENFORCE( + (std::is_same::value || + holder_->type().hash_code() == typeid(T).hash_code()), + "Tensor holds the wrong type, it holds %s", + this->holder_->type().name()); + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + offset_); } @@ -90,11 +102,11 @@ class Tensor { template inline const T *data() const { check_memory_size(); - // PADDLE_ENFORCE(std::is_same::value || - // holder_->type().hash_code() == - // typeid(T).hash_code(), - // "Tensor holds the wrong type, it holds %s", - // this->holder_->type().name()); + PADDLE_MOBILE_ENFORCE( + (std::is_same::value || + holder_->type().hash_code() == typeid(T).hash_code()), + "Tensor holds the wrong type, it holds %s", + this->holder_->type().name()); return reinterpret_cast( reinterpret_cast(holder_->ptr()) + offset_); @@ -116,17 +128,11 @@ class Tensor { if (holder_ != nullptr) { holder_->set_type(type); } - // PADDLE_ENFORCE_GE(numel(), 0, - // "When calling this method, the Tensor's - // numel must be - // " "equal or larger than zero. " "Please - // check - // Tensor::Resize has been called first."); + PADDLE_MOBILE_ENFORCE(numel() >= 0, "the Tensor'snumel must >=0.") int64_t size = numel() * SizeOfType(type); /* some versions of boost::variant don't have operator!= */ if (holder_ == nullptr || holder_->size() < size + offset_) { holder_.reset(new PlaceholderImpl(size, type)); - offset_ = 0; } return reinterpret_cast( @@ -179,16 +185,13 @@ class Tensor { */ inline Tensor Slice(int begin_idx, int end_idx) const { check_memory_size(); - // PADDLE_ENFORCE_GE(begin_idx, 0, - // "The start row index must be greater than - // 0."); - // PADDLE_ENFORCE_LE(end_idx, dims_[0], "The end row index is - // out of - // bound."); PADDLE_ENFORCE_LT( - // begin_idx, end_idx, - // "The start row index must be lesser than the end row - // index."); - + PADDLE_MOBILE_ENFORCE(begin_idx >= 0, + "The start row index must be greater than 0.") + PADDLE_MOBILE_ENFORCE(end_idx <= dims_[0], + "The end row index is out of bound.") + PADDLE_MOBILE_ENFORCE( + begin_idx < end_idx, + "The start row index must be lesser than the end row index") if (dims_[0] == 1) { return *this; } else { @@ -205,10 +208,9 @@ class Tensor { } std::type_index type() const { - // PADDLE_ENFORCE_NOT_NULL( - // holder_, "Tensor not initialized yet - // when - // Tensor::type() is called."); + PADDLE_MOBILE_ENFORCE( + holder_ != nullptr, + "Tensor not initialized yet when Tensor::type() is called.") return holder_->type(); } @@ -221,12 +223,8 @@ class Tensor { PADDLE_MOBILE_ENFORCE( holder_ != nullptr, "Tensor holds no memory. Call Tensor::mutable_data first."); - PADDLE_MOBILE_ENFORCE( - numel() * SizeOfType(type()) <= memory_size(), - "Tensor's dims_ is out of bound. CallTensor::mutable_data " - "first to re-allocate memory.\n" - "or maybe the required data-type mismatches the data\ - already stored."); + PADDLE_MOBILE_ENFORCE(numel() * SizeOfType(type()) <= memory_size(), + "Tensor's dims_ is out of bound. "); } inline DataLayout layout() const { return layout_; } @@ -257,13 +255,8 @@ class Tensor { memory::PODDeleter()), size_(size), type_(type) { - // PADDLE_ENFORCE_NOT_NULL(ptr_, - // "Insufficient %s - // memory to allocation.", - // (is_cpu_place(place_) - // ? - // "CPU" : - // "GPU")); + PADDLE_MOBILE_ENFORCE(ptr_ != nullptr, + "Insufficient memory to allocation"); } virtual size_t size() const { return size_; } @@ -321,6 +314,19 @@ class Tensor { size_t offset_; }; +#ifdef PADDLE_MOBILE_DEBUG +inline Print &operator<<(Print &printer, const Tensor &tensor) { + printer << " dims: " << tensor.dims() << "\n"; + int stride = tensor.numel() / 20; + stride = stride > 0 ? stride : 1; + for (int i = 0; i < tensor.numel(); i += stride) { + printer << tensor.data()[i] << " "; + } + return printer; +} + +#endif + inline Tensor ReshapeToMatrix(const Tensor &src, int num_col_dims) { Tensor res; res.ShareDataWith(src); diff --git a/src/io.cpp b/src/io.cpp index 1c5e97bbb7eaa0257bb2f81ef131b8c6bc48547f..bfb3c5a7e2b9c91016afbe288972de9aca9d470f 100644 --- a/src/io.cpp +++ b/src/io.cpp @@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "io.h" +#include "/io.h" #include #include - #include "common/enforce.h" #include "common/log.h" #include "framework/framework.pb-c.h" @@ -53,7 +52,7 @@ static size_t ReadBuffer(const char *file_name, uint8_t **out) { DLOG << "model size: " << size; - *out = (uint8_t *)malloc(size); + *out = reinterpret_cast(size); size_t cur_len = 0; size_t nread; @@ -364,7 +363,7 @@ void Executor::LoadMemory(const framework::VarDesc var_desc, is.read(static_cast(memory), memory_size * type_size); is.close(); -}; +} template void Executor::InitMemory() { @@ -381,6 +380,7 @@ void Executor::InitMemory() { } else { if (var_desc->Type() == framework::VARTYPE_TYPE_LOD_TENSOR) { auto tensor = var->template GetMutable(); + tensor->template mutable_data(); } } @@ -406,15 +406,7 @@ void Executor::predict(const framework::Tensor &t, int block_id) { template std::vector::Ptype> Executor::predict( const std::vector &input, const std::vector &dims) { - DLOG << "start predict: "; - - framework::LoDTensor tensor; - auto ddim = framework::make_ddim(dims); - - auto input_ptr = tensor.mutable_data(ddim); - for (int i = 0; i < input.size(); ++i) { - input_ptr[i] = input[i]; - } + framework::Tensor tensor(input, framework::make_ddim(dims)); predict(tensor, 0); diff --git a/src/operators/batchnorm_op.h b/src/operators/batchnorm_op.h index 072fbd5f42445167d2ab10d86c39dd161dd78a92..760466eeddcb472ed2a47625b786a021ce7c1ef5 100644 --- a/src/operators/batchnorm_op.h +++ b/src/operators/batchnorm_op.h @@ -12,19 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + +#include #include "framework/operator.h" #include "operators/kernel/batchnorm_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { - -using namespace framework; - +using std::string; template class BatchNormOp : public framework::OperatorWithKernel { public: - BatchNormOp(const std::string &type, const VariableNameMap &inputs, + BatchNormOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) @@ -32,7 +33,7 @@ class BatchNormOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::BatchNormKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/box_coder_op.h b/src/operators/box_coder_op.h index 76f4b15174267948823a5e11889ec1365ebc4034..a2203e1d89f8b5b6270c1576711a4c008d927e34 100644 --- a/src/operators/box_coder_op.h +++ b/src/operators/box_coder_op.h @@ -36,7 +36,7 @@ class BoxCoderOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::BoxCoderKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/concat_op.h b/src/operators/concat_op.h index 611e46af6a64864611b57b28e50954656188214c..15160e20a403d73bb11e982f5a527454f26b5dd6 100644 --- a/src/operators/concat_op.h +++ b/src/operators/concat_op.h @@ -13,25 +13,25 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once + +#include #include "framework/operator.h" #include "operators/kernel/concat_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { - -using namespace framework; - +using std::string; template class ConcatOp : public framework::OperatorWithKernel { public: - ConcatOp(const std::string &type, const VariableNameMap &inputs, + ConcatOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) : framework::OperatorWithKernel(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::ConcatKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/conv_op.h b/src/operators/conv_op.h index 047fa1a8e6c6fb4955361de35327cb64adc156f6..1557f2f06eed8237f7b7e9ff44adc233129a49a3 100644 --- a/src/operators/conv_op.h +++ b/src/operators/conv_op.h @@ -14,14 +14,13 @@ limitations under the License. */ #pragma once +#include #include "framework/operator.h" #include "operators/kernel/conv_kernel.h" namespace paddle_mobile { namespace operators { - -using namespace framework; - +using std::string; template class ConvOp : public framework::OperatorWithKernel { public: @@ -35,7 +34,7 @@ class ConvOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape() const override; - void Run() const { + void RunImpl() const { operators::ConvKernel kernel; kernel.Compute(param_); this->ClearVariables({"Filter", "Input"}); diff --git a/src/operators/elementwise_add_op.h b/src/operators/elementwise_add_op.h index 47fa52c46960b5aa928fd6cbc2e29757598dfc8b..7dd7e147a0630450c3ad9f830d661b2b92a5f995 100644 --- a/src/operators/elementwise_add_op.h +++ b/src/operators/elementwise_add_op.h @@ -12,19 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once + +#include #include "framework/operator.h" #include "kernel/elementwise_add_kernel.h" -#include "op_param.h" +#include "operators/op_param.h" namespace paddle_mobile { namespace operators { - -using namespace framework; - +using std::string; template class ElementwiseAddOp : public framework::OperatorWithKernel { public: - ElementwiseAddOp(const std::string &type, const VariableNameMap &inputs, + ElementwiseAddOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) @@ -32,7 +33,7 @@ class ElementwiseAddOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::ElementwiseAddKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/feed_op.h b/src/operators/feed_op.h index 426d5f6220db4e5ee05465288ab3b2f76735da52..25a82894ea96420e94d9d2e4d70809930a954642 100644 --- a/src/operators/feed_op.h +++ b/src/operators/feed_op.h @@ -14,22 +14,23 @@ limitations under the License. */ #pragma once +#include #include "framework/operator.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { - +using std::string; template class FeedOp : public framework::OperatorBase { public: - FeedOp(const std::string &type, const VariableNameMap &inputs, + FeedOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) : framework::OperatorBase(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { param_.Out()->ShareDataWith(*param_.InputX()); } + void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void InferShape() const { auto out_dims = param_.Out()->dims(); diff --git a/src/operators/fetch_op.h b/src/operators/fetch_op.h index 7dddd67992990c67ebb73172ac516cb0c0d525d3..31e17f2b562567de1b4194098995f6ee4cd3caa3 100644 --- a/src/operators/fetch_op.h +++ b/src/operators/fetch_op.h @@ -14,27 +14,24 @@ limitations under the License. */ #pragma once +#include #include "framework/operator.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { +using std::string; template class FetchOp : public framework::OperatorBase { public: - FetchOp(const std::string &type, const VariableNameMap &inputs, + FetchOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) : framework::OperatorBase(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { - param_.Out()->ShareDataWith(*param_.InputX()); - for (int i = 0; i < param_.Out()->numel(); ++i) { - DLOG << param_.Out()->template data()[i]; - } - } + void RunImpl() const { param_.Out()->ShareDataWith(*param_.InputX()); } void InferShape() const { auto x_dims = param_.InputX()->dims(); diff --git a/src/operators/fusion_fc_op.h b/src/operators/fusion_fc_op.h index 1dd5d2bf535520c46ee838d1cf2945d988557a4c..6e0c50170a170ab469f206d17a9d7f9787b7bdbb 100644 --- a/src/operators/fusion_fc_op.h +++ b/src/operators/fusion_fc_op.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include +#include #include "framework/operator.h" #include "framework/program/program-optimize/fusion_op_register.h" @@ -22,7 +23,8 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { - +using std::string; +using std::vector; class FusionFcMatcher : public framework::FusionOpMatcher { public: FusionFcMatcher() { @@ -30,8 +32,8 @@ class FusionFcMatcher : public framework::FusionOpMatcher { node_ > std::make_shared("elementwise_add"); } - void FolderNodes(framework::Node &node) { - std::vector> origin_descs = + void FolderNodes(const framework::Node &node) { + vector> origin_descs = node.OpDescs(node_.Depth()); node.Folder(node_.Depth(), Type(), {{"elementwise_add", {"Y", "Z"}}}); } @@ -42,7 +44,7 @@ class FusionFcMatcher : public framework::FusionOpMatcher { template class FushionFcOp : public framework::OperatorWithKernel { public: - FushionFcOp(const std::string &type, const VariableNameMap &inputs, + FushionFcOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) @@ -50,7 +52,7 @@ class FushionFcOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::FushionFcKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/lrn_op.h b/src/operators/lrn_op.h index 112053b97f9479baf5db6faec2c3dad09e7d0ce3..e5d98e1bb103307e1fae9c2460be19fe9d0f01a0 100644 --- a/src/operators/lrn_op.h +++ b/src/operators/lrn_op.h @@ -11,27 +11,27 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once +#include #include "framework/operator.h" #include "operators/kernel/lrn_kernel.h" #include "operators/op_param.h" namespace paddle_mobile { namespace operators { - -using namespace framework; - +using std::string; template class LrnOp : public framework::OperatorWithKernel { public: - LrnOp(const std::string &type, const VariableNameMap &inputs, + LrnOp(const string &type, const VariableNameMap &inputs, const VariableNameMap &outputs, const framework::AttributeMap attrs, std::shared_ptr scope) : framework::OperatorWithKernel(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::LrnKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/mul_op.h b/src/operators/mul_op.h index 8685651ea687f36e0f6f0dff1e883f19341b1ab4..ded618551fca682daea0bacc3635776eeb81301c 100644 --- a/src/operators/mul_op.h +++ b/src/operators/mul_op.h @@ -11,7 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once +#include #include "framework/operator.h" #include "operators/kernel/mul_kernel.h" #include "operators/op_param.h" @@ -19,8 +21,6 @@ limitations under the License. */ namespace paddle_mobile { namespace operators { -using namespace framework; - template class MulOp : public framework::OperatorWithKernel { public: @@ -31,7 +31,7 @@ class MulOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::MulKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/multiclass_nms_op.h b/src/operators/multiclass_nms_op.h index 40466af60743f2e9bb28ab92b33d0498fcd97c34..c424856b8cdc09b365a7ece28df39a911b6d3af8 100644 --- a/src/operators/multiclass_nms_op.h +++ b/src/operators/multiclass_nms_op.h @@ -36,7 +36,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::MultiClassNMSKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/pool_op.h b/src/operators/pool_op.h index 3cc1facbef4338ad1112efcf24e95f51e01e0002..7195c3b4e17b1a9ee9cdc2e81fe7fb9a634546c3 100644 --- a/src/operators/pool_op.h +++ b/src/operators/pool_op.h @@ -17,25 +17,25 @@ limitations under the License. */ #include #include #include +#include namespace paddle_mobile { namespace operators { -using namespace framework; - +using framework::AttributeMap; +using framework::Scope; +using std::string; template -class PoolOp : public framework::OperatorWithKernel { +class PoolOp : public OperatorWithKernel { public: - PoolOp(const std::string &type, const VariableNameMap &inputs, - const VariableNameMap &outputs, const framework::AttributeMap &attrs, - std::shared_ptr scope) - : framework::OperatorWithKernel(type, inputs, outputs, attrs, - scope), + PoolOp(const string &type, const VariableNameMap &inputs, + const VariableNameMap &outputs, const AttributeMap &attrs, + std::shared_ptr scope) + : OperatorWithKernel(type, inputs, outputs, attrs, scope), param_(inputs, outputs, attrs, *scope) {} - using framework::OperatorWithKernel::OperatorWithKernel; + using OperatorWithKernel::OperatorWithKernel; void InferShape() const override; - void Run() const { - // InferShape(); + void RunImpl() const { operators::PoolKernel kernel; kernel.Compute(param_); this->ClearVariables({"X"}); diff --git a/src/operators/prior_box_op.h b/src/operators/prior_box_op.h index 17a583cac9699462a1b2f2f90bc1be52ccbcd9eb..84481e602a6cb4143a50760e66b0d430b8a1c719 100644 --- a/src/operators/prior_box_op.h +++ b/src/operators/prior_box_op.h @@ -36,7 +36,7 @@ class PriorBoxOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::PriorBoxKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/relu_op.h b/src/operators/relu_op.h index 26bee848c1b44c53b73fbe58e6cbe45e95d91a1e..6c3a614a1a0316e6b487532739f01bf7027557bc 100644 --- a/src/operators/relu_op.h +++ b/src/operators/relu_op.h @@ -35,7 +35,7 @@ class ReluOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::ReluKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/reshape_op.h b/src/operators/reshape_op.h index 62bcb3a67980b75f46487aba4dbf5c89d2b65c7d..b244e62a930a0e6a98d56fe06a4e4a7e37f7d5e1 100644 --- a/src/operators/reshape_op.h +++ b/src/operators/reshape_op.h @@ -35,7 +35,7 @@ class ReshapeOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::ReshapeKernel kernel; kernel.Compute(param_); } diff --git a/src/operators/sigmoid_op.h b/src/operators/sigmoid_op.h index ba5d3d0299fe5de3f94284546b9fc7d81ca6d524..f631ba51759ea31f91ddcdf7c90a0dc874e86b20 100644 --- a/src/operators/sigmoid_op.h +++ b/src/operators/sigmoid_op.h @@ -36,7 +36,7 @@ class SigmoidOp : public framework::OperatorWithKernel { void InferShape() const override; - void Run() const { + void RunImpl() const { operators::SigmoidKernel kernel; kernel.Compute(param_); this->ClearVariables({"X"}); diff --git a/src/operators/softmax_op.h b/src/operators/softmax_op.h index 550a7698f969ee05f64f4d7c55fe092d4dfe4812..07fd9b945cb29cecd6f4d629b6be58035f971ce4 100644 --- a/src/operators/softmax_op.h +++ b/src/operators/softmax_op.h @@ -36,7 +36,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { void InferShape() const override; - void Run() const { + void RunImpl() const { operators::SoftmaxKernel kernel; kernel.Compute(param_); this->ClearVariables({"X"}); diff --git a/src/operators/transpose_op.h b/src/operators/transpose_op.h index a56771b4c643a8140e17afa588cbba200f6de032..0f67339533261f98374c6257494278306f3a7208 100644 --- a/src/operators/transpose_op.h +++ b/src/operators/transpose_op.h @@ -36,7 +36,7 @@ class TransposeOp : public framework::OperatorWithKernel { scope), param_(inputs, outputs, attrs, *scope) {} - void Run() const { + void RunImpl() const { operators::TransposeKernel kernel; kernel.Compute(param_); } diff --git a/test/executor_for_test.h b/test/executor_for_test.h index 89b546178261a95862122654a0530f59e8ec32cc..1eac65302098524b6353a2e34bc703054198c0d7 100644 --- a/test/executor_for_test.h +++ b/test/executor_for_test.h @@ -17,9 +17,9 @@ limitations under the License. */ #include #include -#include "./io.h" #include "common/log.h" #include "framework/op_registry.h" +#include "io/io.h" #include "operators/conv_op.h" #include "operators/elementwise_add_op.h" #include "operators/pool_op.h" diff --git a/test/framework/test_load.cpp b/test/framework/test_load.cpp index 0370e6d946fc0ce5d7ce0317b2c820ffd38b0faf..cae699b792fa26294f0a56e7723908f4b8c3d54a 100644 --- a/test/framework/test_load.cpp +++ b/test/framework/test_load.cpp @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/framework/test_optimize.cpp b/test/framework/test_optimize.cpp index c721c453739296685aa0075ca13db41a9072353e..6681ce83bb5ff5072c927577e118ca0b9fe3bc7b 100644 --- a/test/framework/test_optimize.cpp +++ b/test/framework/test_optimize.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "framework/program/program-optimize/node.h" #include "framework/program/program-optimize/program_optimize.h" -#include "io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/net/test_googlenet.cpp b/test/net/test_googlenet.cpp index ee03ed0b146da6312337d6961803534358a6fe2e..363825fe726359f365bf9d3b4feafc1202236fb6 100644 --- a/test/net/test_googlenet.cpp +++ b/test/net/test_googlenet.cpp @@ -16,7 +16,7 @@ limitations under the License. */ #include "../test_helper.h" #include "../test_include.h" -#include "io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/operators/test_pool_op.cpp b/test/operators/test_pool_op.cpp index 85c4f6106dae0e951a3a026f66a9f2cd41de94a0..d92fb66efd7e575487a22fa6f5f98dff6173d2d9 100644 --- a/test/operators/test_pool_op.cpp +++ b/test/operators/test_pool_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/operators/test_reshape_op.cpp b/test/operators/test_reshape_op.cpp index 7ba2faa47dfad443df3ff59e8db4a66f8d1d8bcf..d0cb9ac2df01df4066022ac23a912c0a55a04a3e 100644 --- a/test/operators/test_reshape_op.cpp +++ b/test/operators/test_reshape_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "./io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/operators/test_sigmoid_op.cpp b/test/operators/test_sigmoid_op.cpp index adf0376132773989c7ee728f4a38c561760254b5..4ed3efaf28aa986f0b679729c46cb386150583e3 100644 --- a/test/operators/test_sigmoid_op.cpp +++ b/test/operators/test_sigmoid_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "../../src/operators/kernel/sigmoid_kernel.h" #include "../test_helper.h" -#include "./io.h" +#include "io/io.h" int main() { paddle_mobile::framework::Tensor input; diff --git a/test/operators/test_softmax_op.cpp b/test/operators/test_softmax_op.cpp index ed5a1a49f5583e7fe8108675accdc2fc71a6a086..e0a616c9a4600427e34497d9e25dd0f2d1e8589f 100644 --- a/test/operators/test_softmax_op.cpp +++ b/test/operators/test_softmax_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "./io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/operators/test_transpose_op.cpp b/test/operators/test_transpose_op.cpp index ffdb34f2f500d6907e47fda66b29ad98d071ea38..4ca05d612b7ed4280733c1565c017c5d9b964a15 100644 --- a/test/operators/test_transpose_op.cpp +++ b/test/operators/test_transpose_op.cpp @@ -14,7 +14,7 @@ limitations under the License. */ #include "../executor_for_test.h" #include "../test_helper.h" -#include "./io.h" +#include "io/io.h" int main() { paddle_mobile::Loader loader; diff --git a/test/test_include.h b/test/test_include.h index 25efbb9f4c00921495a5ab054acdde329c4ef58a..dd4bf5d127d2c180e415edd91e1abbf379dac5c1 100644 --- a/test/test_include.h +++ b/test/test_include.h @@ -29,4 +29,4 @@ limitations under the License. */ #include "framework/scope.h" #include "framework/tensor.h" #include "framework/variable.h" -#include "io.h" +#include "io/io.h"